From 1742302f83e1a0b3963924bc354d75ace228793e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 16:27:41 -0400 Subject: [PATCH] More tests and human review --- attachment/store.go | 2 + s3/client_multipart.go | 11 +++-- s3/client_test.go | 4 +- s3/types.go | 6 +-- s3/util.go | 4 +- s3/util_test.go | 4 +- util/limit_test.go | 100 ++++++++++++++++++++++++++++++++++++++++- 7 files changed, 115 insertions(+), 16 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index 0192b09a..10dcd17b 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -110,9 +110,11 @@ func (c *Store) Remove(ids ...string) error { return errInvalidFileID } } + // Remove from backend if err := c.backend.Delete(ids...); err != nil { return err } + // Update total cache size c.mu.Lock() for _, id := range ids { if size, ok := c.sizes[id]; ok { diff --git a/s3/client_multipart.go b/s3/client_multipart.go index 5e98db38..198175d4 100644 --- a/s3/client_multipart.go +++ b/s3/client_multipart.go @@ -71,7 +71,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, keyMarker = result.NextKeyMarker uploadIDMarker = result.NextUploadIDMarker } - return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages) + return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages) } // abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. @@ -122,10 +122,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea } if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { break - } - if err != nil { + } else if err != nil { c.abortMultipartUpload(ctx, key, uploadID) - return fmt.Errorf("s3: PutObject read: %w", err) + return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err) } } @@ -172,7 +171,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts)) bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts}) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err) + return fmt.Errorf("error marshalling complete multipart upload request: %w", err) } reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil) @@ -180,7 +179,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri return err } // Check if the response contains an error (S3 can return 200 with an error body) - var errResp ErrorResponse + var errResp errorResponse if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { return &errResp } diff --git a/s3/client_test.go b/s3/client_test.go index d267a6a8..84402831 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -456,7 +456,7 @@ func TestClient_GetObject_NotFound(t *testing.T) { _, _, err := client.GetObject(context.Background(), "nonexistent") require.Error(t, err) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) require.Equal(t, "NoSuchKey", errResp.Code) @@ -799,7 +799,7 @@ func TestClient_RealBucket(t *testing.T) { // Get after delete should fail _, _, err = client.GetObject(ctx, key) require.Error(t, err) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) }) diff --git a/s3/types.go b/s3/types.go index 1782b88d..96b62649 100644 --- a/s3/types.go +++ b/s3/types.go @@ -76,15 +76,15 @@ type Object struct { LastModified time.Time } -// ErrorResponse is returned when S3 responds with a non-2xx status code. -type ErrorResponse struct { +// errorResponse is returned when S3 responds with a non-2xx status code. +type errorResponse struct { StatusCode int Code string `xml:"Code"` Message string `xml:"Message"` Body string `xml:"-"` // raw response body } -func (e *ErrorResponse) Error() string { +func (e *errorResponse) Error() string { if e.Code != "" { return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message) } diff --git a/s3/util.go b/s3/util.go index 06f7e3d1..0bcc96d2 100644 --- a/s3/util.go +++ b/s3/util.go @@ -84,7 +84,7 @@ func ParseURL(s3URL string) (*Config, error) { }, nil } -// parseError reads an S3 error response and returns an *ErrorResponse. +// parseError reads an S3 error response and returns an *errorResponse. func parseError(resp *http.Response) error { body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) if err != nil { @@ -94,7 +94,7 @@ func parseError(resp *http.Response) error { } func parseErrorFromBytes(statusCode int, body []byte) error { - errResp := &ErrorResponse{ + errResp := &errorResponse{ StatusCode: statusCode, Body: string(body), } diff --git a/s3/util_test.go b/s3/util_test.go index 3f08911d..93ddd707 100644 --- a/s3/util_test.go +++ b/s3/util_test.go @@ -163,7 +163,7 @@ func TestParseError_XMLResponse(t *testing.T) { xmlBody := []byte(`NoSuchKeyThe specified key does not exist.`) err := parseErrorFromBytes(404, xmlBody) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) require.Equal(t, "NoSuchKey", errResp.Code) @@ -173,7 +173,7 @@ func TestParseError_XMLResponse(t *testing.T) { func TestParseError_NonXMLResponse(t *testing.T) { err := parseErrorFromBytes(500, []byte("internal server error")) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 500, errResp.StatusCode) require.Equal(t, "", errResp.Code) // XML parsing failed, no code diff --git a/util/limit_test.go b/util/limit_test.go index 51595351..9ca9fe39 100644 --- a/util/limit_test.go +++ b/util/limit_test.go @@ -2,9 +2,12 @@ package util import ( "bytes" - "github.com/stretchr/testify/require" + "io" + "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestFixedLimiter_AllowValueReset(t *testing.T) { @@ -147,3 +150,98 @@ func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing. _, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails require.Equal(t, ErrLimitReached, err) } + +func TestCountingReader_Total(t *testing.T) { + cr := NewCountingReader(strings.NewReader("hello world")) + buf := make([]byte, 5) + + n, err := cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(5), cr.Total()) + + n, err = cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(10), cr.Total()) + + n, err = cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 1, n) + require.Equal(t, int64(11), cr.Total()) + + _, err = cr.Read(buf) + require.Equal(t, io.EOF, err) + require.Equal(t, int64(11), cr.Total()) +} + +func TestCountingReader_Empty(t *testing.T) { + cr := NewCountingReader(strings.NewReader("")) + require.Equal(t, int64(0), cr.Total()) + + _, err := cr.Read(make([]byte, 10)) + require.Equal(t, io.EOF, err) + require.Equal(t, int64(0), cr.Total()) +} + +func TestLimitReader_ReadNoLimiter(t *testing.T) { + lr := NewLimitReader(strings.NewReader("hello")) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, "hello", string(data)) +} + +func TestLimitReader_ReadOneLimiter(t *testing.T) { + l := NewFixedLimiter(10) + lr := NewLimitReader(strings.NewReader("hello world!"), l) + + buf := make([]byte, 5) + n, err := lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(5), l.Value()) + + n, err = lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(10), l.Value()) + + _, err = lr.Read(buf) + require.Equal(t, ErrLimitReached, err) +} + +func TestLimitReader_ReadTwoLimiters(t *testing.T) { + l1 := NewFixedLimiter(11) + l2 := NewFixedLimiter(8) + lr := NewLimitReader(strings.NewReader("hello world!"), l1, l2) + + buf := make([]byte, 5) + n, err := lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + + // Second read: l2 (limit 8) should reject 5 more bytes + _, err = lr.Read(buf) + require.Equal(t, ErrLimitReached, err) + // l1 should have been reverted + require.Equal(t, int64(5), l1.Value()) + require.Equal(t, int64(5), l2.Value()) +} + +func TestLimitReader_ReadAll(t *testing.T) { + l := NewFixedLimiter(100) + lr := NewLimitReader(strings.NewReader("hello"), l) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, "hello", string(data)) + require.Equal(t, int64(5), l.Value()) +} + +func TestLimitReader_ReadExactLimit(t *testing.T) { + l := NewFixedLimiter(5) + lr := NewLimitReader(bytes.NewReader(make([]byte, 5)), l) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, 5, len(data)) + require.Equal(t, int64(5), l.Value()) +}