diff --git a/db/db.go b/db/db.go index 66412ede..6586e763 100644 --- a/db/db.go +++ b/db/db.go @@ -16,12 +16,6 @@ const ( replicaHealthCheckTimeout = 10 * time.Second ) -// Beginner is an interface for types that can begin a database transaction. -// Both *sql.DB and *DB implement this. -type Beginner interface { - Begin() (*sql.Tx, error) -} - // DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods // delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica // (round-robin), falling back to the primary if no replicas are configured or all are unhealthy. @@ -32,13 +26,6 @@ type DB struct { cancel context.CancelFunc } -// Host pairs a *sql.DB with the host:port it was opened against. -type Host struct { - Addr string // "host:port" - DB *sql.DB - healthy atomic.Bool -} - // New creates a new DB that wraps the given primary and optional replica connections. // If replicas is nil or empty, ReadOnly() simply returns the primary. // Replicas start unhealthy and are checked immediately by a background goroutine. @@ -55,12 +42,6 @@ func New(primary *Host, replicas []*Host) *DB { return d } -// Primary returns the underlying primary *sql.DB. This is only intended for -// one-time schema setup during store initialization, not for regular queries. -func (d *DB) Primary() *sql.DB { - return d.primary.DB -} - // Query delegates to the primary database. func (d *DB) Query(query string, args ...any) (*sql.Rows, error) { return d.primary.DB.Query(query, args...) @@ -86,13 +67,10 @@ func (d *DB) Ping() error { return d.primary.DB.Ping() } -// Close closes the primary database and all replicas, and stops the health-check goroutine. -func (d *DB) Close() error { - d.cancel() - for _, r := range d.replicas { - r.DB.Close() - } - return d.primary.DB.Close() +// Primary returns the underlying primary *sql.DB. This is only intended for +// one-time schema setup during store initialization, not for regular queries. +func (d *DB) Primary() *sql.DB { + return d.primary.DB } // ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy @@ -112,6 +90,15 @@ func (d *DB) ReadOnly() *sql.DB { return d.primary.DB } +// Close closes the primary database and all replicas, and stops the health-check goroutine. +func (d *DB) Close() error { + d.cancel() + for _, r := range d.replicas { + r.DB.Close() + } + return d.primary.DB.Close() +} + // healthCheckLoop checks replicas immediately, then periodically on a ticker. func (d *DB) healthCheckLoop(ctx context.Context) { select { diff --git a/db/types.go b/db/types.go new file mode 100644 index 00000000..534d6168 --- /dev/null +++ b/db/types.go @@ -0,0 +1,19 @@ +package db + +import ( + "database/sql" + "sync/atomic" +) + +// Beginner is an interface for types that can begin a database transaction. +// Both *sql.DB and *DB implement this. +type Beginner interface { + Begin() (*sql.Tx, error) +} + +// Host pairs a *sql.DB with the host:port it was opened against. +type Host struct { + Addr string // "host:port" + DB *sql.DB + healthy atomic.Bool +} diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index 5593ec63..a36aba0e 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -111,14 +111,14 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration if !util.FileExists(parentDir) { return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) } - sqlDB, err := sql.Open("sqlite3", filename) + d, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil { + if err := setupSQLite(d, startupQueries, cacheDuration); err != nil { return nil, err } - return newCache(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil + return newCache(db.New(&db.Host{DB: d}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil } // NewMemStore creates an in-memory cache diff --git a/user/manager_sqlite.go b/user/manager_sqlite.go index a28a846a..e92c6349 100644 --- a/user/manager_sqlite.go +++ b/user/manager_sqlite.go @@ -281,15 +281,15 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager if !util.FileExists(parentDir) { return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir) } - sqlDB, err := sql.Open("sqlite3", filename) + d, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(sqlDB); err != nil { + if err := setupSQLite(d); err != nil { return nil, err } - if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { + if err := runSQLiteStartupQueries(d, startupQueries); err != nil { return nil, err } - return newManager(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, config) + return newManager(db.New(&db.Host{DB: d}, nil), sqliteQueries, config) } diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index a8404af3..1c9adf0a 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -4,7 +4,7 @@ import ( "database/sql" "fmt" - ntfydb "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/db" ) const ( @@ -73,7 +73,7 @@ const ( ) // NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool. -func NewPostgresStore(d *ntfydb.DB) (*Store, error) { +func NewPostgresStore(d *db.DB) (*Store, error) { if err := setupPostgres(d.Primary()); err != nil { return nil, err } @@ -97,11 +97,11 @@ func NewPostgresStore(d *ntfydb.DB) (*Store, error) { }, nil } -func setupPostgres(db *sql.DB) error { +func setupPostgres(d *sql.DB) error { var schemaVersion int - err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion) + err := d.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion) if err != nil { - return setupNewPostgres(db) + return setupNewPostgres(d) } if schemaVersion > pgCurrentSchemaVersion { return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion) @@ -109,8 +109,8 @@ func setupPostgres(db *sql.DB) error { return nil } -func setupNewPostgres(sqlDB *sql.DB) error { - return ntfydb.ExecTx(sqlDB, func(tx *sql.Tx) error { +func setupNewPostgres(d *sql.DB) error { + return db.ExecTx(d, func(tx *sql.Tx) error { if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { return err } diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index 81c44704..fcf49fcf 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -79,18 +79,18 @@ const ( // NewSQLiteStore creates a new SQLite-backed web push store. func NewSQLiteStore(filename, startupQueries string) (*Store, error) { - sqlDB, err := sql.Open("sqlite3", filename) + d, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(sqlDB); err != nil { + if err := setupSQLite(d); err != nil { return nil, err } - if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { + if err := runSQLiteStartupQueries(d, startupQueries); err != nil { return nil, err } return &Store{ - db: db.New(&db.Host{DB: sqlDB}, nil), + db: db.New(&db.Host{DB: d}, nil), queries: queries{ selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,