diff --git a/db/types.go b/db/types.go index 534d6168..137753a4 100644 --- a/db/types.go +++ b/db/types.go @@ -11,6 +11,12 @@ type Beginner interface { Begin() (*sql.Tx, error) } +// Querier is an interface for types that can execute SQL queries. +// *sql.DB, *sql.Tx, and *DB all implement this. +type Querier interface { + Query(query string, args ...any) (*sql.Rows, error) +} + // Host pairs a *sql.DB with the host:port it was opened against. type Host struct { Addr string // "host:port" diff --git a/docs/releases.md b/docs/releases.md index fe3e2bd9..4d5e6c02 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1795,4 +1795,5 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release **Bug fixes + maintenance:** +* Fix race condition in web push subscription causing FK constraint violation when concurrent requests hit the same endpoint * Route authorization query to read-only database replica to reduce primary database load diff --git a/go.mod b/go.mod index c073d6aa..c5bb4968 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.25.0 require ( cloud.google.com/go/firestore v1.21.0 // indirect - cloud.google.com/go/storage v1.61.1 // indirect + cloud.google.com/go/storage v1.61.3 // indirect github.com/BurntSushi/toml v1.6.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/emersion/go-smtp v0.18.0 diff --git a/go.sum b/go.sum index 1c6eada9..781b9856 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7 cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= -cloud.google.com/go/storage v1.61.1 h1:VELCSvZKiSw0AS1k3so5mKGy3CB7bTCYD8EHhTF42bY= -cloud.google.com/go/storage v1.61.1/go.mod h1:k30/hwYfd0M8aULYbPkQLgNf+SFcdjlRHvLMXggw18E= +cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg= +cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk= cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8= diff --git a/server/log.go b/server/log.go index e4ddc178..432f6743 100644 --- a/server/log.go +++ b/server/log.go @@ -35,7 +35,7 @@ const ( ) var ( - normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage} + normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage, http.StatusRequestEntityTooLarge} rateLimitingErrorCodes = []int{http.StatusTooManyRequests, http.StatusRequestEntityTooLarge} ) diff --git a/server/server_account.go b/server/server_account.go index 19a14042..7b719533 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -3,14 +3,15 @@ package server import ( "encoding/json" "errors" - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/model" - "heckel.io/ntfy/v2/user" - "heckel.io/ntfy/v2/util" "net/http" "net/netip" "strings" "time" + + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/user" + "heckel.io/ntfy/v2/util" ) const ( @@ -455,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ return errHTTPUnauthorized } else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil { return errHTTPConflictTopicReserved - } else if u.IsUser() { - hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic) - if err != nil { - return err - } - if !hasReservation { - reservations, err := s.userManager.ReservationsCount(u.Name) - if err != nil { - return err - } else if reservations >= u.Tier.ReservationLimit { - return errHTTPTooManyRequestsLimitReservations - } - } } - // Actually add the reservation + // Actually add the reservation (with limit check inside the transaction to avoid races) logvr(v, r). Tag(tagAccount). Fields(log.Context{ @@ -477,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ "everyone": everyone.String(), }). Debug("Adding topic reservation") - if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil { + var limit int64 + if u.IsUser() && u.Tier != nil { + limit = u.Tier.ReservationLimit + } + if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil { + if errors.Is(err, user.ErrTooManyReservations) { + return errHTTPTooManyRequestsLimitReservations + } return err } // Kill existing subscribers @@ -530,22 +525,15 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R // and marks associated messages for the topics as deleted. This also eventually deletes attachments. // The process relies on the manager to perform the actual deletions (see runManager). func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error { - reservations, err := s.userManager.Reservations(u.Name) + removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit) if err != nil { return err - } else if int64(len(reservations)) <= reservationsLimit { + } else if len(removedTopics) == 0 { logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove") return nil } - topics := make([]string, 0) - for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- { - topics = append(topics, reservations[i].Topic) - } - logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", ")) - if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil { - return err - } - if err := s.messageCache.ExpireMessages(topics...); err != nil { + logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", ")) + if err := s.messageCache.ExpireMessages(removedTopics...); err != nil { return err } go s.pruneMessages() diff --git a/server/server_account_test.go b/server/server_account_test.go index 7bf6f6d5..0360fcd4 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -503,7 +503,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) { })) require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser, false)) require.Nil(t, s.userManager.ChangeTier("noadmin1", "pro")) - require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll, 0)) require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser, false)) require.Nil(t, s.userManager.ChangeTier("noadmin2", "pro")) diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 06523edc..9873d6d8 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -478,8 +478,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) - require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) - require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0)) + require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll, 0)) // Add billing details u, err := s.userManager.User("phil") @@ -589,7 +589,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) { })) require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false)) require.Nil(t, s.userManager.ChangeTier("phil", "pro")) - require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) + require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0)) // Add billing details u, err := s.userManager.User("phil") diff --git a/user/manager.go b/user/manager.go index 28243a24..99bd705e 100644 --- a/user/manager.go +++ b/user/manager.go @@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error { t, err := a.Tier(tier) if err != nil { return err - } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { - return err } - if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil { - return err - } - return nil + return db.ExecTx(a.db, func(tx *sql.Tx) error { + if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil { + return err + } + if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil { + return err + } + return nil + }) } // ResetTier removes the tier from the given user func (a *Manager) ResetTier(username string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument - } else if err := a.checkReservationsLimit(username, 0); err != nil { - return err } - _, err := a.db.Exec(a.queries.deleteUserTier, username) - return err + return db.ExecTx(a.db, func(tx *sql.Tx) error { + if err := a.checkReservationsLimitTx(tx, username, 0); err != nil { + return err + } + if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil { + return err + } + return nil + }) } -func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { - u, err := a.User(username) +func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error { + u, err := a.userTx(tx, username) if err != nil { return err } if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit { - reservations, err := a.Reservations(username) + reservations, err := a.reservationsTx(tx, username) if err != nil { return err } else if int64(len(reservations)) > reservationsLimit { @@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error { // User returns the user with the given username if it exists, or ErrUserNotFound otherwise func (a *Manager) User(username string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByName, username) + return a.userTx(a.db, username) +} + +func (a *Manager) userTx(tx db.Querier, username string) (*User, error) { + rows, err := tx.Query(a.queries.selectUserByName, username) if err != nil { return nil, err } @@ -415,7 +427,7 @@ func (a *Manager) userByToken(token string) (*User, error) { // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID) + rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID) if err != nil { return nil, err } @@ -713,16 +725,35 @@ func (a *Manager) Grants(username string) ([]Grant, error) { // AddReservation creates two access control entries for the given topic: one with full read/write // access for the given user, and one for Everyone with the given permission. Both entries are -// created atomically in a single transaction. -func (a *Manager) AddReservation(username string, topic string, everyone Permission) error { +// created atomically in a single transaction. If limit is > 0, the reservation count is checked +// inside the transaction and ErrTooManyReservations is returned if the limit would be exceeded. +func (a *Manager) AddReservation(username string, topic string, everyone Permission, limit int64) error { if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { return ErrInvalidArgument } return db.ExecTx(a.db, func(tx *sql.Tx) error { - if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil { + if limit > 0 { + hasReservation, err := a.hasReservationTx(tx, username, topic) + if err != nil { + return err + } + if !hasReservation { + count, err := a.reservationsCountTx(tx, username) + if err != nil { + return err + } + if count >= limit { + return ErrTooManyReservations + } + } + } + if _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), true, true, username, username, false); err != nil { return err } - return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username) + if _, err := tx.Exec(a.queries.upsertUserAccess, Everyone, toSQLWildcard(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil { + return err + } + return nil }) } @@ -740,10 +771,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { } return db.ExecTx(a.db, func(tx *sql.Tx) error { for _, topic := range topics { - if err := a.resetTopicAccessTx(tx, username, topic); err != nil { - return err - } - if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil { + if err := a.removeReservationAccessTx(tx, username, topic); err != nil { return err } } @@ -753,7 +781,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { // Reservations returns all user-owned topics, and the associated everyone-access func (a *Manager) Reservations(username string) ([]Reservation, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username) + return a.reservationsTx(a.db.ReadOnly(), username) +} + +func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) { + rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username) if err != nil { return nil, err } @@ -779,7 +811,11 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { // HasReservation returns true if the given topic access is owned by the user func (a *Manager) HasReservation(username, topic string) (bool, error) { - rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) + return a.hasReservationTx(a.db, username, topic) +} + +func (a *Manager) hasReservationTx(tx db.Querier, username, topic string) (bool, error) { + rows, err := tx.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) if err != nil { return false, err } @@ -796,7 +832,11 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) { // ReservationsCount returns the number of reservations owned by this user func (a *Manager) ReservationsCount(username string) (int64, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username) + return a.reservationsCountTx(a.db, username) +} + +func (a *Manager) reservationsCountTx(tx db.Querier, username string) (int64, error) { + rows, err := tx.Query(a.queries.selectUserReservationsCount, username) if err != nil { return 0, err } @@ -828,6 +868,30 @@ func (a *Manager) ReservationOwner(topic string) (string, error) { return ownerUserID, nil } +// RemoveExcessReservations removes reservations that exceed the given limit for the user. +// It returns the list of topics whose reservations were removed. The read and removal are +// performed atomically in a single transaction to avoid issues with stale replica data. +func (a *Manager) RemoveExcessReservations(username string, limit int64) ([]string, error) { + return db.QueryTx(a.db, func(tx *sql.Tx) ([]string, error) { + reservations, err := a.reservationsTx(tx, username) + if err != nil { + return nil, err + } + if int64(len(reservations)) <= limit { + return []string{}, nil + } + removedTopics := make([]string, 0) + for i := int64(len(reservations)) - 1; i >= limit; i-- { + topic := reservations[i].Topic + if err := a.removeReservationAccessTx(tx, username, topic); err != nil { + return nil, err + } + removedTopics = append(removedTopics, topic) + } + return removedTopics, nil + }) +} + // otherAccessCount returns the number of access entries for the given topic that are not owned by the user func (a *Manager) otherAccessCount(username, topic string) (int, error) { rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) @@ -845,14 +909,11 @@ func (a *Manager) otherAccessCount(username, topic string) (int, error) { return count, nil } -func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error { - if !AllowedUsername(username) && username != Everyone { - return ErrInvalidArgument - } else if !AllowedTopicPattern(topic) { - return ErrInvalidArgument +func (a *Manager) removeReservationAccessTx(tx *sql.Tx, username, topic string) error { + if err := a.resetTopicAccessTx(tx, username, topic); err != nil { + return err } - _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false) - return err + return a.resetTopicAccessTx(tx, Everyone, topic) } func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error { @@ -1134,7 +1195,7 @@ func (a *Manager) Tiers() ([]*Tier, error) { // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist func (a *Manager) Tier(code string) (*Tier, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code) + rows, err := a.db.Query(a.queries.selectTierByCode, code) if err != nil { return nil, err } @@ -1144,7 +1205,7 @@ func (a *Manager) Tier(code string) (*Tier, error) { // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { - rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID) + rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID) if err != nil { return nil, err } diff --git a/user/manager_test.go b/user/manager_test.go index 3e023909..c8e619cf 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -226,7 +226,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { // Create user, add reservations and token require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) - require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead)) + require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead, 0)) u, err := a.User("user") require.Nil(t, err) @@ -439,8 +439,8 @@ func TestManager_Reservations(t *testing.T) { a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) - require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll)) - require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) + require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll, 0)) + require.Nil(t, a.AddReservation("ben", "readme", PermissionRead, 0)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) reservations, err := a.Reservations("ben") @@ -523,7 +523,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { })) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.ChangeTier("ben", "pro")) - require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll)) + require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll, 0)) ben, err := a.User("ben") require.Nil(t, err) @@ -1076,7 +1076,7 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) { // Add 10 reservations (pro tier allows that) for i := 0; i < 4; i++ { - require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite)) + require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite, 0)) } // Downgrading will not work (too many reservations) @@ -2118,7 +2118,7 @@ func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) { func TestStoreReservations(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) - require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead, 0)) reservations, err := manager.Reservations("phil") require.Nil(t, err) @@ -2133,8 +2133,8 @@ func TestStoreReservations(t *testing.T) { func TestStoreReservationsCount(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) - require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite)) - require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite)) + require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 0)) + require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite, 0)) count, err := manager.ReservationsCount("phil") require.Nil(t, err) @@ -2145,7 +2145,7 @@ func TestStoreReservationsCount(t *testing.T) { func TestStoreHasReservation(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) - require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0)) has, err := manager.HasReservation("phil", "mytopic") require.Nil(t, err) @@ -2160,7 +2160,7 @@ func TestStoreHasReservation(t *testing.T) { func TestStoreReservationOwner(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) - require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0)) owner, err := manager.ReservationOwner("mytopic") require.Nil(t, err) @@ -2172,6 +2172,26 @@ func TestStoreReservationOwner(t *testing.T) { }) } +func TestStoreAddReservationWithLimit(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + + // Adding reservations within limit succeeds + require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 2)) + require.Nil(t, manager.AddReservation("phil", "topic2", PermissionRead, 2)) + + // Adding a third reservation exceeds the limit + require.Equal(t, ErrTooManyReservations, manager.AddReservation("phil", "topic3", PermissionRead, 2)) + + // Updating an existing reservation within the limit succeeds + require.Nil(t, manager.AddReservation("phil", "topic1", PermissionRead, 2)) + + reservations, err := manager.Reservations("phil") + require.Nil(t, err) + require.Len(t, reservations, 2) + }) +} + func TestStoreTiers(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { tier := &Tier{ @@ -2431,7 +2451,7 @@ func TestStoreOtherAccessCount(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) { require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false)) - require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite)) + require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite, 0)) count, err := manager.otherAccessCount("phil", "mytopic") require.Nil(t, err) diff --git a/web/package-lock.json b/web/package-lock.json index bec8660f..29c0845e 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -3642,9 +3642,9 @@ "license": "MIT" }, "node_modules/baseline-browser-mapping": { - "version": "2.10.0", - "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.0.tgz", - "integrity": "sha512-lIyg0szRfYbiy67j9KN8IyeD7q7hcmqnJ1ddWmNt19ItGpNN64mnllmxUNFIOdOm6by97jlL6wfpTTJrmnjWAA==", + "version": "2.10.8", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz", + "integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==", "dev": true, "license": "Apache-2.0", "bin": { @@ -3766,9 +3766,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001777", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001777.tgz", - "integrity": "sha512-tmN+fJxroPndC74efCdp12j+0rk0RHwV5Jwa1zWaFVyw2ZxAuPeG8ZgWC3Wz7uSjT3qMRQ5XHZ4COgQmsCMJAQ==", + "version": "1.0.30001779", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001779.tgz", + "integrity": "sha512-U5og2PN7V4DMgF50YPNtnZJGWVLFjjsN3zb6uMT5VGYIewieDj1upwfuVNXf4Kor+89c3iCRJnSzMD5LmTvsfA==", "dev": true, "funding": [ { @@ -4203,9 +4203,9 @@ } }, "node_modules/electron-to-chromium": { - "version": "1.5.307", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.307.tgz", - "integrity": "sha512-5z3uFKBWjiNR44nFcYdkcXjKMbg5KXNdciu7mhTPo9tB7NbqSNP2sSnGR+fqknZSCwKkBN+oxiiajWs4dT6ORg==", + "version": "1.5.313", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.313.tgz", + "integrity": "sha512-QBMrTWEf00GXZmJyx2lbYD45jpI3TUFnNIzJ5BBc8piGUDwMPa1GV6HJWTZVvY/eiN3fSopl7NRbgGp9sZ9LTA==", "dev": true, "license": "ISC" }, @@ -4324,9 +4324,9 @@ } }, "node_modules/es-iterator-helpers": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.0.tgz", - "integrity": "sha512-04cg8iJFDOxWcYlu0GFFWgs7vtaEPCmr5w1nrj9V3z3axu/48HCMwK6VMp45Zh3ZB+xLP1ifbJfrq86+1ypKKQ==", + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.1.tgz", + "integrity": "sha512-zWwRvqWiuBPr0muUG/78cW3aHROFCNIQ3zpmYDpwdbnt2m+xlNyRWpHBpa2lJjSBit7BQ+RXA1iwbSmu5yJ/EQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7043,9 +7043,9 @@ } }, "node_modules/path-scurry/node_modules/lru-cache": { - "version": "11.2.6", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.6.tgz", - "integrity": "sha512-ESL2CrkS/2wTPfuend7Zhkzo2u0daGJ/A2VucJOgQ/C48S/zB8MMeMHSGKYpXhIjbPxfuezITkaBH1wqv00DDQ==", + "version": "11.2.7", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz", + "integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==", "dev": true, "license": "BlueOak-1.0.0", "engines": { @@ -8307,9 +8307,9 @@ } }, "node_modules/terser": { - "version": "5.46.0", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz", - "integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==", + "version": "5.46.1", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.1.tgz", + "integrity": "sha512-vzCjQO/rgUuK9sf8VJZvjqiqiHFaZLnOiimmUuOKODxWL8mm/xua7viT7aqX7dgPY60otQjUotzFMmCB4VdmqQ==", "dev": true, "license": "BSD-2-Clause", "dependencies": { diff --git a/webpush/store.go b/webpush/store.go index 02b7552e..1a9825f5 100644 --- a/webpush/store.go +++ b/webpush/store.go @@ -63,9 +63,10 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, } else if err != nil { return err } - // Insert or update subscription + // Insert or update subscription, and read back the actual ID (which may differ from + // the generated one if another request for the same endpoint raced us and inserted first) updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { + if err := tx.QueryRow(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt).Scan(&subscriptionID); err != nil { return err } // Replace all subscription topics diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index 1c9adf0a..84168d89 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -53,6 +53,7 @@ const ( VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (endpoint) DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at + RETURNING id ` postgresUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2` postgresUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2` diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index fcf49fcf..7677f1ce 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -56,8 +56,9 @@ const ( sqliteUpsertSubscriptionQuery = ` INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT (endpoint) + ON CONFLICT (endpoint) DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at + RETURNING id ` sqliteUpdateSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` sqliteUpdateSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`