mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Fix postgres primary/replica races
This commit is contained in:
@@ -11,6 +11,12 @@ type Beginner interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// Querier is an interface for types that can execute SQL queries.
|
||||
// *sql.DB, *sql.Tx, and *DB all implement this.
|
||||
type Querier interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// Host pairs a *sql.DB with the host:port it was opened against.
|
||||
type Host struct {
|
||||
Addr string // "host:port"
|
||||
|
||||
@@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error {
|
||||
t, err := a.Tier(tier)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil {
|
||||
if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (a *Manager) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if err := a.checkReservationsLimit(username, 0); err != nil {
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.checkReservationsLimitTx(tx, username, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := a.db.Exec(a.queries.deleteUserTier, username)
|
||||
if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
|
||||
u, err := a.User(username)
|
||||
func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error {
|
||||
u, err := a.userTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
|
||||
reservations, err := a.Reservations(username)
|
||||
reservations, err := a.reservationsTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) > reservationsLimit {
|
||||
@@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) User(username string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByName, username)
|
||||
return a.userTx(a.db, username)
|
||||
}
|
||||
|
||||
func (a *Manager) userTx(tx db.Querier, username string) (*User, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -753,7 +765,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
|
||||
return a.reservationsTx(a.db.ReadOnly(), username)
|
||||
}
|
||||
|
||||
func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user