Tighten message IDs in attachments

This commit is contained in:
binwiederhier 2026-03-23 12:54:13 -04:00
parent 075f2ffa15
commit b95efe8dd3
2 changed files with 18 additions and 17 deletions

View file

@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"regexp"
"sync" "sync"
"time" "time"
@ -20,10 +19,7 @@ const (
orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads
) )
var ( var errInvalidFileID = errors.New("invalid file ID")
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)
// Store manages attachment storage with shared logic for size tracking, limiting, // Store manages attachment storage with shared logic for size tracking, limiting,
// ID validation, and background sync to reconcile storage with the database. // ID validation, and background sync to reconcile storage with the database.
@ -86,7 +82,7 @@ func newStore(backend backend, totalSizeLimit int64, attachmentsWithSizes func()
// from the client's Content-Length header; backends may use it to optimize uploads (e.g. // from the client's Content-Length header; backends may use it to optimize uploads (e.g.
// streaming directly to S3 without buffering). // streaming directly to S3 without buffering).
func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) { func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) { if !model.ValidMessageID(id) {
return 0, errInvalidFileID return 0, errInvalidFileID
} }
log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment") log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment")
@ -107,7 +103,7 @@ func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limite
// Read retrieves an attachment file by ID // Read retrieves an attachment file by ID
func (c *Store) Read(id string) (io.ReadCloser, int64, error) { func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) { if !model.ValidMessageID(id) {
return nil, 0, errInvalidFileID return nil, 0, errInvalidFileID
} }
return c.backend.Get(id) return c.backend.Get(id)
@ -118,7 +114,7 @@ func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
// started and before the first sync) are corrected by the next sync() call. // started and before the first sync) are corrected by the next sync() call.
func (c *Store) Remove(ids ...string) error { func (c *Store) Remove(ids ...string) error {
for _, id := range ids { for _, id := range ids {
if !fileIDRegex.MatchString(id) { if !model.ValidMessageID(id) {
return errInvalidFileID return errInvalidFileID
} }
} }
@ -166,7 +162,7 @@ func (c *Store) sync() error {
var count, totalSize int64 var count, totalSize int64
sizes := make(map[string]int64, len(remoteObjects)) sizes := make(map[string]int64, len(remoteObjects))
for _, obj := range remoteObjects { for _, obj := range remoteObjects {
if !fileIDRegex.MatchString(obj.ID) { if !model.ValidMessageID(obj.ID) {
continue continue
} }
if _, ok := attachmentsWithSizes[obj.ID]; !ok && obj.LastModified.Before(cutoff) { if _, ok := attachmentsWithSizes[obj.ID]; !ok && obj.LastModified.Before(cutoff) {

View file

@ -19,8 +19,8 @@ const (
PollRequestEvent = "poll_request" PollRequestEvent = "poll_request"
) )
// MessageIDLength is the length of a randomly generated message ID // messageIDLength is the length of a randomly generated message ID
const MessageIDLength = 12 const messageIDLength = 12
// Errors for message operations // Errors for message operations
var ( var (
@ -133,10 +133,20 @@ func NewAction() *Action {
} }
} }
// GenerateMessageID creates a new random message ID
func GenerateMessageID() string {
return util.RandomString(messageIDLength)
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, messageIDLength)
}
// NewMessage creates a new message with the current timestamp // NewMessage creates a new message with the current timestamp
func NewMessage(event, topic, msg string) *Message { func NewMessage(event, topic, msg string) *Message {
return &Message{ return &Message{
ID: util.RandomString(MessageIDLength), ID: GenerateMessageID(),
Time: time.Now().Unix(), Time: time.Now().Unix(),
Event: event, Event: event,
Topic: topic, Topic: topic,
@ -173,11 +183,6 @@ func NewPollRequestMessage(topic, pollID string) *Message {
return m return m
} }
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, MessageIDLength)
}
// SinceMarker represents a point in time or message ID from which to retrieve messages // SinceMarker represents a point in time or message ID from which to retrieve messages
type SinceMarker struct { type SinceMarker struct {
time time.Time time time.Time