Extract ExecTx

This commit is contained in:
binwiederhier
2026-03-02 19:45:35 -05:00
parent 31f0234098
commit ea4739f79b
9 changed files with 222 additions and 260 deletions

View File

@@ -9,6 +9,7 @@ import (
"sync"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
@@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
@@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// First, get the message IDs of scheduled messages to be deleted
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
// Then delete the messages
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return ids, nil
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
}
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
@@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender

View File

@@ -3,6 +3,8 @@ package message
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
// Initial PostgreSQL schema
@@ -55,34 +57,29 @@ const (
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 14
postgresCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewPostgresDB(db)
}
if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
} else if schemaVersion > postgresCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
}
return nil
}
func setupNewPostgresDB(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewPostgresDB(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
)
@@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
return nil
}
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return nil
})
}