From f2d4575831a3f48014c56576e81686ce19312d79 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:15:23 -0400 Subject: [PATCH] Use real S3 for tests --- attachment/store_s3_test.go | 299 ++++++++---------------------------- attachment/store_test.go | 11 +- s3/client.go | 2 +- 3 files changed, 71 insertions(+), 241 deletions(-) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 2d4635ff..a41c6f8b 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -1,11 +1,9 @@ package attachment import ( - "encoding/xml" - "fmt" + "context" "io" - "net/http" - "net/http/httptest" + "os" "strings" "sync" "testing" @@ -15,13 +13,23 @@ import ( "heckel.io/ntfy/v2/s3" ) -// --- S3-specific tests --- - -func TestS3Store_WriteNoPrefix(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) +func TestS3Store_WriteWithPrefix(t *testing.T) { + s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + if s3URL == "" { + t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + } + cfg, err := s3.ParseURL(s3URL) + require.Nil(t, err) + cfg.Prefix = "test-prefix" + client := s3.New(cfg) + deleteAllObjects(client) + backend := newS3Backend(client) + cache, err := newStore(backend, 10*1024, nil) + require.Nil(t, err) + t.Cleanup(func() { + deleteAllObjects(client) + cache.Close() + }) size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0) require.Nil(t, err) @@ -37,241 +45,64 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { // --- Helpers --- -func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { +func newTestRealS3Store(t *testing.T, totalSizeLimit int64) (*Store, *modTimeOverrideBackend) { t.Helper() - host := strings.TrimPrefix(server.URL, "https://") - backend := newS3Backend(s3.New(&s3.Config{ - AccessKey: "AKID", - SecretKey: "SECRET", - Region: "us-east-1", - Endpoint: host, - Bucket: bucket, - Prefix: prefix, - PathStyle: true, - HTTPClient: server.Client(), - })) - cache, err := newStore(backend, totalSizeLimit, nil) + s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + if s3URL == "" { + t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + } + cfg, err := s3.ParseURL(s3URL) require.Nil(t, err) - t.Cleanup(func() { cache.Close() }) - return cache + client := s3.New(cfg) + inner := newS3Backend(client) + wrapper := &modTimeOverrideBackend{backend: inner, modTimes: make(map[string]time.Time)} + deleteAllObjects(client) + store, err := newStore(wrapper, totalSizeLimit, nil) + require.Nil(t, err) + t.Cleanup(func() { + deleteAllObjects(client) + store.Close() + }) + return store, wrapper } -// --- Mock S3 server --- -// -// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and -// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. - -type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - modTimes map[string]time.Time // full key (bucket/key) -> last modified time - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - mu sync.RWMutex -} - -func newMockS3Server() (*httptest.Server, *mockS3Server) { - m := &mockS3Server{ - objects: make(map[string][]byte), - modTimes: make(map[string]time.Time), - uploads: make(map[string]map[int][]byte), +func deleteAllObjects(client *s3.Client) { + objects, _ := client.ListObjectsV2(context.Background()) + keys := make([]string, 0, len(objects)) + for _, obj := range objects { + keys = append(keys, obj.Key) } - return httptest.NewTLSServer(m), m -} - -func (m *mockS3Server) setModTime(path string, t time.Time) { - m.mu.Lock() - m.modTimes[path] = t - m.mu.Unlock() -} - -func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Path is /{bucket}[/{key...}] - path := strings.TrimPrefix(r.URL.Path, "/") - q := r.URL.Query() - - switch { - case r.Method == http.MethodPut && q.Has("partNumber"): - m.handleUploadPart(w, r, path) - case r.Method == http.MethodPut: - m.handlePut(w, r, path) - case r.Method == http.MethodPost && q.Has("uploads"): - m.handleInitiateMultipart(w, r, path) - case r.Method == http.MethodPost && q.Has("uploadId"): - m.handleCompleteMultipart(w, r, path) - case r.Method == http.MethodDelete && q.Has("uploadId"): - m.handleAbortMultipart(w, r, path) - case r.Method == http.MethodGet && q.Get("list-type") == "2": - m.handleList(w, r, path) - case r.Method == http.MethodGet: - m.handleGet(w, r, path) - case r.Method == http.MethodPost && q.Has("delete"): - m.handleDelete(w, r, path) - default: - http.Error(w, "not implemented", http.StatusNotImplemented) + if len(keys) > 0 { + client.DeleteObjects(context.Background(), keys) //nolint:errcheck } } -func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { - body, err := io.ReadAll(r.Body) +// modTimeOverrideBackend wraps a backend and allows overriding LastModified times returned by List(). +// This is used in tests to simulate old objects on backends (like real S3) where +// LastModified cannot be set directly. +type modTimeOverrideBackend struct { + backend + mu sync.Mutex + modTimes map[string]time.Time // object ID -> override time +} + +func (b *modTimeOverrideBackend) List() ([]object, error) { + objects, err := b.backend.List() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + return nil, err } - m.mu.Lock() - m.objects[path] = body - m.modTimes[path] = time.Now() - m.mu.Unlock() - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { - m.mu.Lock() - m.nextID++ - uploadID := fmt.Sprintf("upload-%d", m.nextID) - m.uploads[uploadID] = make(map[int][]byte) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, uploadID) -} - -func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - var partNumber int - fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - parts[partNumber] = body - m.mu.Unlock() - - etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) - w.Header().Set("ETag", etag) - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - - // Assemble parts in order - var assembled []byte - for i := 1; i <= len(parts); i++ { - assembled = append(assembled, parts[i]...) - } - m.objects[path] = assembled - m.modTimes[path] = time.Now() - delete(m.uploads, uploadID) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, path) -} - -func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - m.mu.Lock() - delete(m.uploads, uploadID) - m.mu.Unlock() - w.WriteHeader(http.StatusNoContent) -} - -func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { - m.mu.RLock() - body, ok := m.objects[path] - m.mu.RUnlock() - if !ok { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`NoSuchKeyThe specified key does not exist.`)) - return - } - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) - w.WriteHeader(http.StatusOK) - w.Write(body) -} - -func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) { - // bucketPath is just the bucket name - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - var req struct { - Objects []struct { - Key string `xml:"Key"` - } `xml:"Object"` - } - if err := xml.Unmarshal(body, &req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - m.mu.Lock() - for _, obj := range req.Objects { - delete(m.objects, bucketPath+"/"+obj.Key) - } - m.mu.Unlock() - w.WriteHeader(http.StatusOK) - w.Write([]byte(``)) -} - -func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) { - prefix := r.URL.Query().Get("prefix") - m.mu.RLock() - var contents []s3ListObject - for key, body := range m.objects { - // key is "bucket/objectkey", strip bucket prefix - objKey := strings.TrimPrefix(key, bucketPath+"/") - if objKey == key { - continue // different bucket - } - if prefix == "" || strings.HasPrefix(objKey, prefix) { - contents = append(contents, s3ListObject{ - Key: objKey, - Size: int64(len(body)), - LastModified: m.modTimes[key].Format(time.RFC3339), - }) + b.mu.Lock() + defer b.mu.Unlock() + for i, obj := range objects { + if t, ok := b.modTimes[obj.ID]; ok { + objects[i].LastModified = t } } - m.mu.RUnlock() - - resp := s3ListResponse{ - Contents: contents, - IsTruncated: false, - } - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - xml.NewEncoder(w).Encode(resp) + return objects, nil } -type s3ListResponse struct { - XMLName xml.Name `xml:"ListBucketResult"` - Contents []s3ListObject `xml:"Contents"` - IsTruncated bool `xml:"IsTruncated"` -} - -type s3ListObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` - LastModified string `xml:"LastModified"` +func (b *modTimeOverrideBackend) setModTime(id string, t time.Time) { + b.mu.Lock() + b.modTimes[id] = t + b.mu.Unlock() } diff --git a/attachment/store_test.go b/attachment/store_test.go index 7b5a6013..645a2159 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -229,8 +229,9 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) { // forEachBackend runs f against both the file and S3 backends. It also provides a makeOld // callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour). -// For the file backend, this uses os.Chtimes; for the S3 backend, it sets the object's -// LastModified time in the mock server. Objects start with recent timestamps by default. +// For the file backend, this uses os.Chtimes; for the S3 backend, it overrides the object's +// LastModified time via a modTimeOverrideBackend wrapper. Objects start with recent timestamps +// by default. The S3 subtest is skipped if NTFY_TEST_ATTACHMENT_S3_URL is not set. func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) { t.Run("file", func(t *testing.T) { dir, s := newTestFileStore(t, totalSizeLimit) @@ -241,11 +242,9 @@ func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s * f(t, s, makeOld) }) t.Run("s3", func(t *testing.T) { - server, mock := newMockS3Server() - defer server.Close() - s := newTestS3Store(t, server, "my-bucket", "pfx", totalSizeLimit) + s, wrapper := newTestRealS3Store(t, totalSizeLimit) makeOld := func(id string) { - mock.setModTime("my-bucket/pfx/"+id, time.Unix(1, 0)) + wrapper.setModTime(id, time.Unix(1, 0)) } f(t, s, makeOld) }) diff --git a/s3/client.go b/s3/client.go index 29cad3a4..d9ec1ab8 100644 --- a/s3/client.go +++ b/s3/client.go @@ -86,7 +86,7 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size if err != nil { return fmt.Errorf("uploading object %s failed: %w", key, err) } - resp.Body.Close() + defer resp.Body.Close() if !isHTTPSuccess(resp) { return parseError(resp) }