Move auth queries to primary, redo health check loop

This commit is contained in:
binwiederhier 2026-03-11 20:26:29 -04:00
parent ab33ac7ae5
commit ac65df1e83
2 changed files with 49 additions and 35 deletions

View file

@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"sync/atomic"
"time"
@ -9,7 +10,8 @@ import (
)
const (
replicaHealthCheckInterval = 5 * time.Second
replicaHealthCheckInterval = 30 * time.Second
replicaHealthCheckTimeout = 2 * time.Second
)
// Beginner is an interface for types that can begin a database transaction.
@ -25,25 +27,29 @@ type DB struct {
primary *sql.DB
replicas []*replica
counter atomic.Uint64
cancel context.CancelFunc
}
type replica struct {
db *sql.DB
healthy atomic.Bool
lastChecked atomic.Int64
db *sql.DB
healthy atomic.Bool
}
// NewDB creates a new DB that wraps the given primary and optional replica connections.
// If replicas is nil or empty, ReadOnly() simply returns the primary.
// Replicas start unhealthy and are checked immediately by a background goroutine.
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
ctx, cancel := context.WithCancel(context.Background())
d := &DB{
primary: primary,
replicas: make([]*replica, len(replicas)),
cancel: cancel,
}
for i, r := range replicas {
rep := &replica{db: r}
rep.healthy.Store(true)
d.replicas[i] = rep
d.replicas[i] = &replica{db: r} // healthy defaults to false
}
if len(d.replicas) > 0 {
go d.healthCheckLoop(ctx)
}
return d
}
@ -79,8 +85,9 @@ func (d *DB) Ping() error {
return d.primary.Ping()
}
// Close closes the primary database and all replicas.
// Close closes the primary database and all replicas, and stops the health-check goroutine.
func (d *DB) Close() error {
d.cancel()
for _, r := range d.replicas {
r.db.Close()
}
@ -88,9 +95,7 @@ func (d *DB) Close() error {
}
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it
// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary
// is returned.
// replicas. If all replicas are unhealthy or none are configured, the primary is returned.
func (d *DB) ReadOnly() *sql.DB {
if len(d.replicas) == 0 {
return d.primary
@ -99,34 +104,43 @@ func (d *DB) ReadOnly() *sql.DB {
start := int(d.counter.Add(1) - 1)
for i := 0; i < n; i++ {
r := d.replicas[(start+i)%n]
if d.isHealthy(r) {
if r.healthy.Load() {
return r.db
}
}
return d.primary
}
// isHealthy returns whether the replica is healthy. If the cached health status is stale,
// it pings the replica and updates the cache.
func (d *DB) isHealthy(r *replica) bool {
now := time.Now().Unix()
lastChecked := r.lastChecked.Load()
if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) {
if r.lastChecked.CompareAndSwap(lastChecked, now) {
wasHealthy := r.healthy.Load()
if err := r.db.Ping(); err != nil {
r.healthy.Store(false)
if wasHealthy {
log.Error("Database replica is now unhealthy: %s", err)
}
return false
// healthCheckLoop checks replicas immediately, then periodically on a ticker.
func (d *DB) healthCheckLoop(ctx context.Context) {
d.checkReplicas(ctx)
for {
select {
case <-ctx.Done():
return
case <-time.After(replicaHealthCheckInterval):
d.checkReplicas(ctx)
}
}
}
// checkReplicas pings each replica with a timeout and updates its health status.
func (d *DB) checkReplicas(ctx context.Context) {
for _, r := range d.replicas {
wasHealthy := r.healthy.Load()
pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout)
err := r.db.PingContext(pingCtx)
cancel()
if err != nil {
r.healthy.Store(false)
if wasHealthy {
log.Error("Database replica is now unhealthy: %s", err)
}
} else {
r.healthy.Store(true)
if !wasHealthy {
log.Info("Database replica is now healthy again")
}
return true
}
}
return r.healthy.Load()
}

View file

@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error {
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username)
rows, err := a.db.Query(a.queries.selectUserByName, username)
if err != nil {
return nil, err
}
@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) {
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByID(id string) (*User, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id)
rows, err := a.db.Query(a.queries.selectUserByID, id)
if err != nil {
return nil, err
}
@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) {
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix())
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
if err != nil {
return nil, err
}
@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
return false, false, false, err
}
@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
// HasReservation returns true if the given topic access is owned by the user
func (a *Manager) HasReservation(username, topic string) (bool, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
if err != nil {
return false, err
}
@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
func (a *Manager) ReservationOwner(topic string) (string, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
if err != nil {
return "", err
}
@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
if err != nil {
return 0, err
}