Fix postgres primary/replica races

This commit is contained in:
binwiederhier
2026-03-16 11:21:21 -04:00
parent df82fdf44c
commit 59ce581ba2
2 changed files with 37 additions and 15 deletions

View File

@@ -11,6 +11,12 @@ type Beginner interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
} }
// Querier is an interface for types that can execute SQL queries.
// *sql.DB, *sql.Tx, and *DB all implement this.
type Querier interface {
Query(query string, args ...any) (*sql.Rows, error)
}
// Host pairs a *sql.DB with the host:port it was opened against. // Host pairs a *sql.DB with the host:port it was opened against.
type Host struct { type Host struct {
Addr string // "host:port" Addr string // "host:port"

View File

@@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error {
t, err := a.Tier(tier) t, err := a.Tier(tier)
if err != nil { if err != nil {
return err return err
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { }
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil {
return err return err
} }
if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil { if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil {
return err return err
} }
return nil return nil
})
} }
// ResetTier removes the tier from the given user // ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error { func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" { if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument return ErrInvalidArgument
} else if err := a.checkReservationsLimit(username, 0); err != nil { }
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.checkReservationsLimitTx(tx, username, 0); err != nil {
return err return err
} }
_, err := a.db.Exec(a.queries.deleteUserTier, username) if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil {
return err return err
} }
return nil
})
}
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error {
u, err := a.User(username) u, err := a.userTx(tx, username)
if err != nil { if err != nil {
return err return err
} }
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit { if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
reservations, err := a.Reservations(username) reservations, err := a.reservationsTx(tx, username)
if err != nil { if err != nil {
return err return err
} else if int64(len(reservations)) > reservationsLimit { } else if int64(len(reservations)) > reservationsLimit {
@@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error {
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise // User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (a *Manager) User(username string) (*User, error) { func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByName, username) return a.userTx(a.db, username)
}
func (a *Manager) userTx(tx db.Querier, username string) (*User, error) {
rows, err := tx.Query(a.queries.selectUserByName, username)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -753,7 +765,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
// Reservations returns all user-owned topics, and the associated everyone-access // Reservations returns all user-owned topics, and the associated everyone-access
func (a *Manager) Reservations(username string) ([]Reservation, error) { func (a *Manager) Reservations(username string) ([]Reservation, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username) return a.reservationsTx(a.db.ReadOnly(), username)
}
func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) {
rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username)
if err != nil { if err != nil {
return nil, err return nil, err
} }