mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-05-15 07:35:49 -06:00
More tests and human review
This commit is contained in:
parent
393f730d11
commit
1742302f83
7 changed files with 115 additions and 16 deletions
|
|
@ -110,9 +110,11 @@ func (c *Store) Remove(ids ...string) error {
|
|||
return errInvalidFileID
|
||||
}
|
||||
}
|
||||
// Remove from backend
|
||||
if err := c.backend.Delete(ids...); err != nil {
|
||||
return err
|
||||
}
|
||||
// Update total cache size
|
||||
c.mu.Lock()
|
||||
for _, id := range ids {
|
||||
if size, ok := c.sizes[id]; ok {
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload,
|
|||
keyMarker = result.NextKeyMarker
|
||||
uploadIDMarker = result.NextUploadIDMarker
|
||||
}
|
||||
return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages)
|
||||
return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
|
|
@ -122,10 +122,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea
|
|||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
} else if err != nil {
|
||||
c.abortMultipartUpload(ctx, key, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -172,7 +171,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
|
|||
log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts))
|
||||
bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts})
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err)
|
||||
return fmt.Errorf("error marshalling complete multipart upload request: %w", err)
|
||||
}
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
|
||||
respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil)
|
||||
|
|
@ -180,7 +179,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
|
|||
return err
|
||||
}
|
||||
// Check if the response contains an error (S3 can return 200 with an error body)
|
||||
var errResp ErrorResponse
|
||||
var errResp errorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
return &errResp
|
||||
}
|
||||
|
|
|
|||
|
|
@ -456,7 +456,7 @@ func TestClient_GetObject_NotFound(t *testing.T) {
|
|||
|
||||
_, _, err := client.GetObject(context.Background(), "nonexistent")
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
var errResp *errorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
|
|
@ -799,7 +799,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
// Get after delete should fail
|
||||
_, _, err = client.GetObject(ctx, key)
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
var errResp *errorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -76,15 +76,15 @@ type Object struct {
|
|||
LastModified time.Time
|
||||
}
|
||||
|
||||
// ErrorResponse is returned when S3 responds with a non-2xx status code.
|
||||
type ErrorResponse struct {
|
||||
// errorResponse is returned when S3 responds with a non-2xx status code.
|
||||
type errorResponse struct {
|
||||
StatusCode int
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Body string `xml:"-"` // raw response body
|
||||
}
|
||||
|
||||
func (e *ErrorResponse) Error() string {
|
||||
func (e *errorResponse) Error() string {
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ func ParseURL(s3URL string) (*Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// parseError reads an S3 error response and returns an *ErrorResponse.
|
||||
// parseError reads an S3 error response and returns an *errorResponse.
|
||||
func parseError(resp *http.Response) error {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
|
|
@ -94,7 +94,7 @@ func parseError(resp *http.Response) error {
|
|||
}
|
||||
|
||||
func parseErrorFromBytes(statusCode int, body []byte) error {
|
||||
errResp := &ErrorResponse{
|
||||
errResp := &errorResponse{
|
||||
StatusCode: statusCode,
|
||||
Body: string(body),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ func TestParseError_XMLResponse(t *testing.T) {
|
|||
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
|
||||
err := parseErrorFromBytes(404, xmlBody)
|
||||
|
||||
var errResp *ErrorResponse
|
||||
var errResp *errorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
|
|
@ -173,7 +173,7 @@ func TestParseError_XMLResponse(t *testing.T) {
|
|||
func TestParseError_NonXMLResponse(t *testing.T) {
|
||||
err := parseErrorFromBytes(500, []byte("internal server error"))
|
||||
|
||||
var errResp *ErrorResponse
|
||||
var errResp *errorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 500, errResp.StatusCode)
|
||||
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ package util
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFixedLimiter_AllowValueReset(t *testing.T) {
|
||||
|
|
@ -147,3 +150,98 @@ func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing.
|
|||
_, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails
|
||||
require.Equal(t, ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestCountingReader_Total(t *testing.T) {
|
||||
cr := NewCountingReader(strings.NewReader("hello world"))
|
||||
buf := make([]byte, 5)
|
||||
|
||||
n, err := cr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, int64(5), cr.Total())
|
||||
|
||||
n, err = cr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, int64(10), cr.Total())
|
||||
|
||||
n, err = cr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, int64(11), cr.Total())
|
||||
|
||||
_, err = cr.Read(buf)
|
||||
require.Equal(t, io.EOF, err)
|
||||
require.Equal(t, int64(11), cr.Total())
|
||||
}
|
||||
|
||||
func TestCountingReader_Empty(t *testing.T) {
|
||||
cr := NewCountingReader(strings.NewReader(""))
|
||||
require.Equal(t, int64(0), cr.Total())
|
||||
|
||||
_, err := cr.Read(make([]byte, 10))
|
||||
require.Equal(t, io.EOF, err)
|
||||
require.Equal(t, int64(0), cr.Total())
|
||||
}
|
||||
|
||||
func TestLimitReader_ReadNoLimiter(t *testing.T) {
|
||||
lr := NewLimitReader(strings.NewReader("hello"))
|
||||
data, err := io.ReadAll(lr)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello", string(data))
|
||||
}
|
||||
|
||||
func TestLimitReader_ReadOneLimiter(t *testing.T) {
|
||||
l := NewFixedLimiter(10)
|
||||
lr := NewLimitReader(strings.NewReader("hello world!"), l)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
n, err := lr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, int64(5), l.Value())
|
||||
|
||||
n, err = lr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, int64(10), l.Value())
|
||||
|
||||
_, err = lr.Read(buf)
|
||||
require.Equal(t, ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestLimitReader_ReadTwoLimiters(t *testing.T) {
|
||||
l1 := NewFixedLimiter(11)
|
||||
l2 := NewFixedLimiter(8)
|
||||
lr := NewLimitReader(strings.NewReader("hello world!"), l1, l2)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
n, err := lr.Read(buf)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// Second read: l2 (limit 8) should reject 5 more bytes
|
||||
_, err = lr.Read(buf)
|
||||
require.Equal(t, ErrLimitReached, err)
|
||||
// l1 should have been reverted
|
||||
require.Equal(t, int64(5), l1.Value())
|
||||
require.Equal(t, int64(5), l2.Value())
|
||||
}
|
||||
|
||||
func TestLimitReader_ReadAll(t *testing.T) {
|
||||
l := NewFixedLimiter(100)
|
||||
lr := NewLimitReader(strings.NewReader("hello"), l)
|
||||
data, err := io.ReadAll(lr)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello", string(data))
|
||||
require.Equal(t, int64(5), l.Value())
|
||||
}
|
||||
|
||||
func TestLimitReader_ReadExactLimit(t *testing.T) {
|
||||
l := NewFixedLimiter(5)
|
||||
lr := NewLimitReader(bytes.NewReader(make([]byte, 5)), l)
|
||||
data, err := io.ReadAll(lr)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 5, len(data))
|
||||
require.Equal(t, int64(5), l.Value())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue