Tighten message IDs in attachments

This commit is contained in:
binwiederhier 2026-03-23 12:54:13 -04:00
parent 075f2ffa15
commit b95efe8dd3
2 changed files with 18 additions and 17 deletions

View file

@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"io"
"regexp"
"sync"
"time"
@ -20,10 +19,7 @@ const (
orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)
var errInvalidFileID = errors.New("invalid file ID")
// Store manages attachment storage with shared logic for size tracking, limiting,
// ID validation, and background sync to reconcile storage with the database.
@ -86,7 +82,7 @@ func newStore(backend backend, totalSizeLimit int64, attachmentsWithSizes func()
// from the client's Content-Length header; backends may use it to optimize uploads (e.g.
// streaming directly to S3 without buffering).
func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
if !model.ValidMessageID(id) {
return 0, errInvalidFileID
}
log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment")
@ -107,7 +103,7 @@ func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limite
// Read retrieves an attachment file by ID
func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
if !model.ValidMessageID(id) {
return nil, 0, errInvalidFileID
}
return c.backend.Get(id)
@ -118,7 +114,7 @@ func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
// started and before the first sync) are corrected by the next sync() call.
func (c *Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
if !model.ValidMessageID(id) {
return errInvalidFileID
}
}
@ -166,7 +162,7 @@ func (c *Store) sync() error {
var count, totalSize int64
sizes := make(map[string]int64, len(remoteObjects))
for _, obj := range remoteObjects {
if !fileIDRegex.MatchString(obj.ID) {
if !model.ValidMessageID(obj.ID) {
continue
}
if _, ok := attachmentsWithSizes[obj.ID]; !ok && obj.LastModified.Before(cutoff) {

View file

@ -19,8 +19,8 @@ const (
PollRequestEvent = "poll_request"
)
// MessageIDLength is the length of a randomly generated message ID
const MessageIDLength = 12
// messageIDLength is the length of a randomly generated message ID
const messageIDLength = 12
// Errors for message operations
var (
@ -133,10 +133,20 @@ func NewAction() *Action {
}
}
// GenerateMessageID creates a new random message ID
func GenerateMessageID() string {
return util.RandomString(messageIDLength)
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, messageIDLength)
}
// NewMessage creates a new message with the current timestamp
func NewMessage(event, topic, msg string) *Message {
return &Message{
ID: util.RandomString(MessageIDLength),
ID: GenerateMessageID(),
Time: time.Now().Unix(),
Event: event,
Topic: topic,
@ -173,11 +183,6 @@ func NewPollRequestMessage(topic, pollID string) *Message {
return m
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, MessageIDLength)
}
// SinceMarker represents a point in time or message ID from which to retrieve messages
type SinceMarker struct {
time time.Time