diff --git a/db/db.go b/db/db.go index e21b0590..3b6585c1 100644 --- a/db/db.go +++ b/db/db.go @@ -91,3 +91,36 @@ func extractDurationParam(q url.Values, key string, defaultValue time.Duration) } return d, nil } + +// ExecTx executes a function within a database transaction. If the function returns an error, +// the transaction is rolled back. Otherwise, the transaction is committed. +func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := f(tx); err != nil { + return err + } + return tx.Commit() +} + +// QueryTx executes a function within a database transaction and returns the result. If the function +// returns an error, the transaction is rolled back. Otherwise, the transaction is committed. +func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { + tx, err := db.Begin() + if err != nil { + var zero T + return zero, err + } + defer tx.Rollback() + t, err := f(tx) + if err != nil { + return t, err + } + if err := tx.Commit(); err != nil { + return t, err + } + return t, nil +} diff --git a/message/cache.go b/message/cache.go index 953a6f7f..3b12af3e 100644 --- a/message/cache.go +++ b/message/cache.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" @@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) { func (c *Cache) DeleteMessages(ids ...string) error { c.maybeLock() defer c.maybeUnlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, id := range ids { - if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil { - return err + return db.ExecTx(c.db, func(tx *sql.Tx) error { + for _, id := range ids { + if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil { + return err + } } - } - return tx.Commit() + return nil + }) } // DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. @@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error { func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { c.maybeLock() defer c.maybeUnlock() - tx, err := c.db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - // First, get the message IDs of scheduled messages to be deleted - rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID) - if err != nil { - return nil, err - } - defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { + return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) { + rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID) + if err != nil { return nil, err } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - rows.Close() // Close rows before executing delete in same transaction - // Then delete the messages - if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return ids, nil + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + rows.Close() // Close rows before executing delete in same transaction + if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil { + return nil, err + } + return ids, nil + }) } // ExpireMessages marks messages in the given topics as expired func (c *Cache) ExpireMessages(topics ...string) error { c.maybeLock() defer c.maybeUnlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, t := range topics { - if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil { - return err + return db.ExecTx(c.db, func(tx *sql.Tx) error { + for _, t := range topics { + if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil { + return err + } } - } - return tx.Commit() + return nil + }) } // AttachmentsExpired returns message IDs with expired attachments that have not been deleted @@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) { func (c *Cache) MarkAttachmentsDeleted(ids ...string) error { c.maybeLock() defer c.maybeUnlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, id := range ids { - if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil { - return err + return db.ExecTx(c.db, func(tx *sql.Tx) error { + for _, id := range ids { + if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil { + return err + } } - } - return tx.Commit() + return nil + }) } // AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender diff --git a/message/cache_postgres_schema.go b/message/cache_postgres_schema.go index 4b5730a0..4a539d0e 100644 --- a/message/cache_postgres_schema.go +++ b/message/cache_postgres_schema.go @@ -3,6 +3,8 @@ package message import ( "database/sql" "fmt" + + "heckel.io/ntfy/v2/db" ) // Initial PostgreSQL schema @@ -55,34 +57,29 @@ const ( // PostgreSQL schema management queries const ( - pgCurrentSchemaVersion = 14 + postgresCurrentSchemaVersion = 14 postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)` postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'` ) func setupPostgres(db *sql.DB) error { var schemaVersion int - err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion) - if err != nil { + if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil { return setupNewPostgresDB(db) - } - if schemaVersion > pgCurrentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion) + } else if schemaVersion > postgresCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion) } return nil } -func setupNewPostgresDB(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { - return err - } - if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil { - return err - } - return tx.Commit() +func setupNewPostgresDB(sqlDB *sql.DB) error { + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { + return err + } + if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil { + return err + } + return nil + }) } diff --git a/message/cache_sqlite_schema.go b/message/cache_sqlite_schema.go index 00a26c09..da744887 100644 --- a/message/cache_sqlite_schema.go +++ b/message/cache_sqlite_schema.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/log" ) @@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error { return nil } -func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error { +func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error { log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { - return err - } - if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil { - return err - } - return tx.Commit() + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil { + return err + } + return nil + }) } -func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error { +func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error { log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil { - return err - } - return tx.Commit() + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil { + return err + } + return nil + }) } -func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error { +func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error { log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil { - return err - } - return tx.Commit() + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil { + return err + } + return nil + }) } -func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error { +func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error { log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil { - return err - } - return tx.Commit() + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil { + return err + } + return nil + }) } -func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error { +func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error { log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil { - return err - } - return tx.Commit() + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil { + return err + } + return nil + }) } diff --git a/user/manager.go b/user/manager.go index a1412be3..cd37cb3b 100644 --- a/user/manager.go +++ b/user/manager.go @@ -13,6 +13,7 @@ import ( "time" "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" @@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err if err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.addUserTx(tx, username, hash, role, false) }) } @@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error { if err := a.CanChangeUser(username); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.removeUserTx(tx, username) }) } @@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error { if !AllowedUsername(user.Name) { return ErrInvalidArgument } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { if err := a.resetUserAccessTx(tx, user.Name); err != nil { return err } @@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error { if err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.changePasswordHashTx(tx, username, hash) }) } @@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error { if err := a.CanChangeUser(username); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.changeRoleTx(tx, username, role) }) } @@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error { a.statsQueue = make(map[string]*Stats) a.mu.Unlock() - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue)) for userID, update := range statsQueue { log. @@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error { // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // owner may either be a user (username), or the system (empty). func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.allowAccessTx(tx, username, topicPattern, permission, false) }) } @@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). func (a *Manager) ResetAccess(username string, topicPattern string) error { - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { return a.resetAccessTx(tx, username, topicPattern) }) } @@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { return ErrInvalidArgument } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil { return err } @@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { return ErrInvalidArgument } } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { for _, topic := range topics { if err := a.resetTopicAccessTx(tx, username, topic); err != nil { return err @@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string) // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { - return queryTx(a.db, func(tx *sql.Tx) (*Token, error) { + return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) { return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) }) } @@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error { a.tokenQueue = make(map[string]*TokenUpdate) a.mu.Unlock() - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) @@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error { if err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil { return fmt.Errorf("failed to provision users: %v", err) } diff --git a/user/util.go b/user/util.go index 6354464d..16f6cc09 100644 --- a/user/util.go +++ b/user/util.go @@ -113,35 +113,3 @@ func escapeUnderscore(s string) string { func unescapeUnderscore(s string) string { return strings.ReplaceAll(s, "\\_", "_") } - -// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. -func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if err := f(tx); err != nil { - return err - } - return tx.Commit() -} - -// queryTx executes a function in a transaction and returns the result. If the function -// returns an error, the transaction is rolled back. -func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { - tx, err := db.Begin() - if err != nil { - var zero T - return zero, err - } - defer tx.Rollback() - t, err := f(tx) - if err != nil { - return t, err - } - if err := tx.Commit(); err != nil { - return t, err - } - return t, nil -} diff --git a/webpush/store.go b/webpush/store.go index 20ff19cc..ab5c517b 100644 --- a/webpush/store.go +++ b/webpush/store.go @@ -6,6 +6,7 @@ import ( "net/netip" "time" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/util" ) @@ -46,41 +47,38 @@ type queries struct { // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Read number of subscriptions for subscriber IP address - var subscriptionCount int - if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil { - return err - } - // Read existing subscription ID for endpoint (or create new ID) - var subscriptionID string - if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) { - if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { - return ErrWebPushTooManySubscriptions - } - subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) - } else if err != nil { - return err - } - // Insert or update subscription - updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { - return err - } - // Replace all subscription topics - if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil { - return err - } - for _, topic := range topics { - if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil { + return db.ExecTx(s.db, func(tx *sql.Tx) error { + // Read number of subscriptions for subscriber IP address + var subscriptionCount int + if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil { return err } - } - return tx.Commit() + // Read existing subscription ID for endpoint (or create new ID) + var subscriptionID string + if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) { + if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { + return ErrWebPushTooManySubscriptions + } + subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) + } else if err != nil { + return err + } + // Insert or update subscription + updatedAt, warnedAt := time.Now().Unix(), 0 + if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { + return err + } + // Replace all subscription topics + if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil { + return err + } + for _, topic := range topics { + if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil { + return err + } + } + return nil + }) } // SubscriptionsForTopic returns all subscriptions for the given topic. @@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, subscription := range subscriptions { - if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil { - return err + return db.ExecTx(s.db, func(tx *sql.Tx) error { + for _, subscription := range subscriptions { + if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil { + return err + } } - } - return tx.Commit() + return nil + }) } // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index d9c61470..ec541d37 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -3,6 +3,8 @@ package webpush import ( "database/sql" "fmt" + + "heckel.io/ntfy/v2/db" ) const ( @@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error { return nil } -func setupNewPostgres(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { - return err - } - if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil { - return err - } - return tx.Commit() +func setupNewPostgres(sqlDB *sql.DB) error { + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { + return err + } + if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil { + return err + } + return nil + }) } diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index d41e1c8d..90ca14a8 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -5,6 +5,8 @@ import ( "fmt" _ "github.com/mattn/go-sqlite3" // SQLite driver + + "heckel.io/ntfy/v2/db" ) const ( @@ -119,19 +121,16 @@ func setupSQLite(db *sql.DB) error { return nil } -func setupNewSQLite(db *sql.DB) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil { - return err - } - if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil { - return err - } - return tx.Commit() +func setupNewSQLite(sqlDB *sql.DB) error { + return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil { + return err + } + return nil + }) } func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {