mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
PG races
This commit is contained in:
@@ -427,7 +427,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -725,16 +725,35 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
|
||||
|
||||
// AddReservation creates two access control entries for the given topic: one with full read/write
|
||||
// access for the given user, and one for Everyone with the given permission. Both entries are
|
||||
// created atomically in a single transaction.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
|
||||
// created atomically in a single transaction. If limit is > 0, the reservation count is checked
|
||||
// inside the transaction and ErrTooManyReservations is returned if the limit would be exceeded.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission, limit int64) error {
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||
if limit > 0 {
|
||||
hasReservation, err := a.hasReservationTx(tx, username, topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasReservation {
|
||||
count, err := a.reservationsCountTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count >= limit {
|
||||
return ErrTooManyReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), true, true, username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, Everyone, toSQLWildcard(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -752,10 +771,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -795,7 +811,11 @@ func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation,
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
return a.hasReservationTx(a.db, username, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) hasReservationTx(tx db.Querier, username, topic string) (bool, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -812,7 +832,11 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
|
||||
return a.reservationsCountTx(a.db, username)
|
||||
}
|
||||
|
||||
func (a *Manager) reservationsCountTx(tx db.Querier, username string) (int64, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -844,6 +868,30 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// RemoveExcessReservations removes reservations that exceed the given limit for the user.
|
||||
// It returns the list of topics whose reservations were removed. The read and removal are
|
||||
// performed atomically in a single transaction to avoid issues with stale replica data.
|
||||
func (a *Manager) RemoveExcessReservations(username string, limit int64) ([]string, error) {
|
||||
return db.QueryTx(a.db, func(tx *sql.Tx) ([]string, error) {
|
||||
reservations, err := a.reservationsTx(tx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(reservations)) <= limit {
|
||||
return []string{}, nil
|
||||
}
|
||||
removedTopics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= limit; i-- {
|
||||
topic := reservations[i].Topic
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removedTopics = append(removedTopics, topic)
|
||||
}
|
||||
return removedTopics, nil
|
||||
})
|
||||
}
|
||||
|
||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
@@ -861,14 +909,11 @@ func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topic) {
|
||||
return ErrInvalidArgument
|
||||
func (a *Manager) removeReservationAccessTx(tx *sql.Tx, username, topic string) error {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
|
||||
return err
|
||||
return a.resetTopicAccessTx(tx, Everyone, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
|
||||
@@ -1150,7 +1195,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
|
||||
rows, err := a.db.Query(a.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1160,7 +1205,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user