Extract ExecTx

This commit is contained in:
binwiederhier
2026-03-02 19:45:35 -05:00
parent 31f0234098
commit ea4739f79b
9 changed files with 222 additions and 260 deletions

View File

@@ -91,3 +91,36 @@ func extractDurationParam(q url.Values, key string, defaultValue time.Duration)
}
return d, nil
}
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// QueryTx executes a function within a database transaction and returns the result. If the function
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View File

@@ -9,6 +9,7 @@ import (
"sync"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
@@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
@@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// First, get the message IDs of scheduled messages to be deleted
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
// Then delete the messages
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return ids, nil
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
}
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
@@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender

View File

@@ -3,6 +3,8 @@ package message
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
// Initial PostgreSQL schema
@@ -55,34 +57,29 @@ const (
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 14
postgresCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewPostgresDB(db)
}
if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
} else if schemaVersion > postgresCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
}
return nil
}
func setupNewPostgresDB(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewPostgresDB(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
)
@@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
return nil
}
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return nil
})
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util"
@@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, hash, role, false)
})
}
@@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username)
})
}
@@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
return err
}
@@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordHashTx(tx, username, hash)
})
}
@@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role)
})
}
@@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.statsQueue = make(map[string]*Stats)
a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue {
log.
@@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false)
})
}
@@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error {
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.resetAccessTx(tx, username, topicPattern)
})
}
@@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
return err
}
@@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument
}
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
for _, topic := range topics {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err
@@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string)
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
})
}
@@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
a.tokenQueue = make(map[string]*TokenUpdate)
a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
@@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
return fmt.Errorf("failed to provision users: %v", err)
}

View File

@@ -113,35 +113,3 @@ func escapeUnderscore(s string) string {
func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
@@ -46,41 +47,38 @@ type queries struct {
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
var subscriptionCount int
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return db.ExecTx(s.db, func(tx *sql.Tx) error {
// Read number of subscriptions for subscriber IP address
var subscriptionCount int
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
return err
}
}
return tx.Commit()
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return err
}
}
return nil
})
}
// SubscriptionsForTopic returns all subscriptions for the given topic.
@@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription,
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err
return db.ExecTx(s.db, func(tx *sql.Tx) error {
for _, subscription := range subscriptions {
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.

View File

@@ -3,6 +3,8 @@ package webpush
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
const (
@@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error {
return nil
}
func setupNewPostgres(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewPostgres(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
)
const (
@@ -119,19 +121,16 @@ func setupSQLite(db *sql.DB) error {
return nil
}
func setupNewSQLite(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewSQLite(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {