mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 13:20:48 +01:00
WIP: Postgres read-only replica
This commit is contained in:
@@ -40,6 +40,7 @@ var flagsServe = append(
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
|
||||
@@ -145,6 +146,7 @@ func execServe(c *cli.Context) error {
|
||||
certFile := c.String("cert-file")
|
||||
firebaseKeyFile := c.String("firebase-key-file")
|
||||
databaseURL := c.String("database-url")
|
||||
databaseReplicaURLs := c.StringSlice("database-replica-urls")
|
||||
webPushPrivateKey := c.String("web-push-private-key")
|
||||
webPushPublicKey := c.String("web-push-public-key")
|
||||
webPushFile := c.String("web-push-file")
|
||||
@@ -282,7 +284,9 @@ func execServe(c *cli.Context) error {
|
||||
}
|
||||
|
||||
// Check values
|
||||
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||
if len(databaseReplicaURLs) > 0 && databaseURL == "" {
|
||||
return errors.New("database-replica-urls can only be used if database-url is also set")
|
||||
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||
return errors.New("if set, FCM key file must exist")
|
||||
@@ -502,6 +506,7 @@ func execServe(c *cli.Context) error {
|
||||
conf.MetricsListenHTTP = metricsListenHTTP
|
||||
conf.ProfileListenHTTP = profileListenHTTP
|
||||
conf.DatabaseURL = databaseURL
|
||||
conf.DatabaseReplicaURLs = databaseReplicaURLs
|
||||
conf.WebPushPrivateKey = webPushPrivateKey
|
||||
conf.WebPushPublicKey = webPushPublicKey
|
||||
conf.WebPushFile = webPushFile
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
@@ -383,7 +384,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
return user.NewPostgresManager(pool, authConfig)
|
||||
return user.NewPostgresManager(db.NewDB(pool, nil), authConfig)
|
||||
} else if authFile != "" {
|
||||
if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
|
||||
122
db/db.go
122
db/db.go
@@ -2,11 +2,129 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
replicaHealthCheckInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// Beginner is an interface for types that can begin a database transaction.
|
||||
// Both *sql.DB and *DB implement this.
|
||||
type Beginner interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods
|
||||
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
|
||||
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
|
||||
type DB struct {
|
||||
primary *sql.DB
|
||||
replicas []*replica
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
type replica struct {
|
||||
db *sql.DB
|
||||
healthy atomic.Bool
|
||||
lastChecked atomic.Int64
|
||||
}
|
||||
|
||||
// NewDB creates a new DB that wraps the given primary and optional replica connections.
|
||||
// If replicas is nil or empty, ReadOnly() simply returns the primary.
|
||||
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
|
||||
d := &DB{
|
||||
primary: primary,
|
||||
replicas: make([]*replica, len(replicas)),
|
||||
}
|
||||
for i, r := range replicas {
|
||||
rep := &replica{db: r}
|
||||
rep.healthy.Store(true)
|
||||
d.replicas[i] = rep
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// Query delegates to the primary database.
|
||||
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return d.primary.Query(query, args...)
|
||||
}
|
||||
|
||||
// QueryRow delegates to the primary database.
|
||||
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
return d.primary.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
// Exec delegates to the primary database.
|
||||
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return d.primary.Exec(query, args...)
|
||||
}
|
||||
|
||||
// Begin delegates to the primary database.
|
||||
func (d *DB) Begin() (*sql.Tx, error) {
|
||||
return d.primary.Begin()
|
||||
}
|
||||
|
||||
// Ping delegates to the primary database.
|
||||
func (d *DB) Ping() error {
|
||||
return d.primary.Ping()
|
||||
}
|
||||
|
||||
// Close closes the primary database and all replicas.
|
||||
func (d *DB) Close() error {
|
||||
for _, r := range d.replicas {
|
||||
r.db.Close()
|
||||
}
|
||||
return d.primary.Close()
|
||||
}
|
||||
|
||||
// SetupPrimary returns the underlying primary *sql.DB. This is only intended for
|
||||
// one-time schema setup during store initialization, not for regular queries.
|
||||
func (d *DB) SetupPrimary() *sql.DB {
|
||||
return d.primary
|
||||
}
|
||||
|
||||
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
|
||||
// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it
|
||||
// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary
|
||||
// is returned.
|
||||
func (d *DB) ReadOnly() *sql.DB {
|
||||
if len(d.replicas) == 0 {
|
||||
return d.primary
|
||||
}
|
||||
n := len(d.replicas)
|
||||
start := int(d.counter.Add(1) - 1)
|
||||
for i := 0; i < n; i++ {
|
||||
r := d.replicas[(start+i)%n]
|
||||
if d.isHealthy(r) {
|
||||
return r.db
|
||||
}
|
||||
}
|
||||
return d.primary
|
||||
}
|
||||
|
||||
// isHealthy returns whether the replica is healthy. If the cached health status is stale,
|
||||
// it pings the replica and updates the cache.
|
||||
func (d *DB) isHealthy(r *replica) bool {
|
||||
now := time.Now().Unix()
|
||||
lastChecked := r.lastChecked.Load()
|
||||
if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) {
|
||||
if r.lastChecked.CompareAndSwap(lastChecked, now) {
|
||||
if err := r.db.Ping(); err != nil {
|
||||
r.healthy.Store(false)
|
||||
return false
|
||||
}
|
||||
r.healthy.Store(true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return r.healthy.Load()
|
||||
}
|
||||
|
||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
func ExecTx(db Beginner, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -20,7 +138,7 @@ func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
|
||||
// QueryTx executes a function within a database transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package dbtest
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -48,16 +48,17 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
||||
return schemaDSN
|
||||
}
|
||||
|
||||
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
|
||||
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it.
|
||||
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
||||
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||
func CreateTestPostgres(t *testing.T) *sql.DB {
|
||||
func CreateTestPostgres(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
schemaDSN := CreateTestPostgresSchema(t)
|
||||
testDB, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
d := db.NewDB(testDB, nil)
|
||||
t.Cleanup(func() {
|
||||
testDB.Close()
|
||||
d.Close()
|
||||
})
|
||||
return testDB
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -50,14 +50,14 @@ type queries struct {
|
||||
|
||||
// Cache stores published messages
|
||||
type Cache struct {
|
||||
db *sql.DB
|
||||
db *db.DB
|
||||
queue *util.BatchingQueue[*model.Message]
|
||||
nop bool
|
||||
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
||||
queries queries
|
||||
}
|
||||
|
||||
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||
var queue *util.BatchingQueue[*model.Message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||
@@ -201,10 +201,11 @@ func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool)
|
||||
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
rdb := c.db.ReadOnly()
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -215,10 +216,11 @@ func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, schedul
|
||||
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
rdb := c.db.ReadOnly()
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -227,7 +229,7 @@ func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled
|
||||
}
|
||||
|
||||
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -266,7 +268,7 @@ func (c *Cache) MessagesExpired() ([]string, error) {
|
||||
|
||||
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||
func (c *Cache) Message(id string) (*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -295,7 +297,7 @@ func (c *Cache) MarkPublished(m *model.Message) error {
|
||||
|
||||
// MessagesCount returns the total number of messages in the cache
|
||||
func (c *Cache) MessagesCount() (int, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -312,7 +314,7 @@ func (c *Cache) MessagesCount() (int, error) {
|
||||
|
||||
// Topics returns a list of all topics with messages in the cache
|
||||
func (c *Cache) Topics() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectTopics)
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -426,7 +428,7 @@ func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||
|
||||
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -435,7 +437,7 @@ func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
|
||||
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
||||
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -466,7 +468,7 @@ func (c *Cache) UpdateStats(messages int64) error {
|
||||
|
||||
// Stats returns the total message count statistic
|
||||
func (c *Cache) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(c.queries.selectStats)
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// PostgreSQL runtime query constants
|
||||
@@ -102,9 +103,9 @@ var postgresQueries = queries{
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||
if err := setupPostgres(db); err != nil {
|
||||
func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCache(db, postgresQueries, nil, batchSize, batchTimeout, false), nil
|
||||
return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -110,14 +111,14 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
sqlDB, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||
if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||
return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||
}
|
||||
|
||||
// NewMemStore creates an in-memory cache
|
||||
|
||||
@@ -95,7 +95,8 @@ type Config struct {
|
||||
ListenUnixMode fs.FileMode
|
||||
KeyFile string
|
||||
CertFile string
|
||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||
DatabaseReplicaURLs []string // PostgreSQL read replica connection strings
|
||||
FirebaseKeyFile string
|
||||
CacheFile string
|
||||
CacheDuration time.Duration
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
@@ -47,7 +48,7 @@ import (
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
type Server struct {
|
||||
config *Config
|
||||
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
|
||||
db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite
|
||||
httpServer *http.Server
|
||||
httpsServer *http.Server
|
||||
httpMetricsServer *http.Server
|
||||
@@ -179,13 +180,26 @@ func New(conf *Config) (*Server, error) {
|
||||
stripe = newStripeAPI()
|
||||
}
|
||||
// Open shared PostgreSQL connection pool if configured
|
||||
var pool *sql.DB
|
||||
var pool *db.DB
|
||||
if conf.DatabaseURL != "" {
|
||||
var err error
|
||||
pool, err = pg.Open(conf.DatabaseURL)
|
||||
primary, err := pg.Open(conf.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var replicas []*sql.DB
|
||||
for _, replicaURL := range conf.DatabaseReplicaURLs {
|
||||
r, err := pg.Open(replicaURL)
|
||||
if err != nil {
|
||||
// Close already-opened replicas before returning
|
||||
for _, opened := range replicas {
|
||||
opened.Close()
|
||||
}
|
||||
primary.Close()
|
||||
return nil, fmt.Errorf("failed to open database replica: %w", err)
|
||||
}
|
||||
replicas = append(replicas, r)
|
||||
}
|
||||
pool = db.NewDB(primary, replicas)
|
||||
}
|
||||
messageCache, err := createMessageCache(conf, pool)
|
||||
if err != nil {
|
||||
@@ -277,7 +291,7 @@ func New(conf *Config) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
|
||||
func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
|
||||
if conf.CacheDuration == 0 {
|
||||
return message.NewNopStore()
|
||||
} else if pool != nil {
|
||||
|
||||
@@ -49,7 +49,7 @@ var (
|
||||
// Manager handles user authentication, authorization, and management
|
||||
type Manager struct {
|
||||
config *Config
|
||||
db *sql.DB
|
||||
db *db.DB
|
||||
queries queries
|
||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
|
||||
@@ -58,7 +58,7 @@ type Manager struct {
|
||||
|
||||
var _ Auther = (*Manager)(nil)
|
||||
|
||||
func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
|
||||
func newManager(d *db.DB, queries queries, config *Config) (*Manager, error) {
|
||||
if config.BcryptCost <= 0 {
|
||||
config.BcryptCost = DefaultUserPasswordBcryptCost
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
|
||||
}
|
||||
manager := &Manager{
|
||||
config: config,
|
||||
db: db,
|
||||
db: d,
|
||||
statsQueue: make(map[string]*Stats),
|
||||
tokenQueue: make(map[string]*TokenUpdate),
|
||||
queries: queries,
|
||||
@@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) User(username string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByName, username)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) {
|
||||
|
||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByID(id string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByID, id)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) {
|
||||
|
||||
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||
func (a *Manager) userByToken(token string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -415,7 +415,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -425,7 +425,7 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
// Users returns a list of users. It loads all users in a single query
|
||||
// rather than one query per user to avoid N+1 performance issues.
|
||||
func (a *Manager) Users() ([]*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUsers)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -434,7 +434,7 @@ func (a *Manager) Users() ([]*User, error) {
|
||||
|
||||
// UsersCount returns the number of users in the database
|
||||
func (a *Manager) UsersCount() (int64, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserCount)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
|
||||
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
|
||||
// - It also prioritizes write permissions over read permissions
|
||||
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
@@ -660,7 +660,7 @@ func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read,
|
||||
|
||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||
func (a *Manager) AllGrants() (map[string][]Grant, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserAllAccess)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAllAccess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -688,7 +688,7 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
|
||||
|
||||
// Grants returns all user-specific access control entries
|
||||
func (a *Manager) Grants(username string) ([]Grant, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserAccess, username)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAccess, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -753,7 +753,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -796,7 +796,7 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserReservationsCount, username)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
|
||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -962,7 +962,7 @@ func (a *Manager) canChangeToken(userID, token string) error {
|
||||
|
||||
// Token returns a specific token for a user
|
||||
func (a *Manager) Token(userID, token string) (*Token, error) {
|
||||
rows, err := a.db.Query(a.queries.selectToken, userID, token)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectToken, userID, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -972,7 +972,7 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
||||
|
||||
// Tokens returns all existing tokens for the user with the given user ID
|
||||
func (a *Manager) Tokens(userID string) ([]*Token, error) {
|
||||
rows, err := a.db.Query(a.queries.selectTokens, userID)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTokens, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -991,7 +991,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
|
||||
}
|
||||
|
||||
func (a *Manager) allProvisionedTokens() ([]*Token, error) {
|
||||
rows, err := a.db.Query(a.queries.selectAllProvisionedTokens)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectAllProvisionedTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1114,7 +1114,7 @@ func (a *Manager) RemoveTier(code string) error {
|
||||
|
||||
// Tiers returns a list of all Tier structs
|
||||
func (a *Manager) Tiers() ([]*Tier, error) {
|
||||
rows, err := a.db.Query(a.queries.selectTiers)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTiers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1134,7 +1134,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
rows, err := a.db.Query(a.queries.selectTierByCode, code)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1144,7 +1144,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1185,7 +1185,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
|
||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||
func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
||||
rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectPhoneNumbers, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// PostgreSQL queries
|
||||
@@ -278,9 +278,9 @@ var postgresQueries = queries{
|
||||
}
|
||||
|
||||
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
|
||||
func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) {
|
||||
if err := setupPostgres(db); err != nil {
|
||||
func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) {
|
||||
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newManager(db, postgresQueries, config)
|
||||
return newManager(d, postgresQueries, config)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -280,15 +281,15 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
sqlDB, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db); err != nil {
|
||||
if err := setupSQLite(sqlDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newManager(db, sqliteQueries, config)
|
||||
return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
dbtest "heckel.io/ntfy/v2/db/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
|
||||
f(t, func(config *Config) *Manager {
|
||||
pool, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
a, err := NewPostgresManager(pool, config)
|
||||
a, err := NewPostgresManager(db.NewDB(pool, nil), config)
|
||||
require.Nil(t, err)
|
||||
return a
|
||||
})
|
||||
@@ -1734,8 +1735,8 @@ func TestMigrationFrom4(t *testing.T) {
|
||||
require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
|
||||
}
|
||||
|
||||
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
func checkSchemaVersion(t *testing.T, d *db.DB) {
|
||||
rows, err := d.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
|
||||
@@ -1771,7 +1772,7 @@ func newTestManagerFromConfig(t *testing.T, newManager newManagerFunc, conf *Con
|
||||
return a
|
||||
}
|
||||
|
||||
func testDB(a *Manager) *sql.DB {
|
||||
func testDB(a *Manager) *db.DB {
|
||||
return a.db
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ var (
|
||||
|
||||
// Store holds the database connection and queries for web push subscriptions.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
db *db.DB
|
||||
queries queries
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string,
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
||||
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
||||
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
ntfydb "heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -73,12 +73,12 @@ const (
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
||||
func NewPostgresStore(db *sql.DB) (*Store, error) {
|
||||
if err := setupPostgres(db); err != nil {
|
||||
func NewPostgresStore(d *ntfydb.DB) (*Store, error) {
|
||||
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Store{
|
||||
db: db,
|
||||
db: d,
|
||||
queries: queries{
|
||||
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
||||
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
||||
@@ -110,7 +110,7 @@ func setupPostgres(db *sql.DB) error {
|
||||
}
|
||||
|
||||
func setupNewPostgres(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
return ntfydb.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -79,18 +79,18 @@ const (
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed web push store.
|
||||
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
sqlDB, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db); err != nil {
|
||||
if err := setupSQLite(sqlDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Store{
|
||||
db: db,
|
||||
db: db.NewDB(sqlDB, nil),
|
||||
queries: queries{
|
||||
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
|
||||
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,
|
||||
|
||||
Reference in New Issue
Block a user