More tests and human review

This commit is contained in:
binwiederhier 2026-03-21 16:27:41 -04:00
parent 393f730d11
commit 1742302f83
7 changed files with 115 additions and 16 deletions

View file

@ -110,9 +110,11 @@ func (c *Store) Remove(ids ...string) error {
return errInvalidFileID
}
}
// Remove from backend
if err := c.backend.Delete(ids...); err != nil {
return err
}
// Update total cache size
c.mu.Lock()
for _, id := range ids {
if size, ok := c.sizes[id]; ok {

View file

@ -71,7 +71,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload,
keyMarker = result.NextKeyMarker
uploadIDMarker = result.NextUploadIDMarker
}
return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages)
return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages)
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
@ -122,10 +122,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
} else if err != nil {
c.abortMultipartUpload(ctx, key, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err)
}
}
@ -172,7 +171,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts))
bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts})
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err)
return fmt.Errorf("error marshalling complete multipart upload request: %w", err)
}
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil)
@ -180,7 +179,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
return err
}
// Check if the response contains an error (S3 can return 200 with an error body)
var errResp ErrorResponse
var errResp errorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
return &errResp
}

View file

@ -456,7 +456,7 @@ func TestClient_GetObject_NotFound(t *testing.T) {
_, _, err := client.GetObject(context.Background(), "nonexistent")
require.Error(t, err)
var errResp *ErrorResponse
var errResp *errorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
@ -799,7 +799,7 @@ func TestClient_RealBucket(t *testing.T) {
// Get after delete should fail
_, _, err = client.GetObject(ctx, key)
require.Error(t, err)
var errResp *ErrorResponse
var errResp *errorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
})

View file

@ -76,15 +76,15 @@ type Object struct {
LastModified time.Time
}
// ErrorResponse is returned when S3 responds with a non-2xx status code.
type ErrorResponse struct {
// errorResponse is returned when S3 responds with a non-2xx status code.
type errorResponse struct {
StatusCode int
Code string `xml:"Code"`
Message string `xml:"Message"`
Body string `xml:"-"` // raw response body
}
func (e *ErrorResponse) Error() string {
func (e *errorResponse) Error() string {
if e.Code != "" {
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
}

View file

@ -84,7 +84,7 @@ func ParseURL(s3URL string) (*Config, error) {
}, nil
}
// parseError reads an S3 error response and returns an *ErrorResponse.
// parseError reads an S3 error response and returns an *errorResponse.
func parseError(resp *http.Response) error {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
@ -94,7 +94,7 @@ func parseError(resp *http.Response) error {
}
func parseErrorFromBytes(statusCode int, body []byte) error {
errResp := &ErrorResponse{
errResp := &errorResponse{
StatusCode: statusCode,
Body: string(body),
}

View file

@ -163,7 +163,7 @@ func TestParseError_XMLResponse(t *testing.T) {
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
err := parseErrorFromBytes(404, xmlBody)
var errResp *ErrorResponse
var errResp *errorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
@ -173,7 +173,7 @@ func TestParseError_XMLResponse(t *testing.T) {
func TestParseError_NonXMLResponse(t *testing.T) {
err := parseErrorFromBytes(500, []byte("internal server error"))
var errResp *ErrorResponse
var errResp *errorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 500, errResp.StatusCode)
require.Equal(t, "", errResp.Code) // XML parsing failed, no code

View file

@ -2,9 +2,12 @@ package util
import (
"bytes"
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestFixedLimiter_AllowValueReset(t *testing.T) {
@ -147,3 +150,98 @@ func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing.
_, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails
require.Equal(t, ErrLimitReached, err)
}
func TestCountingReader_Total(t *testing.T) {
cr := NewCountingReader(strings.NewReader("hello world"))
buf := make([]byte, 5)
n, err := cr.Read(buf)
require.Nil(t, err)
require.Equal(t, 5, n)
require.Equal(t, int64(5), cr.Total())
n, err = cr.Read(buf)
require.Nil(t, err)
require.Equal(t, 5, n)
require.Equal(t, int64(10), cr.Total())
n, err = cr.Read(buf)
require.Nil(t, err)
require.Equal(t, 1, n)
require.Equal(t, int64(11), cr.Total())
_, err = cr.Read(buf)
require.Equal(t, io.EOF, err)
require.Equal(t, int64(11), cr.Total())
}
func TestCountingReader_Empty(t *testing.T) {
cr := NewCountingReader(strings.NewReader(""))
require.Equal(t, int64(0), cr.Total())
_, err := cr.Read(make([]byte, 10))
require.Equal(t, io.EOF, err)
require.Equal(t, int64(0), cr.Total())
}
func TestLimitReader_ReadNoLimiter(t *testing.T) {
lr := NewLimitReader(strings.NewReader("hello"))
data, err := io.ReadAll(lr)
require.Nil(t, err)
require.Equal(t, "hello", string(data))
}
func TestLimitReader_ReadOneLimiter(t *testing.T) {
l := NewFixedLimiter(10)
lr := NewLimitReader(strings.NewReader("hello world!"), l)
buf := make([]byte, 5)
n, err := lr.Read(buf)
require.Nil(t, err)
require.Equal(t, 5, n)
require.Equal(t, int64(5), l.Value())
n, err = lr.Read(buf)
require.Nil(t, err)
require.Equal(t, 5, n)
require.Equal(t, int64(10), l.Value())
_, err = lr.Read(buf)
require.Equal(t, ErrLimitReached, err)
}
func TestLimitReader_ReadTwoLimiters(t *testing.T) {
l1 := NewFixedLimiter(11)
l2 := NewFixedLimiter(8)
lr := NewLimitReader(strings.NewReader("hello world!"), l1, l2)
buf := make([]byte, 5)
n, err := lr.Read(buf)
require.Nil(t, err)
require.Equal(t, 5, n)
// Second read: l2 (limit 8) should reject 5 more bytes
_, err = lr.Read(buf)
require.Equal(t, ErrLimitReached, err)
// l1 should have been reverted
require.Equal(t, int64(5), l1.Value())
require.Equal(t, int64(5), l2.Value())
}
func TestLimitReader_ReadAll(t *testing.T) {
l := NewFixedLimiter(100)
lr := NewLimitReader(strings.NewReader("hello"), l)
data, err := io.ReadAll(lr)
require.Nil(t, err)
require.Equal(t, "hello", string(data))
require.Equal(t, int64(5), l.Value())
}
func TestLimitReader_ReadExactLimit(t *testing.T) {
l := NewFixedLimiter(5)
lr := NewLimitReader(bytes.NewReader(make([]byte, 5)), l)
data, err := io.ReadAll(lr)
require.Nil(t, err)
require.Equal(t, 5, len(data))
require.Equal(t, int64(5), l.Value())
}