mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-05-15 07:35:49 -06:00
More tests and human review
This commit is contained in:
parent
393f730d11
commit
1742302f83
7 changed files with 115 additions and 16 deletions
|
|
@ -110,9 +110,11 @@ func (c *Store) Remove(ids ...string) error {
|
||||||
return errInvalidFileID
|
return errInvalidFileID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Remove from backend
|
||||||
if err := c.backend.Delete(ids...); err != nil {
|
if err := c.backend.Delete(ids...); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// Update total cache size
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
if size, ok := c.sizes[id]; ok {
|
if size, ok := c.sizes[id]; ok {
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload,
|
||||||
keyMarker = result.NextKeyMarker
|
keyMarker = result.NextKeyMarker
|
||||||
uploadIDMarker = result.NextUploadIDMarker
|
uploadIDMarker = result.NextUploadIDMarker
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages)
|
return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages)
|
||||||
}
|
}
|
||||||
|
|
||||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||||
|
|
@ -122,10 +122,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea
|
||||||
}
|
}
|
||||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
break
|
break
|
||||||
}
|
} else if err != nil {
|
||||||
if err != nil {
|
|
||||||
c.abortMultipartUpload(ctx, key, uploadID)
|
c.abortMultipartUpload(ctx, key, uploadID)
|
||||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -172,7 +171,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
|
||||||
log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts))
|
log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts))
|
||||||
bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts})
|
bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err)
|
return fmt.Errorf("error marshalling complete multipart upload request: %w", err)
|
||||||
}
|
}
|
||||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
|
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
|
||||||
respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil)
|
respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil)
|
||||||
|
|
@ -180,7 +179,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Check if the response contains an error (S3 can return 200 with an error body)
|
// Check if the response contains an error (S3 can return 200 with an error body)
|
||||||
var errResp ErrorResponse
|
var errResp errorResponse
|
||||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||||
return &errResp
|
return &errResp
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -456,7 +456,7 @@ func TestClient_GetObject_NotFound(t *testing.T) {
|
||||||
|
|
||||||
_, _, err := client.GetObject(context.Background(), "nonexistent")
|
_, _, err := client.GetObject(context.Background(), "nonexistent")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
var errResp *ErrorResponse
|
var errResp *errorResponse
|
||||||
require.ErrorAs(t, err, &errResp)
|
require.ErrorAs(t, err, &errResp)
|
||||||
require.Equal(t, 404, errResp.StatusCode)
|
require.Equal(t, 404, errResp.StatusCode)
|
||||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||||
|
|
@ -799,7 +799,7 @@ func TestClient_RealBucket(t *testing.T) {
|
||||||
// Get after delete should fail
|
// Get after delete should fail
|
||||||
_, _, err = client.GetObject(ctx, key)
|
_, _, err = client.GetObject(ctx, key)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
var errResp *ErrorResponse
|
var errResp *errorResponse
|
||||||
require.ErrorAs(t, err, &errResp)
|
require.ErrorAs(t, err, &errResp)
|
||||||
require.Equal(t, 404, errResp.StatusCode)
|
require.Equal(t, 404, errResp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -76,15 +76,15 @@ type Object struct {
|
||||||
LastModified time.Time
|
LastModified time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorResponse is returned when S3 responds with a non-2xx status code.
|
// errorResponse is returned when S3 responds with a non-2xx status code.
|
||||||
type ErrorResponse struct {
|
type errorResponse struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Code string `xml:"Code"`
|
Code string `xml:"Code"`
|
||||||
Message string `xml:"Message"`
|
Message string `xml:"Message"`
|
||||||
Body string `xml:"-"` // raw response body
|
Body string `xml:"-"` // raw response body
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ErrorResponse) Error() string {
|
func (e *errorResponse) Error() string {
|
||||||
if e.Code != "" {
|
if e.Code != "" {
|
||||||
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
|
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ func ParseURL(s3URL string) (*Config, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseError reads an S3 error response and returns an *ErrorResponse.
|
// parseError reads an S3 error response and returns an *errorResponse.
|
||||||
func parseError(resp *http.Response) error {
|
func parseError(resp *http.Response) error {
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -94,7 +94,7 @@ func parseError(resp *http.Response) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseErrorFromBytes(statusCode int, body []byte) error {
|
func parseErrorFromBytes(statusCode int, body []byte) error {
|
||||||
errResp := &ErrorResponse{
|
errResp := &errorResponse{
|
||||||
StatusCode: statusCode,
|
StatusCode: statusCode,
|
||||||
Body: string(body),
|
Body: string(body),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -163,7 +163,7 @@ func TestParseError_XMLResponse(t *testing.T) {
|
||||||
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
|
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
|
||||||
err := parseErrorFromBytes(404, xmlBody)
|
err := parseErrorFromBytes(404, xmlBody)
|
||||||
|
|
||||||
var errResp *ErrorResponse
|
var errResp *errorResponse
|
||||||
require.ErrorAs(t, err, &errResp)
|
require.ErrorAs(t, err, &errResp)
|
||||||
require.Equal(t, 404, errResp.StatusCode)
|
require.Equal(t, 404, errResp.StatusCode)
|
||||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||||
|
|
@ -173,7 +173,7 @@ func TestParseError_XMLResponse(t *testing.T) {
|
||||||
func TestParseError_NonXMLResponse(t *testing.T) {
|
func TestParseError_NonXMLResponse(t *testing.T) {
|
||||||
err := parseErrorFromBytes(500, []byte("internal server error"))
|
err := parseErrorFromBytes(500, []byte("internal server error"))
|
||||||
|
|
||||||
var errResp *ErrorResponse
|
var errResp *errorResponse
|
||||||
require.ErrorAs(t, err, &errResp)
|
require.ErrorAs(t, err, &errResp)
|
||||||
require.Equal(t, 500, errResp.StatusCode)
|
require.Equal(t, 500, errResp.StatusCode)
|
||||||
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
|
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@ package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/stretchr/testify/require"
|
"io"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFixedLimiter_AllowValueReset(t *testing.T) {
|
func TestFixedLimiter_AllowValueReset(t *testing.T) {
|
||||||
|
|
@ -147,3 +150,98 @@ func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing.
|
||||||
_, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails
|
_, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails
|
||||||
require.Equal(t, ErrLimitReached, err)
|
require.Equal(t, ErrLimitReached, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCountingReader_Total(t *testing.T) {
|
||||||
|
cr := NewCountingReader(strings.NewReader("hello world"))
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
|
||||||
|
n, err := cr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, n)
|
||||||
|
require.Equal(t, int64(5), cr.Total())
|
||||||
|
|
||||||
|
n, err = cr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, n)
|
||||||
|
require.Equal(t, int64(10), cr.Total())
|
||||||
|
|
||||||
|
n, err = cr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, n)
|
||||||
|
require.Equal(t, int64(11), cr.Total())
|
||||||
|
|
||||||
|
_, err = cr.Read(buf)
|
||||||
|
require.Equal(t, io.EOF, err)
|
||||||
|
require.Equal(t, int64(11), cr.Total())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountingReader_Empty(t *testing.T) {
|
||||||
|
cr := NewCountingReader(strings.NewReader(""))
|
||||||
|
require.Equal(t, int64(0), cr.Total())
|
||||||
|
|
||||||
|
_, err := cr.Read(make([]byte, 10))
|
||||||
|
require.Equal(t, io.EOF, err)
|
||||||
|
require.Equal(t, int64(0), cr.Total())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitReader_ReadNoLimiter(t *testing.T) {
|
||||||
|
lr := NewLimitReader(strings.NewReader("hello"))
|
||||||
|
data, err := io.ReadAll(lr)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "hello", string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitReader_ReadOneLimiter(t *testing.T) {
|
||||||
|
l := NewFixedLimiter(10)
|
||||||
|
lr := NewLimitReader(strings.NewReader("hello world!"), l)
|
||||||
|
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
n, err := lr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, n)
|
||||||
|
require.Equal(t, int64(5), l.Value())
|
||||||
|
|
||||||
|
n, err = lr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, n)
|
||||||
|
require.Equal(t, int64(10), l.Value())
|
||||||
|
|
||||||
|
_, err = lr.Read(buf)
|
||||||
|
require.Equal(t, ErrLimitReached, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitReader_ReadTwoLimiters(t *testing.T) {
|
||||||
|
l1 := NewFixedLimiter(11)
|
||||||
|
l2 := NewFixedLimiter(8)
|
||||||
|
lr := NewLimitReader(strings.NewReader("hello world!"), l1, l2)
|
||||||
|
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
n, err := lr.Read(buf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, n)
|
||||||
|
|
||||||
|
// Second read: l2 (limit 8) should reject 5 more bytes
|
||||||
|
_, err = lr.Read(buf)
|
||||||
|
require.Equal(t, ErrLimitReached, err)
|
||||||
|
// l1 should have been reverted
|
||||||
|
require.Equal(t, int64(5), l1.Value())
|
||||||
|
require.Equal(t, int64(5), l2.Value())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitReader_ReadAll(t *testing.T) {
|
||||||
|
l := NewFixedLimiter(100)
|
||||||
|
lr := NewLimitReader(strings.NewReader("hello"), l)
|
||||||
|
data, err := io.ReadAll(lr)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "hello", string(data))
|
||||||
|
require.Equal(t, int64(5), l.Value())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitReader_ReadExactLimit(t *testing.T) {
|
||||||
|
l := NewFixedLimiter(5)
|
||||||
|
lr := NewLimitReader(bytes.NewReader(make([]byte, 5)), l)
|
||||||
|
data, err := io.ReadAll(lr)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 5, len(data))
|
||||||
|
require.Equal(t, int64(5), l.Value())
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue