Compare commits

...

6 Commits

Author SHA1 Message Date
binwiederhier
66449bd19b pgimport readme 2026-03-02 20:15:35 -05:00
binwiederhier
bedbb121e4 Remove codecov.io 2026-03-02 20:05:29 -05:00
binwiederhier
c4b8cfa756 Manual correction 2026-03-02 20:04:03 -05:00
binwiederhier
c864a9baeb Put more things in tx 2026-03-02 19:58:26 -05:00
binwiederhier
8afeb813d9 Move OpenPostgres 2026-03-02 19:52:36 -05:00
binwiederhier
ea4739f79b Extract ExecTx 2026-03-02 19:45:35 -05:00
19 changed files with 335 additions and 369 deletions

View File

@@ -42,5 +42,3 @@ jobs:
run: make checkv
- name: Run coverage
run: make coverage
- name: Upload coverage to codecov.io
run: make coverage-upload

View File

@@ -56,7 +56,6 @@ help:
@echo " make race - Run tests with -race flag"
@echo " make coverage - Run tests and show coverage"
@echo " make coverage-html - Run tests and show coverage (as HTML)"
@echo " make coverage-upload - Upload coverage results to codecov.io"
@echo
@echo "Lint/format:"
@echo " make fmt - Run 'go fmt'"
@@ -286,9 +285,6 @@ coverage-html:
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go tool cover -html build/coverage/coverage.txt
coverage-upload:
cd build/coverage && (curl -s https://codecov.io/bash | bash)
# Lint/formatting targets

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
}
if databaseURL != "" {
pool, dbErr := db.OpenPostgres(databaseURL)
pool, dbErr := pg.Open(databaseURL)
if dbErr != nil {
return nil, dbErr
}

View File

@@ -2,92 +2,37 @@ package db
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func OpenPostgres(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
return err
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
return tx.Commit()
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
// QueryTx executes a function within a database transaction and returns the result. If the function
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
var zero T
return zero, err
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
return t, err
}
return d, nil
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

93
db/pg/pg.go Normal file
View File

@@ -0,0 +1,93 @@
package pg
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func Open(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}

View File

@@ -8,7 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/util"
)
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode()
dsn = u.String()
setupDB, err := db.OpenPostgres(dsn)
setupDB, err := pg.Open(dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanDB, err := db.OpenPostgres(dsn)
cleanDB, err := pg.Open(dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
func CreateTestPostgres(t *testing.T) *sql.DB {
t.Helper()
schemaDSN := CreateTestPostgresSchema(t)
testDB, err := db.OpenPostgres(schemaDSN)
testDB, err := pg.Open(schemaDSN)
require.Nil(t, err)
t.Cleanup(func() {
testDB.Close()

View File

@@ -9,6 +9,7 @@ import (
"sync"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
@@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
@@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// First, get the message IDs of scheduled messages to be deleted
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
// Then delete the messages
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return ids, nil
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
}
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
@@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender

View File

@@ -3,6 +3,8 @@ package message
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
// Initial PostgreSQL schema
@@ -55,34 +57,29 @@ const (
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 14
postgresCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewPostgresDB(db)
}
if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
} else if schemaVersion > postgresCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
}
return nil
}
func setupNewPostgresDB(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewPostgresDB(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
)
@@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
return nil
}
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return nil
})
}
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return tx.Commit()
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return nil
})
}

View File

@@ -33,7 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
if payments.Available && conf.StripeSecretKey != "" {
stripe = newStripeAPI()
}
// OpenPostgres shared PostgreSQL connection pool if configured
// Open shared PostgreSQL connection pool if configured
var pool *sql.DB
if conf.DatabaseURL != "" {
var err error
pool, err = db.OpenPostgres(conf.DatabaseURL)
pool, err = pg.Open(conf.DatabaseURL)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,10 @@
# pgimport
Migrates ntfy data from SQLite to PostgreSQL.
One-off migration script to import ntfy data from SQLite to PostgreSQL.
This is **not** a generic migration tool. It only works with specific SQLite schema versions
(message cache v14, user db v6, web push v1) and their corresponding PostgreSQL schemas.
If your database versions differ, this tool will refuse to run.
## Build

View File

@@ -1,3 +1,6 @@
// pgimport is a one-off migration script to import ntfy data from SQLite to PostgreSQL.
// It is not a generic migration tool. It expects specific schema versions for each database
// (message cache v14, user db v6, web push v1) and will refuse to run if versions don't match.
package main
import (
@@ -11,7 +14,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
)
const (
@@ -33,7 +36,7 @@ var flags = []cli.Flag{
func main() {
app := &cli.App{
Name: "pgimport",
Usage: "SQLite to PostgreSQL migration tool for ntfy",
Usage: "One-off SQLite to PostgreSQL migration script for ntfy",
UsageText: "pgimport [OPTIONS]",
Flags: flags,
Before: loadConfigFile("config", flags),
@@ -79,7 +82,7 @@ func execImport(c *cli.Context) error {
}
fmt.Println()
pgDB, err := db.OpenPostgres(databaseURL)
pgDB, err := pg.Open(databaseURL)
if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util"
@@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, hash, role, false)
})
}
@@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username)
})
}
@@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
return err
}
@@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordHashTx(tx, username, hash)
})
}
@@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role)
})
}
@@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.statsQueue = make(map[string]*Stats)
a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue {
log.
@@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false)
})
}
@@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error {
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.resetAccessTx(tx, username, topicPattern)
})
}
@@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
return err
}
@@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument
}
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
for _, topic := range topics {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err
@@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string)
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
})
}
@@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
a.tokenQueue = make(map[string]*TokenUpdate)
a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
@@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
if err != nil {
return err
}
return execTx(a.db, func(tx *sql.Tx) error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
return fmt.Errorf("failed to provision users: %v", err)
}

View File

@@ -328,8 +328,7 @@ var (
func setupSQLite(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewSQLite(db)
}
if schemaVersion == sqliteCurrentSchemaVersion {

View File

@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/util"
)
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
t.Run("postgres", func(t *testing.T) {
schemaDSN := dbtest.CreateTestPostgresSchema(t)
f(t, func(config *Config) *Manager {
pool, err := db.OpenPostgres(schemaDSN)
pool, err := pg.Open(schemaDSN)
require.Nil(t, err)
a, err := NewPostgresManager(pool, config)
require.Nil(t, err)

View File

@@ -113,35 +113,3 @@ func escapeUnderscore(s string) string {
func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
@@ -46,41 +47,38 @@ type queries struct {
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
var subscriptionCount int
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return db.ExecTx(s.db, func(tx *sql.Tx) error {
// Read number of subscriptions for subscriber IP address
var subscriptionCount int
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
return err
}
}
return tx.Commit()
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return err
}
}
return nil
})
}
// SubscriptionsForTopic returns all subscriptions for the given topic.
@@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription,
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err
return db.ExecTx(s.db, func(tx *sql.Tx) error {
for _, subscription := range subscriptions {
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
}
return tx.Commit()
return nil
})
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
@@ -135,12 +130,13 @@ func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
if err != nil {
return db.ExecTx(s.db, func(tx *sql.Tx) error {
if _, err := tx.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix()); err != nil {
return err
}
_, err := tx.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription)
return err
}
_, err = s.db.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription)
return err
})
}
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is

View File

@@ -3,6 +3,8 @@ package webpush
import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
)
const (
@@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error {
return nil
}
func setupNewPostgres(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewPostgres(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
)
const (
@@ -109,29 +111,24 @@ func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
func setupSQLite(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
return setupNewSQLite(db)
}
if schemaVersion > sqliteCurrentSchemaVersion {
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
return nil
}
func setupNewSQLite(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
func setupNewSQLite(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
})
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {