diff --git a/db/types.go b/db/types.go index 534d6168..137753a4 100644 --- a/db/types.go +++ b/db/types.go @@ -11,6 +11,12 @@ type Beginner interface { Begin() (*sql.Tx, error) } +// Querier is an interface for types that can execute SQL queries. +// *sql.DB, *sql.Tx, and *DB all implement this. +type Querier interface { + Query(query string, args ...any) (*sql.Rows, error) +} + // Host pairs a *sql.DB with the host:port it was opened against. type Host struct { Addr string // "host:port" diff --git a/user/manager.go b/user/manager.go index 28243a24..76352384 100644 --- a/user/manager.go +++ b/user/manager.go @@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error { t, err := a.Tier(tier) if err != nil { return err - } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { - return err } - if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil { - return err - } - return nil + return db.ExecTx(a.db, func(tx *sql.Tx) error { + if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil { + return err + } + if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil { + return err + } + return nil + }) } // ResetTier removes the tier from the given user func (a *Manager) ResetTier(username string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument - } else if err := a.checkReservationsLimit(username, 0); err != nil { - return err } - _, err := a.db.Exec(a.queries.deleteUserTier, username) - return err + return db.ExecTx(a.db, func(tx *sql.Tx) error { + if err := a.checkReservationsLimitTx(tx, username, 0); err != nil { + return err + } + if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil { + return err + } + return nil + }) } -func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { - u, err := a.User(username) +func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error { + u, err := a.userTx(tx, username) if err != nil { return err } if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit { - reservations, err := a.Reservations(username) + reservations, err := a.reservationsTx(tx, username) if err != nil { return err } else if int64(len(reservations)) > reservationsLimit { @@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error { // User returns the user with the given username if it exists, or ErrUserNotFound otherwise func (a *Manager) User(username string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByName, username) + return a.userTx(a.db, username) +} + +func (a *Manager) userTx(tx db.Querier, username string) (*User, error) { + rows, err := tx.Query(a.queries.selectUserByName, username) if err != nil { return nil, err } @@ -753,7 +765,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { // Reservations returns all user-owned topics, and the associated everyone-access func (a *Manager) Reservations(username string) ([]Reservation, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username) + return a.reservationsTx(a.db.ReadOnly(), username) +} + +func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) { + rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username) if err != nil { return nil, err }