mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
PG races
This commit is contained in:
@@ -3,14 +3,15 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -455,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
return errHTTPUnauthorized
|
||||
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
|
||||
return errHTTPConflictTopicReserved
|
||||
} else if u.IsUser() {
|
||||
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasReservation {
|
||||
reservations, err := s.userManager.ReservationsCount(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if reservations >= u.Tier.ReservationLimit {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
// Actually add the reservation
|
||||
// Actually add the reservation (with limit check inside the transaction to avoid races)
|
||||
logvr(v, r).
|
||||
Tag(tagAccount).
|
||||
Fields(log.Context{
|
||||
@@ -477,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
"everyone": everyone.String(),
|
||||
}).
|
||||
Debug("Adding topic reservation")
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
|
||||
var limit int64
|
||||
if u.IsUser() && u.Tier != nil {
|
||||
limit = u.Tier.ReservationLimit
|
||||
}
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil {
|
||||
if errors.Is(err, user.ErrTooManyReservations) {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Kill existing subscribers
|
||||
@@ -530,22 +525,16 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
}
|
||||
if len(removedTopics) == 0 {
|
||||
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", "))
|
||||
if err := s.messageCache.ExpireMessages(removedTopics...); err != nil {
|
||||
return err
|
||||
}
|
||||
go s.pruneMessages()
|
||||
|
||||
@@ -503,7 +503,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("noadmin1", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll, 0))
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("noadmin2", "pro"))
|
||||
|
||||
@@ -478,8 +478,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll, 0))
|
||||
|
||||
// Add billing details
|
||||
u, err := s.userManager.User("phil")
|
||||
@@ -589,7 +589,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
|
||||
|
||||
// Add billing details
|
||||
u, err := s.userManager.User("phil")
|
||||
|
||||
@@ -427,7 +427,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -725,16 +725,35 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
|
||||
|
||||
// AddReservation creates two access control entries for the given topic: one with full read/write
|
||||
// access for the given user, and one for Everyone with the given permission. Both entries are
|
||||
// created atomically in a single transaction.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
|
||||
// created atomically in a single transaction. If limit is > 0, the reservation count is checked
|
||||
// inside the transaction and ErrTooManyReservations is returned if the limit would be exceeded.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission, limit int64) error {
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||
if limit > 0 {
|
||||
hasReservation, err := a.hasReservationTx(tx, username, topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
|
||||
if !hasReservation {
|
||||
count, err := a.reservationsCountTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count >= limit {
|
||||
return ErrTooManyReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), true, true, username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, Everyone, toSQLWildcard(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -752,10 +771,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -795,7 +811,11 @@ func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation,
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
return a.hasReservationTx(a.db, username, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) hasReservationTx(tx db.Querier, username, topic string) (bool, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -812,7 +832,11 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
|
||||
return a.reservationsCountTx(a.db, username)
|
||||
}
|
||||
|
||||
func (a *Manager) reservationsCountTx(tx db.Querier, username string) (int64, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -844,6 +868,30 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// RemoveExcessReservations removes reservations that exceed the given limit for the user.
|
||||
// It returns the list of topics whose reservations were removed. The read and removal are
|
||||
// performed atomically in a single transaction to avoid issues with stale replica data.
|
||||
func (a *Manager) RemoveExcessReservations(username string, limit int64) ([]string, error) {
|
||||
return db.QueryTx(a.db, func(tx *sql.Tx) ([]string, error) {
|
||||
reservations, err := a.reservationsTx(tx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(reservations)) <= limit {
|
||||
return []string{}, nil
|
||||
}
|
||||
removedTopics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= limit; i-- {
|
||||
topic := reservations[i].Topic
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removedTopics = append(removedTopics, topic)
|
||||
}
|
||||
return removedTopics, nil
|
||||
})
|
||||
}
|
||||
|
||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
@@ -861,15 +909,12 @@ func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
|
||||
func (a *Manager) removeReservationAccessTx(tx *sql.Tx, username, topic string) error {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.resetTopicAccessTx(tx, Everyone, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
@@ -1150,7 +1195,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
|
||||
rows, err := a.db.Query(a.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1160,7 +1205,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||
|
||||
// Create user, add reservations and token
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
||||
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
|
||||
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead, 0))
|
||||
|
||||
u, err := a.User("user")
|
||||
require.Nil(t, err)
|
||||
@@ -439,8 +439,8 @@ func TestManager_Reservations(t *testing.T) {
|
||||
a := newTestManager(t, newManager, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
|
||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll, 0))
|
||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead, 0))
|
||||
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
|
||||
|
||||
reservations, err := a.Reservations("ben")
|
||||
@@ -523,7 +523,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
require.Nil(t, a.ChangeTier("ben", "pro"))
|
||||
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
|
||||
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll, 0))
|
||||
|
||||
ben, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
@@ -1076,7 +1076,7 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) {
|
||||
|
||||
// Add 10 reservations (pro tier allows that)
|
||||
for i := 0; i < 4; i++ {
|
||||
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite))
|
||||
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite, 0))
|
||||
}
|
||||
|
||||
// Downgrading will not work (too many reservations)
|
||||
@@ -2118,7 +2118,7 @@ func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
func TestStoreReservations(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead, 0))
|
||||
|
||||
reservations, err := manager.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
@@ -2133,8 +2133,8 @@ func TestStoreReservations(t *testing.T) {
|
||||
func TestStoreReservationsCount(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 0))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite, 0))
|
||||
|
||||
count, err := manager.ReservationsCount("phil")
|
||||
require.Nil(t, err)
|
||||
@@ -2145,7 +2145,7 @@ func TestStoreReservationsCount(t *testing.T) {
|
||||
func TestStoreHasReservation(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
has, err := manager.HasReservation("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
@@ -2160,7 +2160,7 @@ func TestStoreHasReservation(t *testing.T) {
|
||||
func TestStoreReservationOwner(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
owner, err := manager.ReservationOwner("mytopic")
|
||||
require.Nil(t, err)
|
||||
@@ -2172,6 +2172,26 @@ func TestStoreReservationOwner(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreAddReservationWithLimit(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
|
||||
// Adding reservations within limit succeeds
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 2))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionRead, 2))
|
||||
|
||||
// Adding a third reservation exceeds the limit
|
||||
require.Equal(t, ErrTooManyReservations, manager.AddReservation("phil", "topic3", PermissionRead, 2))
|
||||
|
||||
// Updating an existing reservation within the limit succeeds
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionRead, 2))
|
||||
|
||||
reservations, err := manager.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reservations, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreTiers(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
tier := &Tier{
|
||||
@@ -2431,7 +2451,7 @@ func TestStoreOtherAccessCount(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
count, err := manager.otherAccessCount("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
|
||||
Reference in New Issue
Block a user