mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-05-15 07:35:49 -06:00
Move auth queries to primary, redo health check loop
This commit is contained in:
parent
ab33ac7ae5
commit
ac65df1e83
2 changed files with 49 additions and 35 deletions
70
db/db.go
70
db/db.go
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
|
@ -9,7 +10,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
replicaHealthCheckInterval = 5 * time.Second
|
||||
replicaHealthCheckInterval = 30 * time.Second
|
||||
replicaHealthCheckTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// Beginner is an interface for types that can begin a database transaction.
|
||||
|
|
@ -25,25 +27,29 @@ type DB struct {
|
|||
primary *sql.DB
|
||||
replicas []*replica
|
||||
counter atomic.Uint64
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type replica struct {
|
||||
db *sql.DB
|
||||
healthy atomic.Bool
|
||||
lastChecked atomic.Int64
|
||||
db *sql.DB
|
||||
healthy atomic.Bool
|
||||
}
|
||||
|
||||
// NewDB creates a new DB that wraps the given primary and optional replica connections.
|
||||
// If replicas is nil or empty, ReadOnly() simply returns the primary.
|
||||
// Replicas start unhealthy and are checked immediately by a background goroutine.
|
||||
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
d := &DB{
|
||||
primary: primary,
|
||||
replicas: make([]*replica, len(replicas)),
|
||||
cancel: cancel,
|
||||
}
|
||||
for i, r := range replicas {
|
||||
rep := &replica{db: r}
|
||||
rep.healthy.Store(true)
|
||||
d.replicas[i] = rep
|
||||
d.replicas[i] = &replica{db: r} // healthy defaults to false
|
||||
}
|
||||
if len(d.replicas) > 0 {
|
||||
go d.healthCheckLoop(ctx)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
|
@ -79,8 +85,9 @@ func (d *DB) Ping() error {
|
|||
return d.primary.Ping()
|
||||
}
|
||||
|
||||
// Close closes the primary database and all replicas.
|
||||
// Close closes the primary database and all replicas, and stops the health-check goroutine.
|
||||
func (d *DB) Close() error {
|
||||
d.cancel()
|
||||
for _, r := range d.replicas {
|
||||
r.db.Close()
|
||||
}
|
||||
|
|
@ -88,9 +95,7 @@ func (d *DB) Close() error {
|
|||
}
|
||||
|
||||
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
|
||||
// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it
|
||||
// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary
|
||||
// is returned.
|
||||
// replicas. If all replicas are unhealthy or none are configured, the primary is returned.
|
||||
func (d *DB) ReadOnly() *sql.DB {
|
||||
if len(d.replicas) == 0 {
|
||||
return d.primary
|
||||
|
|
@ -99,34 +104,43 @@ func (d *DB) ReadOnly() *sql.DB {
|
|||
start := int(d.counter.Add(1) - 1)
|
||||
for i := 0; i < n; i++ {
|
||||
r := d.replicas[(start+i)%n]
|
||||
if d.isHealthy(r) {
|
||||
if r.healthy.Load() {
|
||||
return r.db
|
||||
}
|
||||
}
|
||||
return d.primary
|
||||
}
|
||||
|
||||
// isHealthy returns whether the replica is healthy. If the cached health status is stale,
|
||||
// it pings the replica and updates the cache.
|
||||
func (d *DB) isHealthy(r *replica) bool {
|
||||
now := time.Now().Unix()
|
||||
lastChecked := r.lastChecked.Load()
|
||||
if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) {
|
||||
if r.lastChecked.CompareAndSwap(lastChecked, now) {
|
||||
wasHealthy := r.healthy.Load()
|
||||
if err := r.db.Ping(); err != nil {
|
||||
r.healthy.Store(false)
|
||||
if wasHealthy {
|
||||
log.Error("Database replica is now unhealthy: %s", err)
|
||||
}
|
||||
return false
|
||||
// healthCheckLoop checks replicas immediately, then periodically on a ticker.
|
||||
func (d *DB) healthCheckLoop(ctx context.Context) {
|
||||
d.checkReplicas(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(replicaHealthCheckInterval):
|
||||
d.checkReplicas(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkReplicas pings each replica with a timeout and updates its health status.
|
||||
func (d *DB) checkReplicas(ctx context.Context) {
|
||||
for _, r := range d.replicas {
|
||||
wasHealthy := r.healthy.Load()
|
||||
pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout)
|
||||
err := r.db.PingContext(pingCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
r.healthy.Store(false)
|
||||
if wasHealthy {
|
||||
log.Error("Database replica is now unhealthy: %s", err)
|
||||
}
|
||||
} else {
|
||||
r.healthy.Store(true)
|
||||
if !wasHealthy {
|
||||
log.Info("Database replica is now healthy again")
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return r.healthy.Load()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error {
|
|||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) User(username string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username)
|
||||
rows, err := a.db.Query(a.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) {
|
|||
|
||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByID(id string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id)
|
||||
rows, err := a.db.Query(a.queries.selectUserByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) {
|
|||
|
||||
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||
func (a *Manager) userByToken(token string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
||||
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
|
|||
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
|
||||
// - It also prioritizes write permissions over read permissions
|
||||
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
|
|
@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
|||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
|||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
|||
|
||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue