WIP: Postgres read-only replica

This commit is contained in:
binwiederhier
2026-03-10 22:17:40 -04:00
parent 997e20fa3f
commit f1865749d7
16 changed files with 229 additions and 83 deletions

View File

@@ -40,6 +40,7 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
@@ -145,6 +146,7 @@ func execServe(c *cli.Context) error {
certFile := c.String("cert-file")
firebaseKeyFile := c.String("firebase-key-file")
databaseURL := c.String("database-url")
databaseReplicaURLs := c.StringSlice("database-replica-urls")
webPushPrivateKey := c.String("web-push-private-key")
webPushPublicKey := c.String("web-push-public-key")
webPushFile := c.String("web-push-file")
@@ -282,7 +284,9 @@ func execServe(c *cli.Context) error {
}
// Check values
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
if len(databaseReplicaURLs) > 0 && databaseURL == "" {
return errors.New("database-replica-urls can only be used if database-url is also set")
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
return errors.New("if set, FCM key file must exist")
@@ -502,6 +506,7 @@ func execServe(c *cli.Context) error {
conf.MetricsListenHTTP = metricsListenHTTP
conf.ProfileListenHTTP = profileListenHTTP
conf.DatabaseURL = databaseURL
conf.DatabaseReplicaURLs = databaseReplicaURLs
conf.WebPushPrivateKey = webPushPrivateKey
conf.WebPushPublicKey = webPushPublicKey
conf.WebPushFile = webPushFile

View File

@@ -11,6 +11,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
@@ -383,7 +384,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
if dbErr != nil {
return nil, dbErr
}
return user.NewPostgresManager(pool, authConfig)
return user.NewPostgresManager(db.NewDB(pool, nil), authConfig)
} else if authFile != "" {
if !util.FileExists(authFile) {
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")

122
db/db.go
View File

@@ -2,11 +2,129 @@ package db
import (
"database/sql"
"sync/atomic"
"time"
)
const (
replicaHealthCheckInterval = 5 * time.Second
)
// Beginner is an interface for types that can begin a database transaction.
// Both *sql.DB and *DB implement this.
type Beginner interface {
Begin() (*sql.Tx, error)
}
// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
type DB struct {
primary *sql.DB
replicas []*replica
counter atomic.Uint64
}
type replica struct {
db *sql.DB
healthy atomic.Bool
lastChecked atomic.Int64
}
// NewDB creates a new DB that wraps the given primary and optional replica connections.
// If replicas is nil or empty, ReadOnly() simply returns the primary.
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
d := &DB{
primary: primary,
replicas: make([]*replica, len(replicas)),
}
for i, r := range replicas {
rep := &replica{db: r}
rep.healthy.Store(true)
d.replicas[i] = rep
}
return d
}
// Query delegates to the primary database.
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
return d.primary.Query(query, args...)
}
// QueryRow delegates to the primary database.
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
return d.primary.QueryRow(query, args...)
}
// Exec delegates to the primary database.
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
return d.primary.Exec(query, args...)
}
// Begin delegates to the primary database.
func (d *DB) Begin() (*sql.Tx, error) {
return d.primary.Begin()
}
// Ping delegates to the primary database.
func (d *DB) Ping() error {
return d.primary.Ping()
}
// Close closes the primary database and all replicas.
func (d *DB) Close() error {
for _, r := range d.replicas {
r.db.Close()
}
return d.primary.Close()
}
// SetupPrimary returns the underlying primary *sql.DB. This is only intended for
// one-time schema setup during store initialization, not for regular queries.
func (d *DB) SetupPrimary() *sql.DB {
return d.primary
}
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it
// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary
// is returned.
func (d *DB) ReadOnly() *sql.DB {
if len(d.replicas) == 0 {
return d.primary
}
n := len(d.replicas)
start := int(d.counter.Add(1) - 1)
for i := 0; i < n; i++ {
r := d.replicas[(start+i)%n]
if d.isHealthy(r) {
return r.db
}
}
return d.primary
}
// isHealthy returns whether the replica is healthy. If the cached health status is stale,
// it pings the replica and updates the cache.
func (d *DB) isHealthy(r *replica) bool {
now := time.Now().Unix()
lastChecked := r.lastChecked.Load()
if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) {
if r.lastChecked.CompareAndSwap(lastChecked, now) {
if err := r.db.Ping(); err != nil {
r.healthy.Store(false)
return false
}
r.healthy.Store(true)
return true
}
}
return r.healthy.Load()
}
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
func ExecTx(db Beginner, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
@@ -20,7 +138,7 @@ func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
// QueryTx executes a function within a database transaction and returns the result. If the function
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T

View File

@@ -1,13 +1,13 @@
package dbtest
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/util"
)
@@ -48,16 +48,17 @@ func CreateTestPostgresSchema(t *testing.T) string {
return schemaDSN
}
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it.
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
func CreateTestPostgres(t *testing.T) *sql.DB {
func CreateTestPostgres(t *testing.T) *db.DB {
t.Helper()
schemaDSN := CreateTestPostgresSchema(t)
testDB, err := pg.Open(schemaDSN)
require.Nil(t, err)
d := db.NewDB(testDB, nil)
t.Cleanup(func() {
testDB.Close()
d.Close()
})
return testDB
return d
}

View File

@@ -50,14 +50,14 @@ type queries struct {
// Cache stores published messages
type Cache struct {
db *sql.DB
db *db.DB
queue *util.BatchingQueue[*model.Message]
nop bool
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
queries queries
}
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
var queue *util.BatchingQueue[*model.Message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
@@ -201,10 +201,11 @@ func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool)
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
rdb := c.db.ReadOnly()
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
}
if err != nil {
return nil, err
@@ -215,10 +216,11 @@ func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, schedul
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
rdb := c.db.ReadOnly()
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
}
if err != nil {
return nil, err
@@ -227,7 +229,7 @@ func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled
}
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
@@ -266,7 +268,7 @@ func (c *Cache) MessagesExpired() ([]string, error) {
// Message returns the message with the given ID, or ErrMessageNotFound if not found
func (c *Cache) Message(id string) (*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
if err != nil {
return nil, err
}
@@ -295,7 +297,7 @@ func (c *Cache) MarkPublished(m *model.Message) error {
// MessagesCount returns the total number of messages in the cache
func (c *Cache) MessagesCount() (int, error) {
rows, err := c.db.Query(c.queries.selectMessagesCount)
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
if err != nil {
return 0, err
}
@@ -312,7 +314,7 @@ func (c *Cache) MessagesCount() (int, error) {
// Topics returns a list of all topics with messages in the cache
func (c *Cache) Topics() ([]string, error) {
rows, err := c.db.Query(c.queries.selectTopics)
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
if err != nil {
return nil, err
}
@@ -426,7 +428,7 @@ func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil {
return 0, err
}
@@ -435,7 +437,7 @@ func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil {
return 0, err
}
@@ -466,7 +468,7 @@ func (c *Cache) UpdateStats(messages int64) error {
// Stats returns the total message count statistic
func (c *Cache) Stats() (messages int64, err error) {
rows, err := c.db.Query(c.queries.selectStats)
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
if err != nil {
return 0, err
}

View File

@@ -1,8 +1,9 @@
package message
import (
"database/sql"
"time"
"heckel.io/ntfy/v2/db"
)
// PostgreSQL runtime query constants
@@ -102,9 +103,9 @@ var postgresQueries = queries{
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
if err := setupPostgres(db); err != nil {
func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
if err := setupPostgres(d.SetupPrimary()); err != nil {
return nil, err
}
return newCache(db, postgresQueries, nil, batchSize, batchTimeout, false), nil
return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil
}

View File

@@ -8,6 +8,7 @@ import (
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
@@ -110,14 +111,14 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
}
db, err := sql.Open("sqlite3", filename)
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil {
return nil, err
}
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
}
// NewMemStore creates an in-memory cache

View File

@@ -95,7 +95,8 @@ type Config struct {
ListenUnixMode fs.FileMode
KeyFile string
CertFile string
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
DatabaseReplicaURLs []string // PostgreSQL read replica connection strings
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration

View File

@@ -33,6 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
@@ -47,7 +48,7 @@ import (
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite
httpServer *http.Server
httpsServer *http.Server
httpMetricsServer *http.Server
@@ -179,13 +180,26 @@ func New(conf *Config) (*Server, error) {
stripe = newStripeAPI()
}
// Open shared PostgreSQL connection pool if configured
var pool *sql.DB
var pool *db.DB
if conf.DatabaseURL != "" {
var err error
pool, err = pg.Open(conf.DatabaseURL)
primary, err := pg.Open(conf.DatabaseURL)
if err != nil {
return nil, err
}
var replicas []*sql.DB
for _, replicaURL := range conf.DatabaseReplicaURLs {
r, err := pg.Open(replicaURL)
if err != nil {
// Close already-opened replicas before returning
for _, opened := range replicas {
opened.Close()
}
primary.Close()
return nil, fmt.Errorf("failed to open database replica: %w", err)
}
replicas = append(replicas, r)
}
pool = db.NewDB(primary, replicas)
}
messageCache, err := createMessageCache(conf, pool)
if err != nil {
@@ -277,7 +291,7 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
if conf.CacheDuration == 0 {
return message.NewNopStore()
} else if pool != nil {

View File

@@ -49,7 +49,7 @@ var (
// Manager handles user authentication, authorization, and management
type Manager struct {
config *Config
db *sql.DB
db *db.DB
queries queries
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
@@ -58,7 +58,7 @@ type Manager struct {
var _ Auther = (*Manager)(nil)
func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
func newManager(d *db.DB, queries queries, config *Config) (*Manager, error) {
if config.BcryptCost <= 0 {
config.BcryptCost = DefaultUserPasswordBcryptCost
}
@@ -67,7 +67,7 @@ func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
}
manager := &Manager{
config: config,
db: db,
db: d,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
queries: queries,
@@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error {
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByName, username)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username)
if err != nil {
return nil, err
}
@@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) {
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByID(id string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByID, id)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id)
if err != nil {
return nil, err
}
@@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) {
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix())
if err != nil {
return nil, err
}
@@ -415,7 +415,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
if err != nil {
return nil, err
}
@@ -425,7 +425,7 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
// Users returns a list of users. It loads all users in a single query
// rather than one query per user to avoid N+1 performance issues.
func (a *Manager) Users() ([]*User, error) {
rows, err := a.db.Query(a.queries.selectUsers)
rows, err := a.db.ReadOnly().Query(a.queries.selectUsers)
if err != nil {
return nil, err
}
@@ -434,7 +434,7 @@ func (a *Manager) Users() ([]*User, error) {
// UsersCount returns the number of users in the database
func (a *Manager) UsersCount() (int64, error) {
rows, err := a.db.Query(a.queries.selectUserCount)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserCount)
if err != nil {
return 0, err
}
@@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
return false, false, false, err
}
@@ -660,7 +660,7 @@ func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read,
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
func (a *Manager) AllGrants() (map[string][]Grant, error) {
rows, err := a.db.Query(a.queries.selectUserAllAccess)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAllAccess)
if err != nil {
return nil, err
}
@@ -688,7 +688,7 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
// Grants returns all user-specific access control entries
func (a *Manager) Grants(username string) ([]Grant, error) {
rows, err := a.db.Query(a.queries.selectUserAccess, username)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAccess, username)
if err != nil {
return nil, err
}
@@ -753,7 +753,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
// Reservations returns all user-owned topics, and the associated everyone-access
func (a *Manager) Reservations(username string) ([]Reservation, error) {
rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
if err != nil {
return nil, err
}
@@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
// HasReservation returns true if the given topic access is owned by the user
func (a *Manager) HasReservation(username, topic string) (bool, error) {
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
if err != nil {
return false, err
}
@@ -796,7 +796,7 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
// ReservationsCount returns the number of reservations owned by this user
func (a *Manager) ReservationsCount(username string) (int64, error) {
rows, err := a.db.Query(a.queries.selectUserReservationsCount, username)
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
if err != nil {
return 0, err
}
@@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
func (a *Manager) ReservationOwner(topic string) (string, error) {
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
if err != nil {
return "", err
}
@@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
if err != nil {
return 0, err
}
@@ -962,7 +962,7 @@ func (a *Manager) canChangeToken(userID, token string) error {
// Token returns a specific token for a user
func (a *Manager) Token(userID, token string) (*Token, error) {
rows, err := a.db.Query(a.queries.selectToken, userID, token)
rows, err := a.db.ReadOnly().Query(a.queries.selectToken, userID, token)
if err != nil {
return nil, err
}
@@ -972,7 +972,7 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
// Tokens returns all existing tokens for the user with the given user ID
func (a *Manager) Tokens(userID string) ([]*Token, error) {
rows, err := a.db.Query(a.queries.selectTokens, userID)
rows, err := a.db.ReadOnly().Query(a.queries.selectTokens, userID)
if err != nil {
return nil, err
}
@@ -991,7 +991,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
}
func (a *Manager) allProvisionedTokens() ([]*Token, error) {
rows, err := a.db.Query(a.queries.selectAllProvisionedTokens)
rows, err := a.db.ReadOnly().Query(a.queries.selectAllProvisionedTokens)
if err != nil {
return nil, err
}
@@ -1114,7 +1114,7 @@ func (a *Manager) RemoveTier(code string) error {
// Tiers returns a list of all Tier structs
func (a *Manager) Tiers() ([]*Tier, error) {
rows, err := a.db.Query(a.queries.selectTiers)
rows, err := a.db.ReadOnly().Query(a.queries.selectTiers)
if err != nil {
return nil, err
}
@@ -1134,7 +1134,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(a.queries.selectTierByCode, code)
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
if err != nil {
return nil, err
}
@@ -1144,7 +1144,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
if err != nil {
return nil, err
}
@@ -1185,7 +1185,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
// PhoneNumbers returns all phone numbers for the user with the given user ID
func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID)
rows, err := a.db.ReadOnly().Query(a.queries.selectPhoneNumbers, userID)
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,7 @@
package user
import (
"database/sql"
"heckel.io/ntfy/v2/db"
)
// PostgreSQL queries
@@ -278,9 +278,9 @@ var postgresQueries = queries{
}
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) {
if err := setupPostgres(db); err != nil {
func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) {
if err := setupPostgres(d.SetupPrimary()); err != nil {
return nil, err
}
return newManager(db, postgresQueries, config)
return newManager(d, postgresQueries, config)
}

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
@@ -280,15 +281,15 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
db, err := sql.Open("sqlite3", filename)
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(db); err != nil {
if err := setupSQLite(sqlDB); err != nil {
return nil, err
}
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
return nil, err
}
return newManager(db, sqliteQueries, config)
return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config)
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/util"
@@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
f(t, func(config *Config) *Manager {
pool, err := pg.Open(schemaDSN)
require.Nil(t, err)
a, err := NewPostgresManager(pool, config)
a, err := NewPostgresManager(db.NewDB(pool, nil), config)
require.Nil(t, err)
return a
})
@@ -1734,8 +1735,8 @@ func TestMigrationFrom4(t *testing.T) {
require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
}
func checkSchemaVersion(t *testing.T, db *sql.DB) {
rows, err := db.Query(`SELECT version FROM schemaVersion`)
func checkSchemaVersion(t *testing.T, d *db.DB) {
rows, err := d.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err)
require.True(t, rows.Next())
@@ -1771,7 +1772,7 @@ func newTestManagerFromConfig(t *testing.T, newManager newManagerFunc, conf *Con
return a
}
func testDB(a *Manager) *sql.DB {
func testDB(a *Manager) *db.DB {
return a.db
}

View File

@@ -24,7 +24,7 @@ var (
// Store holds the database connection and queries for web push subscriptions.
type Store struct {
db *sql.DB
db *db.DB
queries queries
}
@@ -83,7 +83,7 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string,
// SubscriptionsForTopic returns all subscriptions for the given topic.
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsForTopic, topic)
if err != nil {
return nil, err
}
@@ -93,7 +93,7 @@ func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
}

View File

@@ -4,7 +4,7 @@ import (
"database/sql"
"fmt"
"heckel.io/ntfy/v2/db"
ntfydb "heckel.io/ntfy/v2/db"
)
const (
@@ -73,12 +73,12 @@ const (
)
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
func NewPostgresStore(db *sql.DB) (*Store, error) {
if err := setupPostgres(db); err != nil {
func NewPostgresStore(d *ntfydb.DB) (*Store, error) {
if err := setupPostgres(d.SetupPrimary()); err != nil {
return nil, err
}
return &Store{
db: db,
db: d,
queries: queries{
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
@@ -110,7 +110,7 @@ func setupPostgres(db *sql.DB) error {
}
func setupNewPostgres(sqlDB *sql.DB) error {
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
return ntfydb.ExecTx(sqlDB, func(tx *sql.Tx) error {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}

View File

@@ -79,18 +79,18 @@ const (
// NewSQLiteStore creates a new SQLite-backed web push store.
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
db, err := sql.Open("sqlite3", filename)
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(db); err != nil {
if err := setupSQLite(sqlDB); err != nil {
return nil, err
}
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
return nil, err
}
return &Store{
db: db,
db: db.NewDB(sqlDB, nil),
queries: queries{
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,