diff --git a/attachment/store.go b/attachment/store.go index d70ea2ab..a9eaaeae 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -28,44 +28,54 @@ var ( // Store manages attachment storage with shared logic for size tracking, limiting, // ID validation, and background sync to reconcile storage with the database. type Store struct { - backend backend - limit int64 // Defined limit of the store in bytes - size int64 // Current size of the store in bytes - sizes map[string]int64 // File ID -> size, for subtracting on Remove - localIDs func() ([]string, error) // Returns file IDs that should exist locally, used for sync() - closeChan chan struct{} - mu sync.RWMutex // Protects size and sizes + backend backend + limit int64 // Defined limit of the store in bytes + size int64 // Current size of the store in bytes + sizes map[string]int64 // File ID -> size, for subtracting on Remove + attachmentsWithSizes func() (map[string]int64, error) // Returns file ID -> size for active attachments + closeChan chan struct{} + mu sync.RWMutex // Protects size and sizes } // NewFileStore creates a new file-system backed attachment cache -func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) { +func NewFileStore(dir string, totalSizeLimit int64, attachmentsWithSizes func() (map[string]int64, error)) (*Store, error) { b, err := newFileBackend(dir) if err != nil { return nil, err } - return newStore(b, totalSizeLimit, localIDsFn) + return newStore(b, totalSizeLimit, attachmentsWithSizes) } // NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format: // // s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] -func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { +func NewS3Store(s3URL string, totalSizeLimit int64, attachmentsWithSizes func() (map[string]int64, error)) (*Store, error) { config, err := s3.ParseURL(s3URL) if err != nil { return nil, err } - return newStore(newS3Backend(s3.New(config)), totalSizeLimit, localIDs) + return newStore(newS3Backend(s3.New(config)), totalSizeLimit, attachmentsWithSizes) } -func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { +func newStore(backend backend, totalSizeLimit int64, attachmentsWithSizes func() (map[string]int64, error)) (*Store, error) { c := &Store{ - backend: backend, - limit: totalSizeLimit, - sizes: make(map[string]int64), - localIDs: localIDs, - closeChan: make(chan struct{}), + backend: backend, + limit: totalSizeLimit, + sizes: make(map[string]int64), + attachmentsWithSizes: attachmentsWithSizes, + closeChan: make(chan struct{}), } - if localIDs != nil { + // Hydrate sizes from the database immediately so that Size()/Remaining()/Remove() + // are accurate from the start, without waiting for the first sync() call. + if attachmentsWithSizes != nil { + attachments, err := attachmentsWithSizes() + if err != nil { + return nil, fmt.Errorf("attachment store: failed to load existing attachments: %w", err) + } + for id, size := range attachments { + c.sizes[id] = size + c.size += size + } go c.syncLoop() } return c, nil @@ -136,18 +146,14 @@ func (c *Store) Remove(ids ...string) error { // sync reconciles the backend storage with the database. It lists all objects, // deletes orphans (not in the valid ID set and older than 1 hour), and recomputes -// the total size from the remaining objects. +// the total size from the existing attachments in the database. func (c *Store) sync() error { - if c.localIDs == nil { + if c.attachmentsWithSizes == nil { return nil } - localIDs, err := c.localIDs() + attachmentsWithSizes, err := c.attachmentsWithSizes() if err != nil { - return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err) - } - localIDMap := make(map[string]struct{}, len(localIDs)) - for _, id := range localIDs { - localIDMap[id] = struct{}{} + return fmt.Errorf("attachment sync: failed to get existing attachments: %w", err) } remoteObjects, err := c.backend.List() if err != nil { @@ -157,23 +163,23 @@ func (c *Store) sync() error { // than the grace period to account for races, and skipping objects with invalid IDs. cutoff := time.Now().Add(-orphanGracePeriod) var orphanIDs []string - var count, size int64 + var count, totalSize int64 sizes := make(map[string]int64, len(remoteObjects)) for _, obj := range remoteObjects { if !fileIDRegex.MatchString(obj.ID) { continue } - if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) { + if _, ok := attachmentsWithSizes[obj.ID]; !ok && obj.LastModified.Before(cutoff) { orphanIDs = append(orphanIDs, obj.ID) } else { count++ - size += obj.Size - sizes[obj.ID] = obj.Size + totalSize += attachmentsWithSizes[obj.ID] + sizes[obj.ID] = attachmentsWithSizes[obj.ID] } } - log.Tag(tagStore).Debug("Attachment store updated: %d attachment(s), %s", count, util.FormatSizeHuman(size)) + log.Tag(tagStore).Debug("Attachment store updated: %d attachment(s), %s", count, util.FormatSizeHuman(totalSize)) c.mu.Lock() - c.size = size + c.size = totalSize c.sizes = sizes c.mu.Unlock() // Delete orphaned attachments diff --git a/attachment/store_test.go b/attachment/store_test.go index 11d0b244..0cb32a3c 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -162,9 +162,9 @@ func TestStore_SyncRecomputesSize(t *testing.T) { s.mu.Unlock() require.Equal(t, int64(999), s.Size()) - // Set localIDs to include both files so nothing gets deleted - s.localIDs = func() ([]string, error) { - return []string{"abcdefghijk0", "abcdefghijk1"}, nil + // Set attachmentsWithSizes to include both files so nothing gets deleted + s.attachmentsWithSizes = func() (map[string]int64, error) { + return map[string]int64{"abcdefghijk0": 100, "abcdefghijk1": 200}, nil } // Sync should recompute size from the backend @@ -280,8 +280,8 @@ func TestStore_Sync(t *testing.T) { require.Equal(t, int64(15), s.Size()) // Set the ID provider to only know about file 0 and 2 - s.localIDs = func() ([]string, error) { - return []string{"abcdefghijk0", "abcdefghijk2"}, nil + s.attachmentsWithSizes = func() (map[string]int64, error) { + return map[string]int64{"abcdefghijk0": 5, "abcdefghijk2": 5}, nil } // Make file 1 old enough to be cleaned up @@ -314,8 +314,8 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) { require.Nil(t, err) // Set the ID provider to return empty (no valid IDs) - s.localIDs = func() ([]string, error) { - return []string{}, nil + s.attachmentsWithSizes = func() (map[string]int64, error) { + return map[string]int64{}, nil } // File was just created, so it should NOT be deleted (< 1 hour old) diff --git a/message/cache.go b/message/cache.go index dd4ef0a4..76ba7926 100644 --- a/message/cache.go +++ b/message/cache.go @@ -43,10 +43,10 @@ type queries struct { selectAttachmentsExpired string selectAttachmentsSizeBySender string selectAttachmentsSizeByUserID string + selectAttachmentsWithSizes string selectStats string updateStats string updateMessageTime string - selectAttachmentIDs string } // Cache stores published messages @@ -363,16 +363,6 @@ func (c *Cache) ExpireMessages(topics ...string) error { }) } -// AttachmentIDs returns message IDs with active (non-expired, non-deleted) attachments -func (c *Cache) AttachmentIDs() ([]string, error) { - rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentIDs, time.Now().Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - return readStrings(rows) -} - // AttachmentsExpired returns message IDs with expired attachments that have not been deleted func (c *Cache) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) @@ -415,6 +405,30 @@ func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) { return c.readAttachmentBytesUsed(rows) } +// AttachmentsWithSizes returns a map of message ID to attachment size for all active +// (non-expired, non-deleted) attachments. This is used to hydrate the attachment store's +// size tracking on startup and during periodic sync. +func (c *Cache) AttachmentsWithSizes() (map[string]int64, error) { + rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsWithSizes, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + attachments := make(map[string]int64) + for rows.Next() { + var id string + var size int64 + if err := rows.Scan(&id, &size); err != nil { + return nil, err + } + attachments[id] = size + } + if err := rows.Err(); err != nil { + return nil, err + } + return attachments, nil +} + func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { defer rows.Close() var size int64 diff --git a/message/cache_postgres.go b/message/cache_postgres.go index d59b2590..f0a32036 100644 --- a/message/cache_postgres.go +++ b/message/cache_postgres.go @@ -70,12 +70,11 @@ const ( postgresSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE` postgresSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2` postgresSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2` + postgresSelectAttachmentsWithSizesQuery = `SELECT mid, attachment_size FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE` postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'` postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'` postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2` - - postgresSelectAttachmentIDsQuery = `SELECT mid FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE` ) var postgresQueries = queries{ @@ -99,10 +98,10 @@ var postgresQueries = queries{ selectAttachmentsExpired: postgresSelectAttachmentsExpiredQuery, selectAttachmentsSizeBySender: postgresSelectAttachmentsSizeBySenderQuery, selectAttachmentsSizeByUserID: postgresSelectAttachmentsSizeByUserIDQuery, + selectAttachmentsWithSizes: postgresSelectAttachmentsWithSizesQuery, selectStats: postgresSelectStatsQuery, updateStats: postgresUpdateStatsQuery, updateMessageTime: postgresUpdateMessageTimeQuery, - selectAttachmentIDs: postgresSelectAttachmentIDsQuery, } // NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index 6126f1e1..b39095e0 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -73,12 +73,11 @@ const ( sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` + sqliteSelectAttachmentsWithSizesQuery = `SELECT mid, attachment_size FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0` sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?` - - sqliteSelectAttachmentIDsQuery = `SELECT mid FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0` ) var sqliteQueries = queries{ @@ -102,10 +101,10 @@ var sqliteQueries = queries{ selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery, selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery, selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery, + selectAttachmentsWithSizes: sqliteSelectAttachmentsWithSizesQuery, selectStats: sqliteSelectStatsQuery, updateStats: sqliteUpdateStatsQuery, updateMessageTime: sqliteUpdateMessageTimeQuery, - selectAttachmentIDs: sqliteSelectAttachmentIDsQuery, } // NewSQLiteStore creates a SQLite file-backed cache diff --git a/server/server.go b/server/server.go index 77e7b0c0..dc56d57f 100644 --- a/server/server.go +++ b/server/server.go @@ -301,13 +301,10 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { } func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) { - attachmentIDs := func() ([]string, error) { - return messageCache.AttachmentIDs() - } if strings.HasPrefix(conf.AttachmentCacheDir, "s3://") { - return attachment.NewS3Store(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, attachmentIDs) + return attachment.NewS3Store(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, messageCache.AttachmentsWithSizes) } else if conf.AttachmentCacheDir != "" { - return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, attachmentIDs) + return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, messageCache.AttachmentsWithSizes) } return nil, nil }