Allow streaming to S3

This commit is contained in:
binwiederhier 2026-03-21 21:14:49 -04:00
parent b3a8f18019
commit b81218953a
12 changed files with 181 additions and 70 deletions

View file

@ -15,7 +15,7 @@ type object struct {
// backend is a minimal I/O interface for storing and retrieving attachment files.
// It has no knowledge of size tracking, limiting, or ID validation.
type backend interface {
Put(id string, in io.Reader) error
Put(id string, reader io.Reader, untrustedLength int64) error
Get(id string) (io.ReadCloser, int64, error)
List() ([]object, error)
Delete(ids ...string) error

View file

@ -1,6 +1,7 @@
package attachment
import (
"fmt"
"io"
"os"
"path/filepath"
@ -24,16 +25,23 @@ func newFileBackend(dir string) (*fileBackend, error) {
return &fileBackend{dir: dir}, nil
}
func (b *fileBackend) Put(id string, in io.Reader) error {
func (b *fileBackend) Put(id string, reader io.Reader, untrustedLength int64) error {
if untrustedLength > 0 {
reader = io.LimitReader(reader, untrustedLength)
}
file := filepath.Join(b.dir, id)
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, in); err != nil {
n, err := io.Copy(f, reader)
if err != nil {
os.Remove(file)
return err
} else if untrustedLength > 0 && n != untrustedLength {
os.Remove(file)
return fmt.Errorf("content length mismatch: claimed %d, got %d", untrustedLength, n)
}
if err := f.Close(); err != nil {
os.Remove(file)

View file

@ -24,8 +24,8 @@ func newS3Backend(client *s3.Client) *s3Backend {
return &s3Backend{client: client}
}
func (b *s3Backend) Put(id string, in io.Reader) error {
return b.client.PutObject(context.Background(), id, in)
func (b *s3Backend) Put(id string, reader io.Reader, untrustedLength int64) error {
return b.client.PutObject(context.Background(), id, reader, untrustedLength)
}
func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) {

View file

@ -72,20 +72,22 @@ func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string,
}
// Write stores an attachment file. The id is validated, and the write is subject to
// the total size limit and any additional limiters.
func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
// the total size limit and any additional limiters. The untrustedLength is a hint
// from the client's Content-Length header; backends may use it to optimize uploads (e.g.
// streaming directly to S3 without buffering).
func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment")
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
cr := util.NewCountingReader(in)
lr := util.NewLimitReader(cr, limiters...)
if err := c.backend.Put(id, lr); err != nil {
countingReader := util.NewCountingReader(reader)
limitReader := util.NewLimitReader(countingReader, limiters...)
if err := c.backend.Put(id, limitReader, untrustedLength); err != nil {
c.backend.Delete(id) //nolint:errcheck
return 0, err
}
size := cr.Total()
size := countingReader.Total()
c.mu.Lock()
c.size += size
c.sizes[id] = size

View file

@ -19,7 +19,7 @@ var (
func TestFileStore_Write_Success(t *testing.T) {
dir, c := newTestFileStore(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), 0, util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
@ -29,7 +29,7 @@ func TestFileStore_Write_Success(t *testing.T) {
func TestFileStore_Write_Read_Success(t *testing.T) {
_, c := newTestFileStore(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"))
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 0)
require.Nil(t, err)
require.Equal(t, int64(11), size)
@ -45,7 +45,7 @@ func TestFileStore_Write_Read_Success(t *testing.T) {
func TestFileStore_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)), 0)
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
@ -64,22 +64,48 @@ func TestFileStore_Write_Remove_Success(t *testing.T) {
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileStore(t)
for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray), 0)
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray), 0)
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileStore(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), 0, util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Write_UntrustedContentLengthExact(t *testing.T) {
dir, c := newTestFileStore(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 11)
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "hello world", readFile(t, dir+"/abcdefghijkl"))
}
func TestFileStore_Write_UntrustedContentLengthBodyLonger(t *testing.T) {
dir, c := newTestFileStore(t)
// Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 5)
require.Nil(t, err)
require.Equal(t, int64(5), size)
require.Equal(t, "hello", readFile(t, dir+"/abcdefghijkl"))
}
func TestFileStore_Write_UntrustedContentLengthBodyShorter(t *testing.T) {
dir, c := newTestFileStore(t)
// Body has 5 bytes, but we claim 100 — should fail with content length mismatch
_, err := c.Write("abcdefghijkl", strings.NewReader("hello"), 100)
require.Error(t, err)
require.Contains(t, err.Error(), "content length mismatch")
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Read_NotFound(t *testing.T) {
_, c := newTestFileStore(t)
_, _, err := c.Read("abcdefghijkl")
@ -90,11 +116,11 @@ func TestFileStore_Sync(t *testing.T) {
dir, c := newTestFileStore(t)
// Write some files
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0)
require.Nil(t, err)
_, err = c.Write("abcdefghijk1", strings.NewReader("file1"))
_, err = c.Write("abcdefghijk1", strings.NewReader("file1"), 0)
require.Nil(t, err)
_, err = c.Write("abcdefghijk2", strings.NewReader("file2"))
_, err = c.Write("abcdefghijk2", strings.NewReader("file2"), 0)
require.Nil(t, err)
require.Equal(t, int64(15), c.Size())
@ -124,7 +150,7 @@ func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) {
dir, c := newTestFileStore(t)
// Write a file
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0)
require.Nil(t, err)
// Set the ID provider to return empty (no valid IDs)

View file

@ -26,7 +26,7 @@ func TestS3Store_WriteReadRemove(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write
size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"))
size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"), 0)
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, int64(11), cache.Size())
@ -55,7 +55,7 @@ func TestS3Store_WriteNoPrefix(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "", 10*1024)
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"))
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0)
require.Nil(t, err)
require.Equal(t, int64(4), size)
@ -74,13 +74,13 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 100)
// First write fits
_, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
_, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)), 0)
require.Nil(t, err)
require.Equal(t, int64(80), cache.Size())
require.Equal(t, int64(20), cache.Remaining())
// Second write exceeds total limit
_, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
_, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)), 0)
require.ErrorIs(t, err, util.ErrLimitReached)
}
@ -90,7 +90,7 @@ func TestS3Store_WriteFileSizeLimit(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
_, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100))
require.ErrorIs(t, err, util.ErrLimitReached)
}
@ -101,7 +101,7 @@ func TestS3Store_WriteRemoveMultiple(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
for i := 0; i < 5; i++ {
_, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
_, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)), 0)
require.Nil(t, err)
}
require.Equal(t, int64(500), cache.Size())
@ -126,7 +126,7 @@ func TestS3Store_InvalidID(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := cache.Write("bad", strings.NewReader("x"))
_, err := cache.Write("bad", strings.NewReader("x"), 0)
require.Equal(t, errInvalidFileID, err)
_, _, err = cache.Read("bad")
@ -143,11 +143,11 @@ func TestS3Store_Sync(t *testing.T) {
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write some files
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0)
require.Nil(t, err)
_, err = cache.Write("abcdefghijk1", strings.NewReader("file1"))
_, err = cache.Write("abcdefghijk1", strings.NewReader("file1"), 0)
require.Nil(t, err)
_, err = cache.Write("abcdefghijk2", strings.NewReader("file2"))
_, err = cache.Write("abcdefghijk2", strings.NewReader("file2"), 0)
require.Nil(t, err)
require.Equal(t, int64(15), cache.Size())
@ -175,7 +175,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) {
cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024)
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0)
require.Nil(t, err)
// Set the ID provider to return empty (no valid IDs)

View file

@ -45,25 +45,36 @@ func New(config *Config) *Client {
}
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
// configured prefix. The body size does not need to be known in advance.
// configured prefix.
//
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request
// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html). Otherwise, the body
// is uploaded using S3 multipart upload, reading one part at a time into memory
// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html).
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
// If untrustedLength is between 1 and 5 GB, the body is streamed directly to S3 via a
// single PUT request without buffering. The read is limited to untrustedLength bytes;
// any extra data in the body is ignored. If the body is shorter than claimed, the upload fails.
//
// Otherwise (untrustedLength <= 0 or > 5 GB), the first 5 MB are buffered to decide
// between a simple PUT and multipart upload.
//
// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html
// and https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, untrustedLength int64) error {
if untrustedLength > 0 && untrustedLength <= maxSinglePutSize {
// Stream directly: Content-Length is known (but untrusted). LimitReader ensures we send at most
// untrustedLength bytes, and any extra data in body is ignored.
return c.putObject(ctx, key, io.LimitReader(body, untrustedLength), untrustedLength)
}
// Buffered path: read first 5 MB to decide simple vs multipart
first := make([]byte, partSize)
n, err := io.ReadFull(body, first)
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
return c.putObjectSimple(ctx, key, bytes.NewReader(first[:n]), int64(n))
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
} else if err != nil {
return fmt.Errorf("error reading object %s from client: %w", key, err)
}
return c.putObjectMultipart(ctx, key, io.MultiReader(bytes.NewReader(first), body))
}
// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error {
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.config.ObjectURL(key), body)
if err != nil {

View file

@ -419,7 +419,7 @@ func TestClient_PutGetObject(t *testing.T) {
ctx := context.Background()
// Put
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 0)
require.Nil(t, err)
// Get
@ -439,7 +439,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) {
ctx := context.Background()
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 0)
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key")
@ -471,7 +471,7 @@ func TestClient_DeleteObjects(t *testing.T) {
// Put several objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 0)
require.Nil(t, err)
}
require.Equal(t, 5, mock.objectCount())
@ -502,13 +502,13 @@ func TestClient_ListObjects(t *testing.T) {
// Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ {
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 0)
require.Nil(t, err)
}
// Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "")
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 0)
require.Nil(t, err)
// List with prefix client: should only see 3
@ -532,7 +532,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
// Put 5 objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0)
require.Nil(t, err)
}
@ -564,7 +564,7 @@ func TestClient_ListAllObjects(t *testing.T) {
ctx := context.Background()
for i := 0; i < 10; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0)
require.Nil(t, err)
}
@ -585,7 +585,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "large", bytes.NewReader(data))
err := client.PutObject(ctx, "large", bytes.NewReader(data), 0)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large")
@ -609,7 +609,7 @@ func TestClient_PutObject_ChunkedUpload(t *testing.T) {
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
err := client.PutObject(ctx, "multipart", bytes.NewReader(data), 0)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "multipart")
@ -633,7 +633,7 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) {
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
err := client.PutObject(ctx, "exact", bytes.NewReader(data), 0)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "exact")
@ -645,6 +645,63 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) {
require.Equal(t, data, got)
}
func TestClient_PutObject_StreamingExactLength(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
// untrustedLength matches body exactly — streams directly via putObject
err := client.PutObject(ctx, "stream-exact", strings.NewReader("hello world"), 11)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "stream-exact")
require.Nil(t, err)
require.Equal(t, int64(11), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(got))
}
func TestClient_PutObject_StreamingBodyLongerThanClaimed(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
// Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored
err := client.PutObject(ctx, "stream-long", strings.NewReader("hello world"), 5)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "stream-long")
require.Nil(t, err)
require.Equal(t, int64(5), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello", string(got))
}
func TestClient_PutObject_StreamingBodyShorterThanClaimed(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
// Body has 5 bytes, but we claim 100 — should fail
err := client.PutObject(ctx, "stream-short", strings.NewReader("hello"), 100)
require.Error(t, err)
require.Contains(t, err.Error(), "ContentLength")
// Object should not exist
_, _, err = client.GetObject(ctx, "stream-short")
require.Error(t, err)
}
func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
@ -652,7 +709,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) {
ctx := context.Background()
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 0)
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
@ -682,7 +739,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) {
for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx)
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 0)
require.Nil(t, err)
}
}
@ -780,7 +837,7 @@ func TestClient_RealBucket(t *testing.T) {
content := "hello from ntfy s3 test"
// Put
err := client.PutObject(ctx, key, strings.NewReader(content))
err := client.PutObject(ctx, key, strings.NewReader(content), 0)
require.Nil(t, err)
// Get
@ -818,7 +875,7 @@ func TestClient_RealBucket(t *testing.T) {
// Put 10 objects
for i := 0; i < 10; i++ {
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 0)
require.Nil(t, err)
}
@ -843,7 +900,7 @@ func TestClient_RealBucket(t *testing.T) {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, key, bytes.NewReader(data))
err := client.PutObject(ctx, key, bytes.NewReader(data), 0)
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key)

View file

@ -28,6 +28,10 @@ const (
// part size of 5 MB for all parts except the last.
partSize = 5 * 1024 * 1024
// maxSinglePutSize is the maximum size for a single PUT upload (5 GB).
// Objects larger than this must use multipart upload.
maxSinglePutSize = 5 * 1024 * 1024 * 1024
// maxPages is the max number of pages to iterate through when listing objects
maxPages = 500
)

View file

@ -1432,17 +1432,14 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me
if m.Time > attachmentExpiry {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery.With(m)
}
contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
// Early "do-not-trust" check, hard limit see below
if r.ContentLength > 0 && (r.ContentLength > vinfo.Stats.AttachmentTotalSizeRemaining || r.ContentLength > vinfo.Limits.AttachmentFileSizeLimit) {
return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{
"message_content_length": contentLength,
"message_content_length": r.ContentLength,
"attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining,
"attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit,
})
}
}
if m.Attachment == nil {
m.Attachment = &model.Attachment{}
}
@ -1461,7 +1458,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me
util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
}
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, r.ContentLength, limiters...)
if errors.Is(err, util.ErrLimitReached) {
return errHTTPEntityTooLargeAttachment.With(m)
} else if err != nil {

View file

@ -2218,8 +2218,8 @@ func TestServer_PublishAttachmentTooLargeContentLength(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
content := util.RandomString(5000) // > 4096
s := newTestServer(t, newTestConfig(t, databaseURL))
response := request(t, s, "PUT", "/mytopic", content, map[string]string{
"Content-Length": "20000000",
response := request(t, s, "PUT", "/mytopic", content, nil, func(r *http.Request) {
r.ContentLength = 20000000
})
err := toHTTPError(t, response.Body.String())
require.Equal(t, 413, response.Code)

View file

@ -58,6 +58,7 @@ func cmdPut(ctx context.Context, client *s3.Client) {
path := os.Args[3]
var r io.Reader
var size int64
if path == "-" {
r = os.Stdin
} else {
@ -66,10 +67,15 @@ func cmdPut(ctx context.Context, client *s3.Client) {
fail("open %s: %s", path, err)
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
fail("stat %s: %s", path, err)
}
r = f
size = stat.Size()
}
if err := client.PutObject(ctx, key, r); err != nil {
if err := client.PutObject(ctx, key, r, size); err != nil {
fail("put: %s", err)
}
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)