mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Compare commits
6 Commits
31f0234098
...
66449bd19b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66449bd19b | ||
|
|
bedbb121e4 | ||
|
|
c4b8cfa756 | ||
|
|
c864a9baeb | ||
|
|
8afeb813d9 | ||
|
|
ea4739f79b |
2
.github/workflows/test.yaml
vendored
2
.github/workflows/test.yaml
vendored
@@ -42,5 +42,3 @@ jobs:
|
||||
run: make checkv
|
||||
- name: Run coverage
|
||||
run: make coverage
|
||||
- name: Upload coverage to codecov.io
|
||||
run: make coverage-upload
|
||||
|
||||
4
Makefile
4
Makefile
@@ -56,7 +56,6 @@ help:
|
||||
@echo " make race - Run tests with -race flag"
|
||||
@echo " make coverage - Run tests and show coverage"
|
||||
@echo " make coverage-html - Run tests and show coverage (as HTML)"
|
||||
@echo " make coverage-upload - Upload coverage results to codecov.io"
|
||||
@echo
|
||||
@echo "Lint/format:"
|
||||
@echo " make fmt - Run 'go fmt'"
|
||||
@@ -286,9 +285,6 @@ coverage-html:
|
||||
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go tool cover -html build/coverage/coverage.txt
|
||||
|
||||
coverage-upload:
|
||||
cd build/coverage && (curl -s https://codecov.io/bash | bash)
|
||||
|
||||
|
||||
# Lint/formatting targets
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -379,7 +379,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
if databaseURL != "" {
|
||||
pool, dbErr := db.OpenPostgres(databaseURL)
|
||||
pool, dbErr := pg.Open(databaseURL)
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
99
db/db.go
99
db/db.go
@@ -2,92 +2,37 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
const (
|
||||
paramMaxOpenConns = "pool_max_conns"
|
||||
paramMaxIdleConns = "pool_max_idle_conns"
|
||||
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||
|
||||
defaultMaxOpenConns = 10
|
||||
)
|
||||
|
||||
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||
// the DSN before passing it to the driver.
|
||||
func OpenPostgres(dsn string) (*sql.DB, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||
return err
|
||||
}
|
||||
q := u.Query()
|
||||
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
defer tx.Rollback()
|
||||
if err := f(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
db, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
if maxIdleConns > 0 {
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
if connMaxLifetime > 0 {
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
}
|
||||
if connMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
v, err := strconv.Atoi(s)
|
||||
// QueryTx executes a function within a database transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
d, err := time.ParseDuration(s)
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
return t, err
|
||||
}
|
||||
return d, nil
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
93
db/pg/pg.go
Normal file
93
db/pg/pg.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
const (
|
||||
paramMaxOpenConns = "pool_max_conns"
|
||||
paramMaxIdleConns = "pool_max_idle_conns"
|
||||
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||
|
||||
defaultMaxOpenConns = 10
|
||||
)
|
||||
|
||||
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||
// the DSN before passing it to the driver.
|
||||
func Open(dsn string) (*sql.DB, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
db, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(maxOpenConns)
|
||||
if maxIdleConns > 0 {
|
||||
db.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
if connMaxLifetime > 0 {
|
||||
db.SetConnMaxLifetime(connMaxLifetime)
|
||||
}
|
||||
if connMaxIdleTime > 0 {
|
||||
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("ping failed: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
q.Set("pool_max_conns", testPoolMaxConns)
|
||||
u.RawQuery = q.Encode()
|
||||
dsn = u.String()
|
||||
setupDB, err := db.OpenPostgres(dsn)
|
||||
setupDB, err := pg.Open(dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
@@ -39,7 +39,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
u.RawQuery = q.Encode()
|
||||
schemaDSN := u.String()
|
||||
t.Cleanup(func() {
|
||||
cleanDB, err := db.OpenPostgres(dsn)
|
||||
cleanDB, err := pg.Open(dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
@@ -54,7 +54,7 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
func CreateTestPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
schemaDSN := CreateTestPostgresSchema(t)
|
||||
testDB, err := db.OpenPostgres(schemaDSN)
|
||||
testDB, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
testDB.Close()
|
||||
|
||||
102
message/cache.go
102
message/cache.go
@@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
|
||||
func (c *Cache) DeleteMessages(ids ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
return err
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||
@@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
|
||||
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// First, get the message IDs of scheduled messages to be deleted
|
||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
|
||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close() // Close rows before executing delete in same transaction
|
||||
// Then delete the messages
|
||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close() // Close rows before executing delete in same transaction
|
||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ExpireMessages marks messages in the given topics as expired
|
||||
func (c *Cache) ExpireMessages(topics ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||
@@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||
|
||||
@@ -3,6 +3,8 @@ package message
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
@@ -55,34 +57,29 @@ const (
|
||||
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 14
|
||||
postgresCurrentSchemaVersion = 14
|
||||
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
} else if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
func setupNewPostgresDB(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
@@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
||||
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
@@ -178,11 +178,11 @@ func New(conf *Config) (*Server, error) {
|
||||
if payments.Available && conf.StripeSecretKey != "" {
|
||||
stripe = newStripeAPI()
|
||||
}
|
||||
// OpenPostgres shared PostgreSQL connection pool if configured
|
||||
// Open shared PostgreSQL connection pool if configured
|
||||
var pool *sql.DB
|
||||
if conf.DatabaseURL != "" {
|
||||
var err error
|
||||
pool, err = db.OpenPostgres(conf.DatabaseURL)
|
||||
pool, err = pg.Open(conf.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# pgimport
|
||||
|
||||
Migrates ntfy data from SQLite to PostgreSQL.
|
||||
One-off migration script to import ntfy data from SQLite to PostgreSQL.
|
||||
|
||||
This is **not** a generic migration tool. It only works with specific SQLite schema versions
|
||||
(message cache v14, user db v6, web push v1) and their corresponding PostgreSQL schemas.
|
||||
If your database versions differ, this tool will refuse to run.
|
||||
|
||||
## Build
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// pgimport is a one-off migration script to import ntfy data from SQLite to PostgreSQL.
|
||||
// It is not a generic migration tool. It expects specific schema versions for each database
|
||||
// (message cache v14, user db v6, web push v1) and will refuse to run if versions don't match.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -11,7 +14,7 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +36,7 @@ var flags = []cli.Flag{
|
||||
func main() {
|
||||
app := &cli.App{
|
||||
Name: "pgimport",
|
||||
Usage: "SQLite to PostgreSQL migration tool for ntfy",
|
||||
Usage: "One-off SQLite to PostgreSQL migration script for ntfy",
|
||||
UsageText: "pgimport [OPTIONS]",
|
||||
Flags: flags,
|
||||
Before: loadConfigFile("config", flags),
|
||||
@@ -79,7 +82,7 @@ func execImport(c *cli.Context) error {
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
pgDB, err := db.OpenPostgres(databaseURL)
|
||||
pgDB, err := pg.Open(databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.addUserTx(tx, username, hash, role, false)
|
||||
})
|
||||
}
|
||||
@@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error {
|
||||
if err := a.CanChangeUser(username); err != nil {
|
||||
return err
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.removeUserTx(tx, username)
|
||||
})
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
|
||||
if !AllowedUsername(user.Name) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.changePasswordHashTx(tx, username, hash)
|
||||
})
|
||||
}
|
||||
@@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error {
|
||||
if err := a.CanChangeUser(username); err != nil {
|
||||
return err
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.changeRoleTx(tx, username, role)
|
||||
})
|
||||
}
|
||||
@@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||
a.statsQueue = make(map[string]*Stats)
|
||||
a.mu.Unlock()
|
||||
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
|
||||
for userID, update := range statsQueue {
|
||||
log.
|
||||
@@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
|
||||
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
||||
// owner may either be a user (username), or the system (empty).
|
||||
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.allowAccessTx(tx, username, topicPattern, permission, false)
|
||||
})
|
||||
}
|
||||
@@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string
|
||||
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
||||
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
||||
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.resetAccessTx(tx, username, topicPattern)
|
||||
})
|
||||
}
|
||||
@@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
@@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string)
|
||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||
return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
||||
})
|
||||
}
|
||||
@@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
|
||||
a.tokenQueue = make(map[string]*TokenUpdate)
|
||||
a.mu.Unlock()
|
||||
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
|
||||
for tokenID, update := range tokenQueue {
|
||||
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
|
||||
@@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
|
||||
return fmt.Errorf("failed to provision users: %v", err)
|
||||
}
|
||||
|
||||
@@ -328,8 +328,7 @@ var (
|
||||
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
dbtest "heckel.io/ntfy/v2/db/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -36,7 +36,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
|
||||
t.Run("postgres", func(t *testing.T) {
|
||||
schemaDSN := dbtest.CreateTestPostgresSchema(t)
|
||||
f(t, func(config *Config) *Manager {
|
||||
pool, err := db.OpenPostgres(schemaDSN)
|
||||
pool, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
a, err := NewPostgresManager(pool, config)
|
||||
require.Nil(t, err)
|
||||
|
||||
32
user/util.go
32
user/util.go
@@ -113,35 +113,3 @@ func escapeUnderscore(s string) string {
|
||||
func unescapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "\\_", "_")
|
||||
}
|
||||
|
||||
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
||||
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := f(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// queryTx executes a function in a transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back.
|
||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -46,41 +47,38 @@ type queries struct {
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
||||
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
var subscriptionCount int
|
||||
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
var subscriptionID string
|
||||
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return ErrWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
||||
return db.ExecTx(s.db, func(tx *sql.Tx) error {
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
var subscriptionCount int
|
||||
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
var subscriptionID string
|
||||
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return ErrWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||
@@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription,
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
||||
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
return db.ExecTx(s.db, func(tx *sql.Tx) error {
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
||||
@@ -135,12 +130,13 @@ func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
||||
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
||||
if err != nil {
|
||||
return db.ExecTx(s.db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := tx.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
||||
|
||||
@@ -3,6 +3,8 @@ package webpush
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgres(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
func setupNewPostgres(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -109,29 +111,24 @@ func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
||||
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
func setupNewSQLite(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
|
||||
Reference in New Issue
Block a user