diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go
index 2d4635ff..a41c6f8b 100644
--- a/attachment/store_s3_test.go
+++ b/attachment/store_s3_test.go
@@ -1,11 +1,9 @@
package attachment
import (
- "encoding/xml"
- "fmt"
+ "context"
"io"
- "net/http"
- "net/http/httptest"
+ "os"
"strings"
"sync"
"testing"
@@ -15,13 +13,23 @@ import (
"heckel.io/ntfy/v2/s3"
)
-// --- S3-specific tests ---
-
-func TestS3Store_WriteNoPrefix(t *testing.T) {
- server, _ := newMockS3Server()
- defer server.Close()
-
- cache := newTestS3Store(t, server, "my-bucket", "", 10*1024)
+func TestS3Store_WriteWithPrefix(t *testing.T) {
+ s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL")
+ if s3URL == "" {
+ t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set")
+ }
+ cfg, err := s3.ParseURL(s3URL)
+ require.Nil(t, err)
+ cfg.Prefix = "test-prefix"
+ client := s3.New(cfg)
+ deleteAllObjects(client)
+ backend := newS3Backend(client)
+ cache, err := newStore(backend, 10*1024, nil)
+ require.Nil(t, err)
+ t.Cleanup(func() {
+ deleteAllObjects(client)
+ cache.Close()
+ })
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0)
require.Nil(t, err)
@@ -37,241 +45,64 @@ func TestS3Store_WriteNoPrefix(t *testing.T) {
// --- Helpers ---
-func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store {
+func newTestRealS3Store(t *testing.T, totalSizeLimit int64) (*Store, *modTimeOverrideBackend) {
t.Helper()
- host := strings.TrimPrefix(server.URL, "https://")
- backend := newS3Backend(s3.New(&s3.Config{
- AccessKey: "AKID",
- SecretKey: "SECRET",
- Region: "us-east-1",
- Endpoint: host,
- Bucket: bucket,
- Prefix: prefix,
- PathStyle: true,
- HTTPClient: server.Client(),
- }))
- cache, err := newStore(backend, totalSizeLimit, nil)
+ s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL")
+ if s3URL == "" {
+ t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set")
+ }
+ cfg, err := s3.ParseURL(s3URL)
require.Nil(t, err)
- t.Cleanup(func() { cache.Close() })
- return cache
+ client := s3.New(cfg)
+ inner := newS3Backend(client)
+ wrapper := &modTimeOverrideBackend{backend: inner, modTimes: make(map[string]time.Time)}
+ deleteAllObjects(client)
+ store, err := newStore(wrapper, totalSizeLimit, nil)
+ require.Nil(t, err)
+ t.Cleanup(func() {
+ deleteAllObjects(client)
+ store.Close()
+ })
+ return store, wrapper
}
-// --- Mock S3 server ---
-//
-// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
-// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
-
-type mockS3Server struct {
- objects map[string][]byte // full key (bucket/key) -> body
- modTimes map[string]time.Time // full key (bucket/key) -> last modified time
- uploads map[string]map[int][]byte // uploadID -> partNumber -> data
- nextID int // counter for generating upload IDs
- mu sync.RWMutex
-}
-
-func newMockS3Server() (*httptest.Server, *mockS3Server) {
- m := &mockS3Server{
- objects: make(map[string][]byte),
- modTimes: make(map[string]time.Time),
- uploads: make(map[string]map[int][]byte),
+func deleteAllObjects(client *s3.Client) {
+ objects, _ := client.ListObjectsV2(context.Background())
+ keys := make([]string, 0, len(objects))
+ for _, obj := range objects {
+ keys = append(keys, obj.Key)
}
- return httptest.NewTLSServer(m), m
-}
-
-func (m *mockS3Server) setModTime(path string, t time.Time) {
- m.mu.Lock()
- m.modTimes[path] = t
- m.mu.Unlock()
-}
-
-func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // Path is /{bucket}[/{key...}]
- path := strings.TrimPrefix(r.URL.Path, "/")
- q := r.URL.Query()
-
- switch {
- case r.Method == http.MethodPut && q.Has("partNumber"):
- m.handleUploadPart(w, r, path)
- case r.Method == http.MethodPut:
- m.handlePut(w, r, path)
- case r.Method == http.MethodPost && q.Has("uploads"):
- m.handleInitiateMultipart(w, r, path)
- case r.Method == http.MethodPost && q.Has("uploadId"):
- m.handleCompleteMultipart(w, r, path)
- case r.Method == http.MethodDelete && q.Has("uploadId"):
- m.handleAbortMultipart(w, r, path)
- case r.Method == http.MethodGet && q.Get("list-type") == "2":
- m.handleList(w, r, path)
- case r.Method == http.MethodGet:
- m.handleGet(w, r, path)
- case r.Method == http.MethodPost && q.Has("delete"):
- m.handleDelete(w, r, path)
- default:
- http.Error(w, "not implemented", http.StatusNotImplemented)
+ if len(keys) > 0 {
+ client.DeleteObjects(context.Background(), keys) //nolint:errcheck
}
}
-func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
- body, err := io.ReadAll(r.Body)
+// modTimeOverrideBackend wraps a backend and allows overriding LastModified times returned by List().
+// This is used in tests to simulate old objects on backends (like real S3) where
+// LastModified cannot be set directly.
+type modTimeOverrideBackend struct {
+ backend
+ mu sync.Mutex
+ modTimes map[string]time.Time // object ID -> override time
+}
+
+func (b *modTimeOverrideBackend) List() ([]object, error) {
+ objects, err := b.backend.List()
if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
+ return nil, err
}
- m.mu.Lock()
- m.objects[path] = body
- m.modTimes[path] = time.Now()
- m.mu.Unlock()
- w.WriteHeader(http.StatusOK)
-}
-
-func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
- m.mu.Lock()
- m.nextID++
- uploadID := fmt.Sprintf("upload-%d", m.nextID)
- m.uploads[uploadID] = make(map[int][]byte)
- m.mu.Unlock()
-
- w.Header().Set("Content-Type", "application/xml")
- w.WriteHeader(http.StatusOK)
- fmt.Fprintf(w, `%s`, uploadID)
-}
-
-func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
- uploadID := r.URL.Query().Get("uploadId")
- var partNumber int
- fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
-
- body, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- m.mu.Lock()
- parts, ok := m.uploads[uploadID]
- if !ok {
- m.mu.Unlock()
- http.Error(w, "NoSuchUpload", http.StatusNotFound)
- return
- }
- parts[partNumber] = body
- m.mu.Unlock()
-
- etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
- w.Header().Set("ETag", etag)
- w.WriteHeader(http.StatusOK)
-}
-
-func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
- uploadID := r.URL.Query().Get("uploadId")
-
- m.mu.Lock()
- parts, ok := m.uploads[uploadID]
- if !ok {
- m.mu.Unlock()
- http.Error(w, "NoSuchUpload", http.StatusNotFound)
- return
- }
-
- // Assemble parts in order
- var assembled []byte
- for i := 1; i <= len(parts); i++ {
- assembled = append(assembled, parts[i]...)
- }
- m.objects[path] = assembled
- m.modTimes[path] = time.Now()
- delete(m.uploads, uploadID)
- m.mu.Unlock()
-
- w.Header().Set("Content-Type", "application/xml")
- w.WriteHeader(http.StatusOK)
- fmt.Fprintf(w, `%s`, path)
-}
-
-func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
- uploadID := r.URL.Query().Get("uploadId")
- m.mu.Lock()
- delete(m.uploads, uploadID)
- m.mu.Unlock()
- w.WriteHeader(http.StatusNoContent)
-}
-
-func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
- m.mu.RLock()
- body, ok := m.objects[path]
- m.mu.RUnlock()
- if !ok {
- w.WriteHeader(http.StatusNotFound)
- w.Write([]byte(`NoSuchKeyThe specified key does not exist.`))
- return
- }
- w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
- w.WriteHeader(http.StatusOK)
- w.Write(body)
-}
-
-func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
- // bucketPath is just the bucket name
- body, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- var req struct {
- Objects []struct {
- Key string `xml:"Key"`
- } `xml:"Object"`
- }
- if err := xml.Unmarshal(body, &req); err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- m.mu.Lock()
- for _, obj := range req.Objects {
- delete(m.objects, bucketPath+"/"+obj.Key)
- }
- m.mu.Unlock()
- w.WriteHeader(http.StatusOK)
- w.Write([]byte(``))
-}
-
-func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
- prefix := r.URL.Query().Get("prefix")
- m.mu.RLock()
- var contents []s3ListObject
- for key, body := range m.objects {
- // key is "bucket/objectkey", strip bucket prefix
- objKey := strings.TrimPrefix(key, bucketPath+"/")
- if objKey == key {
- continue // different bucket
- }
- if prefix == "" || strings.HasPrefix(objKey, prefix) {
- contents = append(contents, s3ListObject{
- Key: objKey,
- Size: int64(len(body)),
- LastModified: m.modTimes[key].Format(time.RFC3339),
- })
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ for i, obj := range objects {
+ if t, ok := b.modTimes[obj.ID]; ok {
+ objects[i].LastModified = t
}
}
- m.mu.RUnlock()
-
- resp := s3ListResponse{
- Contents: contents,
- IsTruncated: false,
- }
- w.Header().Set("Content-Type", "application/xml")
- w.WriteHeader(http.StatusOK)
- xml.NewEncoder(w).Encode(resp)
+ return objects, nil
}
-type s3ListResponse struct {
- XMLName xml.Name `xml:"ListBucketResult"`
- Contents []s3ListObject `xml:"Contents"`
- IsTruncated bool `xml:"IsTruncated"`
-}
-
-type s3ListObject struct {
- Key string `xml:"Key"`
- Size int64 `xml:"Size"`
- LastModified string `xml:"LastModified"`
+func (b *modTimeOverrideBackend) setModTime(id string, t time.Time) {
+ b.mu.Lock()
+ b.modTimes[id] = t
+ b.mu.Unlock()
}
diff --git a/attachment/store_test.go b/attachment/store_test.go
index 7b5a6013..645a2159 100644
--- a/attachment/store_test.go
+++ b/attachment/store_test.go
@@ -229,8 +229,9 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) {
// forEachBackend runs f against both the file and S3 backends. It also provides a makeOld
// callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour).
-// For the file backend, this uses os.Chtimes; for the S3 backend, it sets the object's
-// LastModified time in the mock server. Objects start with recent timestamps by default.
+// For the file backend, this uses os.Chtimes; for the S3 backend, it overrides the object's
+// LastModified time via a modTimeOverrideBackend wrapper. Objects start with recent timestamps
+// by default. The S3 subtest is skipped if NTFY_TEST_ATTACHMENT_S3_URL is not set.
func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) {
t.Run("file", func(t *testing.T) {
dir, s := newTestFileStore(t, totalSizeLimit)
@@ -241,11 +242,9 @@ func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *
f(t, s, makeOld)
})
t.Run("s3", func(t *testing.T) {
- server, mock := newMockS3Server()
- defer server.Close()
- s := newTestS3Store(t, server, "my-bucket", "pfx", totalSizeLimit)
+ s, wrapper := newTestRealS3Store(t, totalSizeLimit)
makeOld := func(id string) {
- mock.setModTime("my-bucket/pfx/"+id, time.Unix(1, 0))
+ wrapper.setModTime(id, time.Unix(1, 0))
}
f(t, s, makeOld)
})
diff --git a/s3/client.go b/s3/client.go
index 29cad3a4..d9ec1ab8 100644
--- a/s3/client.go
+++ b/s3/client.go
@@ -86,7 +86,7 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size
if err != nil {
return fmt.Errorf("uploading object %s failed: %w", key, err)
}
- resp.Body.Close()
+ defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return parseError(resp)
}