From b81218953a9ac89b57ad7e86e6d97de94c67b46e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 21:14:49 -0400 Subject: [PATCH] Allow streaming to S3 --- attachment/backend.go | 2 +- attachment/backend_file.go | 12 ++++- attachment/backend_s3.go | 4 +- attachment/store.go | 14 +++--- attachment/store_file_test.go | 46 ++++++++++++++---- attachment/store_s3_test.go | 22 ++++----- s3/client.go | 29 ++++++++---- s3/client_test.go | 87 +++++++++++++++++++++++++++++------ s3/util.go | 4 ++ server/server.go | 19 ++++---- server/server_test.go | 4 +- tools/s3cli/main.go | 8 +++- 12 files changed, 181 insertions(+), 70 deletions(-) diff --git a/attachment/backend.go b/attachment/backend.go index e95fc91e..921ceb3e 100644 --- a/attachment/backend.go +++ b/attachment/backend.go @@ -15,7 +15,7 @@ type object struct { // backend is a minimal I/O interface for storing and retrieving attachment files. // It has no knowledge of size tracking, limiting, or ID validation. type backend interface { - Put(id string, in io.Reader) error + Put(id string, reader io.Reader, untrustedLength int64) error Get(id string) (io.ReadCloser, int64, error) List() ([]object, error) Delete(ids ...string) error diff --git a/attachment/backend_file.go b/attachment/backend_file.go index 260236d1..8726ddf4 100644 --- a/attachment/backend_file.go +++ b/attachment/backend_file.go @@ -1,6 +1,7 @@ package attachment import ( + "fmt" "io" "os" "path/filepath" @@ -24,16 +25,23 @@ func newFileBackend(dir string) (*fileBackend, error) { return &fileBackend{dir: dir}, nil } -func (b *fileBackend) Put(id string, in io.Reader) error { +func (b *fileBackend) Put(id string, reader io.Reader, untrustedLength int64) error { + if untrustedLength > 0 { + reader = io.LimitReader(reader, untrustedLength) + } file := filepath.Join(b.dir, id) f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { return err } defer f.Close() - if _, err := io.Copy(f, in); err != nil { + n, err := io.Copy(f, reader) + if err != nil { os.Remove(file) return err + } else if untrustedLength > 0 && n != untrustedLength { + os.Remove(file) + return fmt.Errorf("content length mismatch: claimed %d, got %d", untrustedLength, n) } if err := f.Close(); err != nil { os.Remove(file) diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 61d1c7b1..081d6002 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -24,8 +24,8 @@ func newS3Backend(client *s3.Client) *s3Backend { return &s3Backend{client: client} } -func (b *s3Backend) Put(id string, in io.Reader) error { - return b.client.PutObject(context.Background(), id, in) +func (b *s3Backend) Put(id string, reader io.Reader, untrustedLength int64) error { + return b.client.PutObject(context.Background(), id, reader, untrustedLength) } func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { diff --git a/attachment/store.go b/attachment/store.go index ba2e22cc..6e7cfb99 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -72,20 +72,22 @@ func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, } // Write stores an attachment file. The id is validated, and the write is subject to -// the total size limit and any additional limiters. -func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { +// the total size limit and any additional limiters. The untrustedLength is a hint +// from the client's Content-Length header; backends may use it to optimize uploads (e.g. +// streaming directly to S3 without buffering). +func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) { if !fileIDRegex.MatchString(id) { return 0, errInvalidFileID } log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment") limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - cr := util.NewCountingReader(in) - lr := util.NewLimitReader(cr, limiters...) - if err := c.backend.Put(id, lr); err != nil { + countingReader := util.NewCountingReader(reader) + limitReader := util.NewLimitReader(countingReader, limiters...) + if err := c.backend.Put(id, limitReader, untrustedLength); err != nil { c.backend.Delete(id) //nolint:errcheck return 0, err } - size := cr.Total() + size := countingReader.Total() c.mu.Lock() c.size += size c.sizes[id] = size diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index c65bad92..998a2bea 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -19,7 +19,7 @@ var ( func TestFileStore_Write_Success(t *testing.T) { dir, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), 0, util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) @@ -29,7 +29,7 @@ func TestFileStore_Write_Success(t *testing.T) { func TestFileStore_Write_Read_Success(t *testing.T) { _, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 0) require.Nil(t, err) require.Equal(t, int64(11), size) @@ -45,7 +45,7 @@ func TestFileStore_Write_Read_Success(t *testing.T) { func TestFileStore_Write_Remove_Success(t *testing.T) { dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)), 0) require.Nil(t, err) require.Equal(t, int64(999), size) } @@ -64,22 +64,48 @@ func TestFileStore_Write_Remove_Success(t *testing.T) { func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { dir, c := newTestFileStore(t) for i := 0; i < 10; i++ { - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray), 0) require.Nil(t, err) require.Equal(t, int64(1024), size) } - _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) + _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray), 0) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkX") } func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { dir, c := newTestFileStore(t) - _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), 0, util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkl") } +func TestFileStore_Write_UntrustedContentLengthExact(t *testing.T) { + dir, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 11) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, "hello world", readFile(t, dir+"/abcdefghijkl")) +} + +func TestFileStore_Write_UntrustedContentLengthBodyLonger(t *testing.T) { + dir, c := newTestFileStore(t) + // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 5) + require.Nil(t, err) + require.Equal(t, int64(5), size) + require.Equal(t, "hello", readFile(t, dir+"/abcdefghijkl")) +} + +func TestFileStore_Write_UntrustedContentLengthBodyShorter(t *testing.T) { + dir, c := newTestFileStore(t) + // Body has 5 bytes, but we claim 100 — should fail with content length mismatch + _, err := c.Write("abcdefghijkl", strings.NewReader("hello"), 100) + require.Error(t, err) + require.Contains(t, err.Error(), "content length mismatch") + require.NoFileExists(t, dir+"/abcdefghijkl") +} + func TestFileStore_Read_NotFound(t *testing.T) { _, c := newTestFileStore(t) _, _, err := c.Read("abcdefghijkl") @@ -90,11 +116,11 @@ func TestFileStore_Sync(t *testing.T) { dir, c := newTestFileStore(t) // Write some files - _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) - _, err = c.Write("abcdefghijk1", strings.NewReader("file1")) + _, err = c.Write("abcdefghijk1", strings.NewReader("file1"), 0) require.Nil(t, err) - _, err = c.Write("abcdefghijk2", strings.NewReader("file2")) + _, err = c.Write("abcdefghijk2", strings.NewReader("file2"), 0) require.Nil(t, err) require.Equal(t, int64(15), c.Size()) @@ -124,7 +150,7 @@ func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) { dir, c := newTestFileStore(t) // Write a file - _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) // Set the ID provider to return empty (no valid IDs) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 3ad5a93c..37bd0ecb 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -26,7 +26,7 @@ func TestS3Store_WriteReadRemove(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write - size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"), 0) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, int64(11), cache.Size()) @@ -55,7 +55,7 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) - size, err := cache.Write("abcdefghijkl", strings.NewReader("test")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0) require.Nil(t, err) require.Equal(t, int64(4), size) @@ -74,13 +74,13 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 100) // First write fits - _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) + _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)), 0) require.Nil(t, err) require.Equal(t, int64(80), cache.Size()) require.Equal(t, int64(20), cache.Remaining()) // Second write exceeds total limit - _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) + _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)), 0) require.ErrorIs(t, err, util.ErrLimitReached) } @@ -90,7 +90,7 @@ func TestS3Store_WriteFileSizeLimit(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) + _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100)) require.ErrorIs(t, err, util.ErrLimitReached) } @@ -101,7 +101,7 @@ func TestS3Store_WriteRemoveMultiple(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) for i := 0; i < 5; i++ { - _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) + _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)), 0) require.Nil(t, err) } require.Equal(t, int64(500), cache.Size()) @@ -126,7 +126,7 @@ func TestS3Store_InvalidID(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("bad", strings.NewReader("x")) + _, err := cache.Write("bad", strings.NewReader("x"), 0) require.Equal(t, errInvalidFileID, err) _, _, err = cache.Read("bad") @@ -143,11 +143,11 @@ func TestS3Store_Sync(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write some files - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) - _, err = cache.Write("abcdefghijk1", strings.NewReader("file1")) + _, err = cache.Write("abcdefghijk1", strings.NewReader("file1"), 0) require.Nil(t, err) - _, err = cache.Write("abcdefghijk2", strings.NewReader("file2")) + _, err = cache.Write("abcdefghijk2", strings.NewReader("file2"), 0) require.Nil(t, err) require.Equal(t, int64(15), cache.Size()) @@ -175,7 +175,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) // Set the ID provider to return empty (no valid IDs) diff --git a/s3/client.go b/s3/client.go index 83d8195e..29cad3a4 100644 --- a/s3/client.go +++ b/s3/client.go @@ -45,25 +45,36 @@ func New(config *Config) *Client { } // PutObject uploads body to the given key. The key is automatically prefixed with the client's -// configured prefix. The body size does not need to be known in advance. +// configured prefix. // -// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request -// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html). Otherwise, the body -// is uploaded using S3 multipart upload, reading one part at a time into memory -// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html). -func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error { +// If untrustedLength is between 1 and 5 GB, the body is streamed directly to S3 via a +// single PUT request without buffering. The read is limited to untrustedLength bytes; +// any extra data in the body is ignored. If the body is shorter than claimed, the upload fails. +// +// Otherwise (untrustedLength <= 0 or > 5 GB), the first 5 MB are buffered to decide +// between a simple PUT and multipart upload. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html +// and https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html +func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, untrustedLength int64) error { + if untrustedLength > 0 && untrustedLength <= maxSinglePutSize { + // Stream directly: Content-Length is known (but untrusted). LimitReader ensures we send at most + // untrustedLength bytes, and any extra data in body is ignored. + return c.putObject(ctx, key, io.LimitReader(body, untrustedLength), untrustedLength) + } + // Buffered path: read first 5 MB to decide simple vs multipart first := make([]byte, partSize) n, err := io.ReadFull(body, first) if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { - return c.putObjectSimple(ctx, key, bytes.NewReader(first[:n]), int64(n)) + return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) } else if err != nil { return fmt.Errorf("error reading object %s from client: %w", key, err) } return c.putObjectMultipart(ctx, key, io.MultiReader(bytes.NewReader(first), body)) } -// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. -func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error { +// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size) req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.config.ObjectURL(key), body) if err != nil { diff --git a/s3/client_test.go b/s3/client_test.go index 84402831..652db3e7 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -419,7 +419,7 @@ func TestClient_PutGetObject(t *testing.T) { ctx := context.Background() // Put - err := client.PutObject(ctx, "test-key", strings.NewReader("hello world")) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 0) require.Nil(t, err) // Get @@ -439,7 +439,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "test-key", strings.NewReader("hello")) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 0) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "test-key") @@ -471,7 +471,7 @@ func TestClient_DeleteObjects(t *testing.T) { // Put several objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 0) require.Nil(t, err) } require.Equal(t, 5, mock.objectCount()) @@ -502,13 +502,13 @@ func TestClient_ListObjects(t *testing.T) { // Client with prefix "pfx": list should only return objects under pfx/ client := newTestClient(server, "my-bucket", "pfx") for i := 0; i < 3; i++ { - err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } // Also put an object outside the prefix using a no-prefix client clientNoPrefix := newTestClient(server, "my-bucket", "") - err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y"))) + err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 0) require.Nil(t, err) // List with prefix client: should only see 3 @@ -532,7 +532,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) { // Put 5 objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } @@ -564,7 +564,7 @@ func TestClient_ListAllObjects(t *testing.T) { ctx := context.Background() for i := 0; i < 10; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } @@ -585,7 +585,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "large", bytes.NewReader(data)) + err := client.PutObject(ctx, "large", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "large") @@ -609,7 +609,7 @@ func TestClient_PutObject_ChunkedUpload(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "multipart", bytes.NewReader(data)) + err := client.PutObject(ctx, "multipart", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "multipart") @@ -633,7 +633,7 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "exact", bytes.NewReader(data)) + err := client.PutObject(ctx, "exact", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "exact") @@ -645,6 +645,63 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) { require.Equal(t, data, got) } +func TestClient_PutObject_StreamingExactLength(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // untrustedLength matches body exactly — streams directly via putObject + err := client.PutObject(ctx, "stream-exact", strings.NewReader("hello world"), 11) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "stream-exact") + require.Nil(t, err) + require.Equal(t, int64(11), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello world", string(got)) +} + +func TestClient_PutObject_StreamingBodyLongerThanClaimed(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored + err := client.PutObject(ctx, "stream-long", strings.NewReader("hello world"), 5) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "stream-long") + require.Nil(t, err) + require.Equal(t, int64(5), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello", string(got)) +} + +func TestClient_PutObject_StreamingBodyShorterThanClaimed(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // Body has 5 bytes, but we claim 100 — should fail + err := client.PutObject(ctx, "stream-short", strings.NewReader("hello"), 100) + require.Error(t, err) + require.Contains(t, err.Error(), "ContentLength") + + // Object should not exist + _, _, err = client.GetObject(ctx, "stream-short") + require.Error(t, err) +} + func TestClient_PutObject_NestedKey(t *testing.T) { server, _ := newMockS3Server() defer server.Close() @@ -652,7 +709,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested")) + err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 0) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") @@ -682,7 +739,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { for i := 0; i < batchSize; i++ { idx := batch*batchSize + i key := fmt.Sprintf("%08d", idx) - err := client.PutObject(ctx, key, bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } } @@ -780,7 +837,7 @@ func TestClient_RealBucket(t *testing.T) { content := "hello from ntfy s3 test" // Put - err := client.PutObject(ctx, key, strings.NewReader(content)) + err := client.PutObject(ctx, key, strings.NewReader(content), 0) require.Nil(t, err) // Get @@ -818,7 +875,7 @@ func TestClient_RealBucket(t *testing.T) { // Put 10 objects for i := 0; i < 10; i++ { - err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x")) + err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 0) require.Nil(t, err) } @@ -843,7 +900,7 @@ func TestClient_RealBucket(t *testing.T) { data[i] = byte(i % 256) } - err := client.PutObject(ctx, key, bytes.NewReader(data)) + err := client.PutObject(ctx, key, bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, key) diff --git a/s3/util.go b/s3/util.go index 0bcc96d2..1f4c2dd9 100644 --- a/s3/util.go +++ b/s3/util.go @@ -28,6 +28,10 @@ const ( // part size of 5 MB for all parts except the last. partSize = 5 * 1024 * 1024 + // maxSinglePutSize is the maximum size for a single PUT upload (5 GB). + // Objects larger than this must use multipart upload. + maxSinglePutSize = 5 * 1024 * 1024 * 1024 + // maxPages is the max number of pages to iterate through when listing objects maxPages = 500 ) diff --git a/server/server.go b/server/server.go index 99a61906..87eee5d6 100644 --- a/server/server.go +++ b/server/server.go @@ -1432,16 +1432,13 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me if m.Time > attachmentExpiry { return errHTTPBadRequestAttachmentsExpiryBeforeDelivery.With(m) } - contentLengthStr := r.Header.Get("Content-Length") - if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below - contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) - if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) { - return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{ - "message_content_length": contentLength, - "attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining, - "attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit, - }) - } + // Early "do-not-trust" check, hard limit see below + if r.ContentLength > 0 && (r.ContentLength > vinfo.Stats.AttachmentTotalSizeRemaining || r.ContentLength > vinfo.Limits.AttachmentFileSizeLimit) { + return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{ + "message_content_length": r.ContentLength, + "attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining, + "attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit, + }) } if m.Attachment == nil { m.Attachment = &model.Attachment{} @@ -1461,7 +1458,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), } - m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) + m.Attachment.Size, err = s.fileCache.Write(m.ID, body, r.ContentLength, limiters...) if errors.Is(err, util.ErrLimitReached) { return errHTTPEntityTooLargeAttachment.With(m) } else if err != nil { diff --git a/server/server_test.go b/server/server_test.go index cb20cbda..449b6006 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2218,8 +2218,8 @@ func TestServer_PublishAttachmentTooLargeContentLength(t *testing.T) { forEachBackend(t, func(t *testing.T, databaseURL string) { content := util.RandomString(5000) // > 4096 s := newTestServer(t, newTestConfig(t, databaseURL)) - response := request(t, s, "PUT", "/mytopic", content, map[string]string{ - "Content-Length": "20000000", + response := request(t, s, "PUT", "/mytopic", content, nil, func(r *http.Request) { + r.ContentLength = 20000000 }) err := toHTTPError(t, response.Body.String()) require.Equal(t, 413, response.Code) diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go index 0e640823..5de8a75c 100644 --- a/tools/s3cli/main.go +++ b/tools/s3cli/main.go @@ -58,6 +58,7 @@ func cmdPut(ctx context.Context, client *s3.Client) { path := os.Args[3] var r io.Reader + var size int64 if path == "-" { r = os.Stdin } else { @@ -66,10 +67,15 @@ func cmdPut(ctx context.Context, client *s3.Client) { fail("open %s: %s", path, err) } defer f.Close() + stat, err := f.Stat() + if err != nil { + fail("stat %s: %s", path, err) + } r = f + size = stat.Size() } - if err := client.PutObject(ctx, key, r); err != nil { + if err := client.PutObject(ctx, key, r, size); err != nil { fail("put: %s", err) } fmt.Fprintf(os.Stderr, "uploaded %s\n", key)