From f1865749d72dcc02b1038dc45d01a9e594307e6a Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 10 Mar 2026 22:17:40 -0400 Subject: [PATCH] WIP: Postgres read-only replica --- cmd/serve.go | 7 ++- cmd/user.go | 3 +- db/db.go | 122 +++++++++++++++++++++++++++++++++++++- db/test/test.go | 11 ++-- message/cache.go | 28 +++++---- message/cache_postgres.go | 9 +-- message/cache_sqlite.go | 7 ++- server/config.go | 3 +- server/server.go | 24 ++++++-- user/manager.go | 48 +++++++-------- user/manager_postgres.go | 8 +-- user/manager_sqlite.go | 9 +-- user/manager_test.go | 9 +-- webpush/store.go | 6 +- webpush/store_postgres.go | 10 ++-- webpush/store_sqlite.go | 8 +-- 16 files changed, 229 insertions(+), 83 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 313ec835..1ac4f86e 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -40,6 +40,7 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}), @@ -145,6 +146,7 @@ func execServe(c *cli.Context) error { certFile := c.String("cert-file") firebaseKeyFile := c.String("firebase-key-file") databaseURL := c.String("database-url") + databaseReplicaURLs := c.StringSlice("database-replica-urls") webPushPrivateKey := c.String("web-push-private-key") webPushPublicKey := c.String("web-push-public-key") webPushFile := c.String("web-push-file") @@ -282,7 +284,9 @@ func execServe(c *cli.Context) error { } // Check values - if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") { + if len(databaseReplicaURLs) > 0 && databaseURL == "" { + return errors.New("database-replica-urls can only be used if database-url is also set") + } else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") { return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set") } else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { return errors.New("if set, FCM key file must exist") @@ -502,6 +506,7 @@ func execServe(c *cli.Context) error { conf.MetricsListenHTTP = metricsListenHTTP conf.ProfileListenHTTP = profileListenHTTP conf.DatabaseURL = databaseURL + conf.DatabaseReplicaURLs = databaseReplicaURLs conf.WebPushPrivateKey = webPushPrivateKey conf.WebPushPublicKey = webPushPublicKey conf.WebPushFile = webPushFile diff --git a/cmd/user.go b/cmd/user.go index 1ffc3e6b..f108650b 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -11,6 +11,7 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/server" "heckel.io/ntfy/v2/user" @@ -383,7 +384,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { if dbErr != nil { return nil, dbErr } - return user.NewPostgresManager(pool, authConfig) + return user.NewPostgresManager(db.NewDB(pool, nil), authConfig) } else if authFile != "" { if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") diff --git a/db/db.go b/db/db.go index 00ae91f4..f2a539e5 100644 --- a/db/db.go +++ b/db/db.go @@ -2,11 +2,129 @@ package db import ( "database/sql" + "sync/atomic" + "time" ) +const ( + replicaHealthCheckInterval = 5 * time.Second +) + +// Beginner is an interface for types that can begin a database transaction. +// Both *sql.DB and *DB implement this. +type Beginner interface { + Begin() (*sql.Tx, error) +} + +// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods +// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica +// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy. +type DB struct { + primary *sql.DB + replicas []*replica + counter atomic.Uint64 +} + +type replica struct { + db *sql.DB + healthy atomic.Bool + lastChecked atomic.Int64 +} + +// NewDB creates a new DB that wraps the given primary and optional replica connections. +// If replicas is nil or empty, ReadOnly() simply returns the primary. +func NewDB(primary *sql.DB, replicas []*sql.DB) *DB { + d := &DB{ + primary: primary, + replicas: make([]*replica, len(replicas)), + } + for i, r := range replicas { + rep := &replica{db: r} + rep.healthy.Store(true) + d.replicas[i] = rep + } + return d +} + +// Query delegates to the primary database. +func (d *DB) Query(query string, args ...any) (*sql.Rows, error) { + return d.primary.Query(query, args...) +} + +// QueryRow delegates to the primary database. +func (d *DB) QueryRow(query string, args ...any) *sql.Row { + return d.primary.QueryRow(query, args...) +} + +// Exec delegates to the primary database. +func (d *DB) Exec(query string, args ...any) (sql.Result, error) { + return d.primary.Exec(query, args...) +} + +// Begin delegates to the primary database. +func (d *DB) Begin() (*sql.Tx, error) { + return d.primary.Begin() +} + +// Ping delegates to the primary database. +func (d *DB) Ping() error { + return d.primary.Ping() +} + +// Close closes the primary database and all replicas. +func (d *DB) Close() error { + for _, r := range d.replicas { + r.db.Close() + } + return d.primary.Close() +} + +// SetupPrimary returns the underlying primary *sql.DB. This is only intended for +// one-time schema setup during store initialization, not for regular queries. +func (d *DB) SetupPrimary() *sql.DB { + return d.primary +} + +// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy +// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it +// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary +// is returned. +func (d *DB) ReadOnly() *sql.DB { + if len(d.replicas) == 0 { + return d.primary + } + n := len(d.replicas) + start := int(d.counter.Add(1) - 1) + for i := 0; i < n; i++ { + r := d.replicas[(start+i)%n] + if d.isHealthy(r) { + return r.db + } + } + return d.primary +} + +// isHealthy returns whether the replica is healthy. If the cached health status is stale, +// it pings the replica and updates the cache. +func (d *DB) isHealthy(r *replica) bool { + now := time.Now().Unix() + lastChecked := r.lastChecked.Load() + if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) { + if r.lastChecked.CompareAndSwap(lastChecked, now) { + if err := r.db.Ping(); err != nil { + r.healthy.Store(false) + return false + } + r.healthy.Store(true) + return true + } + } + return r.healthy.Load() +} + // ExecTx executes a function within a database transaction. If the function returns an error, // the transaction is rolled back. Otherwise, the transaction is committed. -func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error { +func ExecTx(db Beginner, f func(tx *sql.Tx) error) error { tx, err := db.Begin() if err != nil { return err @@ -20,7 +138,7 @@ func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error { // QueryTx executes a function within a database transaction and returns the result. If the function // returns an error, the transaction is rolled back. Otherwise, the transaction is committed. -func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { +func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) { tx, err := db.Begin() if err != nil { var zero T diff --git a/db/test/test.go b/db/test/test.go index 36c3fc86..0bfa8e11 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -1,13 +1,13 @@ package dbtest import ( - "database/sql" "fmt" "net/url" "os" "testing" "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/util" ) @@ -48,16 +48,17 @@ func CreateTestPostgresSchema(t *testing.T) string { return schemaDSN } -// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it. +// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it. // It registers cleanup functions to close the DB and drop the schema when the test finishes. // If NTFY_TEST_DATABASE_URL is not set, the test is skipped. -func CreateTestPostgres(t *testing.T) *sql.DB { +func CreateTestPostgres(t *testing.T) *db.DB { t.Helper() schemaDSN := CreateTestPostgresSchema(t) testDB, err := pg.Open(schemaDSN) require.Nil(t, err) + d := db.NewDB(testDB, nil) t.Cleanup(func() { - testDB.Close() + d.Close() }) - return testDB + return d } diff --git a/message/cache.go b/message/cache.go index 3b12af3e..b123fba4 100644 --- a/message/cache.go +++ b/message/cache.go @@ -50,14 +50,14 @@ type queries struct { // Cache stores published messages type Cache struct { - db *sql.DB + db *db.DB queue *util.BatchingQueue[*model.Message] nop bool mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer) queries queries } -func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache { +func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache { var queue *util.BatchingQueue[*model.Message] if batchSize > 0 || batchTimeout > 0 { queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout) @@ -201,10 +201,11 @@ func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { var rows *sql.Rows var err error + rdb := c.db.ReadOnly() if scheduled { - rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix()) + rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix()) } else { - rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix()) + rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix()) } if err != nil { return nil, err @@ -215,10 +216,11 @@ func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, schedul func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { var rows *sql.Rows var err error + rdb := c.db.ReadOnly() if scheduled { - rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID()) + rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID()) } else { - rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID()) + rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID()) } if err != nil { return nil, err @@ -227,7 +229,7 @@ func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled } func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) { - rows, err := c.db.Query(c.queries.selectMessagesLatest, topic) + rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic) if err != nil { return nil, err } @@ -266,7 +268,7 @@ func (c *Cache) MessagesExpired() ([]string, error) { // Message returns the message with the given ID, or ErrMessageNotFound if not found func (c *Cache) Message(id string) (*model.Message, error) { - rows, err := c.db.Query(c.queries.selectMessagesByID, id) + rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id) if err != nil { return nil, err } @@ -295,7 +297,7 @@ func (c *Cache) MarkPublished(m *model.Message) error { // MessagesCount returns the total number of messages in the cache func (c *Cache) MessagesCount() (int, error) { - rows, err := c.db.Query(c.queries.selectMessagesCount) + rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount) if err != nil { return 0, err } @@ -312,7 +314,7 @@ func (c *Cache) MessagesCount() (int, error) { // Topics returns a list of all topics with messages in the cache func (c *Cache) Topics() ([]string, error) { - rows, err := c.db.Query(c.queries.selectTopics) + rows, err := c.db.ReadOnly().Query(c.queries.selectTopics) if err != nil { return nil, err } @@ -426,7 +428,7 @@ func (c *Cache) MarkAttachmentsDeleted(ids ...string) error { // AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) { - rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) + rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) if err != nil { return 0, err } @@ -435,7 +437,7 @@ func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) { // AttachmentBytesUsedByUser returns the total size of active attachments for the given user func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) { - rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) + rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) if err != nil { return 0, err } @@ -466,7 +468,7 @@ func (c *Cache) UpdateStats(messages int64) error { // Stats returns the total message count statistic func (c *Cache) Stats() (messages int64, err error) { - rows, err := c.db.Query(c.queries.selectStats) + rows, err := c.db.ReadOnly().Query(c.queries.selectStats) if err != nil { return 0, err } diff --git a/message/cache_postgres.go b/message/cache_postgres.go index 0146f409..5b7a0293 100644 --- a/message/cache_postgres.go +++ b/message/cache_postgres.go @@ -1,8 +1,9 @@ package message import ( - "database/sql" "time" + + "heckel.io/ntfy/v2/db" ) // PostgreSQL runtime query constants @@ -102,9 +103,9 @@ var postgresQueries = queries{ } // NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. -func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) { - if err := setupPostgres(db); err != nil { +func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) { + if err := setupPostgres(d.SetupPrimary()); err != nil { return nil, err } - return newCache(db, postgresQueries, nil, batchSize, batchTimeout, false), nil + return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil } diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index f9d8605e..375902d6 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -8,6 +8,7 @@ import ( "time" _ "github.com/mattn/go-sqlite3" // SQLite driver + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/util" ) @@ -110,14 +111,14 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration if !util.FileExists(parentDir) { return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) } - db, err := sql.Open("sqlite3", filename) + sqlDB, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(db, startupQueries, cacheDuration); err != nil { + if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil { return nil, err } - return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil + return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil } // NewMemStore creates an in-memory cache diff --git a/server/config.go b/server/config.go index 786f0d78..8ead312c 100644 --- a/server/config.go +++ b/server/config.go @@ -95,7 +95,8 @@ type Config struct { ListenUnixMode fs.FileMode KeyFile string CertFile string - DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy") + DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy") + DatabaseReplicaURLs []string // PostgreSQL read replica connection strings FirebaseKeyFile string CacheFile string CacheDuration time.Duration diff --git a/server/server.go b/server/server.go index 329b0ab5..62fba3c2 100644 --- a/server/server.go +++ b/server/server.go @@ -33,6 +33,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/message" @@ -47,7 +48,7 @@ import ( // Server is the main server, providing the UI and API for ntfy type Server struct { config *Config - db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite + db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite httpServer *http.Server httpsServer *http.Server httpMetricsServer *http.Server @@ -179,13 +180,26 @@ func New(conf *Config) (*Server, error) { stripe = newStripeAPI() } // Open shared PostgreSQL connection pool if configured - var pool *sql.DB + var pool *db.DB if conf.DatabaseURL != "" { - var err error - pool, err = pg.Open(conf.DatabaseURL) + primary, err := pg.Open(conf.DatabaseURL) if err != nil { return nil, err } + var replicas []*sql.DB + for _, replicaURL := range conf.DatabaseReplicaURLs { + r, err := pg.Open(replicaURL) + if err != nil { + // Close already-opened replicas before returning + for _, opened := range replicas { + opened.Close() + } + primary.Close() + return nil, fmt.Errorf("failed to open database replica: %w", err) + } + replicas = append(replicas, r) + } + pool = db.NewDB(primary, replicas) } messageCache, err := createMessageCache(conf, pool) if err != nil { @@ -277,7 +291,7 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) { +func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { if conf.CacheDuration == 0 { return message.NewNopStore() } else if pool != nil { diff --git a/user/manager.go b/user/manager.go index 0ee6a6e1..31106983 100644 --- a/user/manager.go +++ b/user/manager.go @@ -49,7 +49,7 @@ var ( // Manager handles user authentication, authorization, and management type Manager struct { config *Config - db *sql.DB + db *db.DB queries queries statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate) @@ -58,7 +58,7 @@ type Manager struct { var _ Auther = (*Manager)(nil) -func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) { +func newManager(d *db.DB, queries queries, config *Config) (*Manager, error) { if config.BcryptCost <= 0 { config.BcryptCost = DefaultUserPasswordBcryptCost } @@ -67,7 +67,7 @@ func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) { } manager := &Manager{ config: config, - db: db, + db: d, statsQueue: make(map[string]*Stats), tokenQueue: make(map[string]*TokenUpdate), queries: queries, @@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error { // User returns the user with the given username if it exists, or ErrUserNotFound otherwise func (a *Manager) User(username string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByName, username) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username) if err != nil { return nil, err } @@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) { // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByID(id string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByID, id) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id) if err != nil { return nil, err } @@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) { // userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise func (a *Manager) userByToken(token string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix()) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix()) if err != nil { return nil, err } @@ -415,7 +415,7 @@ func (a *Manager) userByToken(token string) (*User, error) { // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { - rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID) if err != nil { return nil, err } @@ -425,7 +425,7 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { // Users returns a list of users. It loads all users in a single query // rather than one query per user to avoid N+1 performance issues. func (a *Manager) Users() ([]*User, error) { - rows, err := a.db.Query(a.queries.selectUsers) + rows, err := a.db.ReadOnly().Query(a.queries.selectUsers) if err != nil { return nil, err } @@ -434,7 +434,7 @@ func (a *Manager) Users() ([]*User, error) { // UsersCount returns the number of users in the database func (a *Manager) UsersCount() (int64, error) { - rows, err := a.db.Query(a.queries.selectUserCount) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserCount) if err != nil { return 0, err } @@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error { // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" // - It also prioritizes write permissions over read permissions func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { - rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) + rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) if err != nil { return false, false, false, err } @@ -660,7 +660,7 @@ func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, // AllGrants returns all user-specific access control entries, mapped to their respective user IDs func (a *Manager) AllGrants() (map[string][]Grant, error) { - rows, err := a.db.Query(a.queries.selectUserAllAccess) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserAllAccess) if err != nil { return nil, err } @@ -688,7 +688,7 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) { // Grants returns all user-specific access control entries func (a *Manager) Grants(username string) ([]Grant, error) { - rows, err := a.db.Query(a.queries.selectUserAccess, username) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserAccess, username) if err != nil { return nil, err } @@ -753,7 +753,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { // Reservations returns all user-owned topics, and the associated everyone-access func (a *Manager) Reservations(username string) ([]Reservation, error) { - rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username) if err != nil { return nil, err } @@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { // HasReservation returns true if the given topic access is owned by the user func (a *Manager) HasReservation(username, topic string) (bool, error) { - rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) if err != nil { return false, err } @@ -796,7 +796,7 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) { // ReservationsCount returns the number of reservations owned by this user func (a *Manager) ReservationsCount(username string) (int64, error) { - rows, err := a.db.Query(a.queries.selectUserReservationsCount, username) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username) if err != nil { return 0, err } @@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) { // ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone func (a *Manager) ReservationOwner(topic string) (string, error) { - rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic)) + rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic)) if err != nil { return "", err } @@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) { // otherAccessCount returns the number of access entries for the given topic that are not owned by the user func (a *Manager) otherAccessCount(username, topic string) (int, error) { - rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) + rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) if err != nil { return 0, err } @@ -962,7 +962,7 @@ func (a *Manager) canChangeToken(userID, token string) error { // Token returns a specific token for a user func (a *Manager) Token(userID, token string) (*Token, error) { - rows, err := a.db.Query(a.queries.selectToken, userID, token) + rows, err := a.db.ReadOnly().Query(a.queries.selectToken, userID, token) if err != nil { return nil, err } @@ -972,7 +972,7 @@ func (a *Manager) Token(userID, token string) (*Token, error) { // Tokens returns all existing tokens for the user with the given user ID func (a *Manager) Tokens(userID string) ([]*Token, error) { - rows, err := a.db.Query(a.queries.selectTokens, userID) + rows, err := a.db.ReadOnly().Query(a.queries.selectTokens, userID) if err != nil { return nil, err } @@ -991,7 +991,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) { } func (a *Manager) allProvisionedTokens() ([]*Token, error) { - rows, err := a.db.Query(a.queries.selectAllProvisionedTokens) + rows, err := a.db.ReadOnly().Query(a.queries.selectAllProvisionedTokens) if err != nil { return nil, err } @@ -1114,7 +1114,7 @@ func (a *Manager) RemoveTier(code string) error { // Tiers returns a list of all Tier structs func (a *Manager) Tiers() ([]*Tier, error) { - rows, err := a.db.Query(a.queries.selectTiers) + rows, err := a.db.ReadOnly().Query(a.queries.selectTiers) if err != nil { return nil, err } @@ -1134,7 +1134,7 @@ func (a *Manager) Tiers() ([]*Tier, error) { // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist func (a *Manager) Tier(code string) (*Tier, error) { - rows, err := a.db.Query(a.queries.selectTierByCode, code) + rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code) if err != nil { return nil, err } @@ -1144,7 +1144,7 @@ func (a *Manager) Tier(code string) (*Tier, error) { // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { - rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID) + rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID) if err != nil { return nil, err } @@ -1185,7 +1185,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { // PhoneNumbers returns all phone numbers for the user with the given user ID func (a *Manager) PhoneNumbers(userID string) ([]string, error) { - rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID) + rows, err := a.db.ReadOnly().Query(a.queries.selectPhoneNumbers, userID) if err != nil { return nil, err } diff --git a/user/manager_postgres.go b/user/manager_postgres.go index 7138ae2c..7a332070 100644 --- a/user/manager_postgres.go +++ b/user/manager_postgres.go @@ -1,7 +1,7 @@ package user import ( - "database/sql" + "heckel.io/ntfy/v2/db" ) // PostgreSQL queries @@ -278,9 +278,9 @@ var postgresQueries = queries{ } // NewPostgresManager creates a new Manager backed by a PostgreSQL database -func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) { - if err := setupPostgres(db); err != nil { +func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) { + if err := setupPostgres(d.SetupPrimary()); err != nil { return nil, err } - return newManager(db, postgresQueries, config) + return newManager(d, postgresQueries, config) } diff --git a/user/manager_sqlite.go b/user/manager_sqlite.go index b4068599..573d1e41 100644 --- a/user/manager_sqlite.go +++ b/user/manager_sqlite.go @@ -7,6 +7,7 @@ import ( _ "github.com/mattn/go-sqlite3" // SQLite driver + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/util" ) @@ -280,15 +281,15 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager if !util.FileExists(parentDir) { return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir) } - db, err := sql.Open("sqlite3", filename) + sqlDB, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(db); err != nil { + if err := setupSQLite(sqlDB); err != nil { return nil, err } - if err := runSQLiteStartupQueries(db, startupQueries); err != nil { + if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { return nil, err } - return newManager(db, sqliteQueries, config) + return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config) } diff --git a/user/manager_test.go b/user/manager_test.go index 53cae1d1..73df3df1 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg" dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/util" @@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc f(t, func(config *Config) *Manager { pool, err := pg.Open(schemaDSN) require.Nil(t, err) - a, err := NewPostgresManager(pool, config) + a, err := NewPostgresManager(db.NewDB(pool, nil), config) require.Nil(t, err) return a }) @@ -1734,8 +1735,8 @@ func TestMigrationFrom4(t *testing.T) { require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters } -func checkSchemaVersion(t *testing.T, db *sql.DB) { - rows, err := db.Query(`SELECT version FROM schemaVersion`) +func checkSchemaVersion(t *testing.T, d *db.DB) { + rows, err := d.Query(`SELECT version FROM schemaVersion`) require.Nil(t, err) require.True(t, rows.Next()) @@ -1771,7 +1772,7 @@ func newTestManagerFromConfig(t *testing.T, newManager newManagerFunc, conf *Con return a } -func testDB(a *Manager) *sql.DB { +func testDB(a *Manager) *db.DB { return a.db } diff --git a/webpush/store.go b/webpush/store.go index 9a93a074..02b7552e 100644 --- a/webpush/store.go +++ b/webpush/store.go @@ -24,7 +24,7 @@ var ( // Store holds the database connection and queries for web push subscriptions. type Store struct { - db *sql.DB + db *db.DB queries queries } @@ -83,7 +83,7 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, // SubscriptionsForTopic returns all subscriptions for the given topic. func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) { - rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic) + rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsForTopic, topic) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) { // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { - rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix()) + rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix()) if err != nil { return nil, err } diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index ec541d37..cce9ec73 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -4,7 +4,7 @@ import ( "database/sql" "fmt" - "heckel.io/ntfy/v2/db" + ntfydb "heckel.io/ntfy/v2/db" ) const ( @@ -73,12 +73,12 @@ const ( ) // NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool. -func NewPostgresStore(db *sql.DB) (*Store, error) { - if err := setupPostgres(db); err != nil { +func NewPostgresStore(d *ntfydb.DB) (*Store, error) { + if err := setupPostgres(d.SetupPrimary()); err != nil { return nil, err } return &Store{ - db: db, + db: d, queries: queries{ selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery, @@ -110,7 +110,7 @@ func setupPostgres(db *sql.DB) error { } func setupNewPostgres(sqlDB *sql.DB) error { - return db.ExecTx(sqlDB, func(tx *sql.Tx) error { + return ntfydb.ExecTx(sqlDB, func(tx *sql.Tx) error { if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { return err } diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index 4ef78140..c2d105ff 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -79,18 +79,18 @@ const ( // NewSQLiteStore creates a new SQLite-backed web push store. func NewSQLiteStore(filename, startupQueries string) (*Store, error) { - db, err := sql.Open("sqlite3", filename) + sqlDB, err := sql.Open("sqlite3", filename) if err != nil { return nil, err } - if err := setupSQLite(db); err != nil { + if err := setupSQLite(sqlDB); err != nil { return nil, err } - if err := runSQLiteStartupQueries(db, startupQueries); err != nil { + if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { return nil, err } return &Store{ - db: db, + db: db.NewDB(sqlDB, nil), queries: queries{ selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,