Use real S3 for tests

This commit is contained in:
binwiederhier 2026-03-22 08:15:23 -04:00
parent ad501feab1
commit f2d4575831
3 changed files with 71 additions and 241 deletions

View file

@ -1,11 +1,9 @@
package attachment package attachment
import ( import (
"encoding/xml" "context"
"fmt"
"io" "io"
"net/http" "os"
"net/http/httptest"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -15,13 +13,23 @@ import (
"heckel.io/ntfy/v2/s3" "heckel.io/ntfy/v2/s3"
) )
// --- S3-specific tests --- func TestS3Store_WriteWithPrefix(t *testing.T) {
s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL")
func TestS3Store_WriteNoPrefix(t *testing.T) { if s3URL == "" {
server, _ := newMockS3Server() t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set")
defer server.Close() }
cfg, err := s3.ParseURL(s3URL)
cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) require.Nil(t, err)
cfg.Prefix = "test-prefix"
client := s3.New(cfg)
deleteAllObjects(client)
backend := newS3Backend(client)
cache, err := newStore(backend, 10*1024, nil)
require.Nil(t, err)
t.Cleanup(func() {
deleteAllObjects(client)
cache.Close()
})
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0) size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0)
require.Nil(t, err) require.Nil(t, err)
@ -37,241 +45,64 @@ func TestS3Store_WriteNoPrefix(t *testing.T) {
// --- Helpers --- // --- Helpers ---
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { func newTestRealS3Store(t *testing.T, totalSizeLimit int64) (*Store, *modTimeOverrideBackend) {
t.Helper() t.Helper()
host := strings.TrimPrefix(server.URL, "https://") s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL")
backend := newS3Backend(s3.New(&s3.Config{ if s3URL == "" {
AccessKey: "AKID", t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set")
SecretKey: "SECRET", }
Region: "us-east-1", cfg, err := s3.ParseURL(s3URL)
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
}))
cache, err := newStore(backend, totalSizeLimit, nil)
require.Nil(t, err) require.Nil(t, err)
t.Cleanup(func() { cache.Close() }) client := s3.New(cfg)
return cache inner := newS3Backend(client)
wrapper := &modTimeOverrideBackend{backend: inner, modTimes: make(map[string]time.Time)}
deleteAllObjects(client)
store, err := newStore(wrapper, totalSizeLimit, nil)
require.Nil(t, err)
t.Cleanup(func() {
deleteAllObjects(client)
store.Close()
})
return store, wrapper
} }
// --- Mock S3 server --- func deleteAllObjects(client *s3.Client) {
// objects, _ := client.ListObjectsV2(context.Background())
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and keys := make([]string, 0, len(objects))
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. for _, obj := range objects {
keys = append(keys, obj.Key)
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
modTimes map[string]time.Time // full key (bucket/key) -> last modified time
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() (*httptest.Server, *mockS3Server) {
m := &mockS3Server{
objects: make(map[string][]byte),
modTimes: make(map[string]time.Time),
uploads: make(map[string]map[int][]byte),
} }
return httptest.NewTLSServer(m), m if len(keys) > 0 {
} client.DeleteObjects(context.Background(), keys) //nolint:errcheck
func (m *mockS3Server) setModTime(path string, t time.Time) {
m.mu.Lock()
m.modTimes[path] = t
m.mu.Unlock()
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
} }
} }
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { // modTimeOverrideBackend wraps a backend and allows overriding LastModified times returned by List().
body, err := io.ReadAll(r.Body) // This is used in tests to simulate old objects on backends (like real S3) where
// LastModified cannot be set directly.
type modTimeOverrideBackend struct {
backend
mu sync.Mutex
modTimes map[string]time.Time // object ID -> override time
}
func (b *modTimeOverrideBackend) List() ([]object, error) {
objects, err := b.backend.List()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) return nil, err
return
} }
m.mu.Lock() b.mu.Lock()
m.objects[path] = body defer b.mu.Unlock()
m.modTimes[path] = time.Now() for i, obj := range objects {
m.mu.Unlock() if t, ok := b.modTimes[obj.ID]; ok {
w.WriteHeader(http.StatusOK) objects[i].LastModified = t
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
m.modTimes[path] = time.Now()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
m.mu.RLock()
var contents []s3ListObject
for key, body := range m.objects {
// key is "bucket/objectkey", strip bucket prefix
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
contents = append(contents, s3ListObject{
Key: objKey,
Size: int64(len(body)),
LastModified: m.modTimes[key].Format(time.RFC3339),
})
} }
} }
m.mu.RUnlock() return objects, nil
resp := s3ListResponse{
Contents: contents,
IsTruncated: false,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
} }
type s3ListResponse struct { func (b *modTimeOverrideBackend) setModTime(id string, t time.Time) {
XMLName xml.Name `xml:"ListBucketResult"` b.mu.Lock()
Contents []s3ListObject `xml:"Contents"` b.modTimes[id] = t
IsTruncated bool `xml:"IsTruncated"` b.mu.Unlock()
}
type s3ListObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
LastModified string `xml:"LastModified"`
} }

View file

@ -229,8 +229,9 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) {
// forEachBackend runs f against both the file and S3 backends. It also provides a makeOld // forEachBackend runs f against both the file and S3 backends. It also provides a makeOld
// callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour). // callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour).
// For the file backend, this uses os.Chtimes; for the S3 backend, it sets the object's // For the file backend, this uses os.Chtimes; for the S3 backend, it overrides the object's
// LastModified time in the mock server. Objects start with recent timestamps by default. // LastModified time via a modTimeOverrideBackend wrapper. Objects start with recent timestamps
// by default. The S3 subtest is skipped if NTFY_TEST_ATTACHMENT_S3_URL is not set.
func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) { func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) {
t.Run("file", func(t *testing.T) { t.Run("file", func(t *testing.T) {
dir, s := newTestFileStore(t, totalSizeLimit) dir, s := newTestFileStore(t, totalSizeLimit)
@ -241,11 +242,9 @@ func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *
f(t, s, makeOld) f(t, s, makeOld)
}) })
t.Run("s3", func(t *testing.T) { t.Run("s3", func(t *testing.T) {
server, mock := newMockS3Server() s, wrapper := newTestRealS3Store(t, totalSizeLimit)
defer server.Close()
s := newTestS3Store(t, server, "my-bucket", "pfx", totalSizeLimit)
makeOld := func(id string) { makeOld := func(id string) {
mock.setModTime("my-bucket/pfx/"+id, time.Unix(1, 0)) wrapper.setModTime(id, time.Unix(1, 0))
} }
f(t, s, makeOld) f(t, s, makeOld)
}) })

View file

@ -86,7 +86,7 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size
if err != nil { if err != nil {
return fmt.Errorf("uploading object %s failed: %w", key, err) return fmt.Errorf("uploading object %s failed: %w", key, err)
} }
resp.Body.Close() defer resp.Body.Close()
if !isHTTPSuccess(resp) { if !isHTTPSuccess(resp) {
return parseError(resp) return parseError(resp)
} }