mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
WIP: Postgres read-only replica
This commit is contained in:
@@ -40,6 +40,7 @@ var flagsServe = append(
|
|||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
|
||||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
|
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
|
||||||
@@ -145,6 +146,7 @@ func execServe(c *cli.Context) error {
|
|||||||
certFile := c.String("cert-file")
|
certFile := c.String("cert-file")
|
||||||
firebaseKeyFile := c.String("firebase-key-file")
|
firebaseKeyFile := c.String("firebase-key-file")
|
||||||
databaseURL := c.String("database-url")
|
databaseURL := c.String("database-url")
|
||||||
|
databaseReplicaURLs := c.StringSlice("database-replica-urls")
|
||||||
webPushPrivateKey := c.String("web-push-private-key")
|
webPushPrivateKey := c.String("web-push-private-key")
|
||||||
webPushPublicKey := c.String("web-push-public-key")
|
webPushPublicKey := c.String("web-push-public-key")
|
||||||
webPushFile := c.String("web-push-file")
|
webPushFile := c.String("web-push-file")
|
||||||
@@ -282,7 +284,9 @@ func execServe(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check values
|
// Check values
|
||||||
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
if len(databaseReplicaURLs) > 0 && databaseURL == "" {
|
||||||
|
return errors.New("database-replica-urls can only be used if database-url is also set")
|
||||||
|
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||||
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||||
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||||
return errors.New("if set, FCM key file must exist")
|
return errors.New("if set, FCM key file must exist")
|
||||||
@@ -502,6 +506,7 @@ func execServe(c *cli.Context) error {
|
|||||||
conf.MetricsListenHTTP = metricsListenHTTP
|
conf.MetricsListenHTTP = metricsListenHTTP
|
||||||
conf.ProfileListenHTTP = profileListenHTTP
|
conf.ProfileListenHTTP = profileListenHTTP
|
||||||
conf.DatabaseURL = databaseURL
|
conf.DatabaseURL = databaseURL
|
||||||
|
conf.DatabaseReplicaURLs = databaseReplicaURLs
|
||||||
conf.WebPushPrivateKey = webPushPrivateKey
|
conf.WebPushPrivateKey = webPushPrivateKey
|
||||||
conf.WebPushPublicKey = webPushPublicKey
|
conf.WebPushPublicKey = webPushPublicKey
|
||||||
conf.WebPushFile = webPushFile
|
conf.WebPushFile = webPushFile
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/db/pg"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/server"
|
"heckel.io/ntfy/v2/server"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
@@ -383,7 +384,7 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
if dbErr != nil {
|
if dbErr != nil {
|
||||||
return nil, dbErr
|
return nil, dbErr
|
||||||
}
|
}
|
||||||
return user.NewPostgresManager(pool, authConfig)
|
return user.NewPostgresManager(db.NewDB(pool, nil), authConfig)
|
||||||
} else if authFile != "" {
|
} else if authFile != "" {
|
||||||
if !util.FileExists(authFile) {
|
if !util.FileExists(authFile) {
|
||||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||||
|
|||||||
122
db/db.go
122
db/db.go
@@ -2,11 +2,129 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
replicaHealthCheckInterval = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Beginner is an interface for types that can begin a database transaction.
|
||||||
|
// Both *sql.DB and *DB implement this.
|
||||||
|
type Beginner interface {
|
||||||
|
Begin() (*sql.Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods
|
||||||
|
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
|
||||||
|
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
|
||||||
|
type DB struct {
|
||||||
|
primary *sql.DB
|
||||||
|
replicas []*replica
|
||||||
|
counter atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type replica struct {
|
||||||
|
db *sql.DB
|
||||||
|
healthy atomic.Bool
|
||||||
|
lastChecked atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDB creates a new DB that wraps the given primary and optional replica connections.
|
||||||
|
// If replicas is nil or empty, ReadOnly() simply returns the primary.
|
||||||
|
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
|
||||||
|
d := &DB{
|
||||||
|
primary: primary,
|
||||||
|
replicas: make([]*replica, len(replicas)),
|
||||||
|
}
|
||||||
|
for i, r := range replicas {
|
||||||
|
rep := &replica{db: r}
|
||||||
|
rep.healthy.Store(true)
|
||||||
|
d.replicas[i] = rep
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query delegates to the primary database.
|
||||||
|
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||||
|
return d.primary.Query(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRow delegates to the primary database.
|
||||||
|
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||||
|
return d.primary.QueryRow(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec delegates to the primary database.
|
||||||
|
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||||
|
return d.primary.Exec(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin delegates to the primary database.
|
||||||
|
func (d *DB) Begin() (*sql.Tx, error) {
|
||||||
|
return d.primary.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping delegates to the primary database.
|
||||||
|
func (d *DB) Ping() error {
|
||||||
|
return d.primary.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the primary database and all replicas.
|
||||||
|
func (d *DB) Close() error {
|
||||||
|
for _, r := range d.replicas {
|
||||||
|
r.db.Close()
|
||||||
|
}
|
||||||
|
return d.primary.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupPrimary returns the underlying primary *sql.DB. This is only intended for
|
||||||
|
// one-time schema setup during store initialization, not for regular queries.
|
||||||
|
func (d *DB) SetupPrimary() *sql.DB {
|
||||||
|
return d.primary
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
|
||||||
|
// replicas. If a replica's health status is stale (older than replicaHealthCheckInterval), it
|
||||||
|
// is re-checked with a ping. If all replicas are unhealthy or none are configured, the primary
|
||||||
|
// is returned.
|
||||||
|
func (d *DB) ReadOnly() *sql.DB {
|
||||||
|
if len(d.replicas) == 0 {
|
||||||
|
return d.primary
|
||||||
|
}
|
||||||
|
n := len(d.replicas)
|
||||||
|
start := int(d.counter.Add(1) - 1)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
r := d.replicas[(start+i)%n]
|
||||||
|
if d.isHealthy(r) {
|
||||||
|
return r.db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d.primary
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHealthy returns whether the replica is healthy. If the cached health status is stale,
|
||||||
|
// it pings the replica and updates the cache.
|
||||||
|
func (d *DB) isHealthy(r *replica) bool {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
lastChecked := r.lastChecked.Load()
|
||||||
|
if now-lastChecked >= int64(replicaHealthCheckInterval.Seconds()) {
|
||||||
|
if r.lastChecked.CompareAndSwap(lastChecked, now) {
|
||||||
|
if err := r.db.Ping(); err != nil {
|
||||||
|
r.healthy.Store(false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r.healthy.Store(true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.healthy.Load()
|
||||||
|
}
|
||||||
|
|
||||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
func ExecTx(db Beginner, f func(tx *sql.Tx) error) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -20,7 +138,7 @@ func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|||||||
|
|
||||||
// QueryTx executes a function within a database transaction and returns the result. If the function
|
// QueryTx executes a function within a database transaction and returns the result. If the function
|
||||||
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var zero T
|
var zero T
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package dbtest
|
package dbtest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/db/pg"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -48,16 +48,17 @@ func CreateTestPostgresSchema(t *testing.T) string {
|
|||||||
return schemaDSN
|
return schemaDSN
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
|
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it.
|
||||||
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
||||||
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||||
func CreateTestPostgres(t *testing.T) *sql.DB {
|
func CreateTestPostgres(t *testing.T) *db.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
schemaDSN := CreateTestPostgresSchema(t)
|
schemaDSN := CreateTestPostgresSchema(t)
|
||||||
testDB, err := pg.Open(schemaDSN)
|
testDB, err := pg.Open(schemaDSN)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
d := db.NewDB(testDB, nil)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
testDB.Close()
|
d.Close()
|
||||||
})
|
})
|
||||||
return testDB
|
return d
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,14 +50,14 @@ type queries struct {
|
|||||||
|
|
||||||
// Cache stores published messages
|
// Cache stores published messages
|
||||||
type Cache struct {
|
type Cache struct {
|
||||||
db *sql.DB
|
db *db.DB
|
||||||
queue *util.BatchingQueue[*model.Message]
|
queue *util.BatchingQueue[*model.Message]
|
||||||
nop bool
|
nop bool
|
||||||
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
||||||
queries queries
|
queries queries
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||||
var queue *util.BatchingQueue[*model.Message]
|
var queue *util.BatchingQueue[*model.Message]
|
||||||
if batchSize > 0 || batchTimeout > 0 {
|
if batchSize > 0 || batchTimeout > 0 {
|
||||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||||
@@ -201,10 +201,11 @@ func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool)
|
|||||||
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
rdb := c.db.ReadOnly()
|
||||||
if scheduled {
|
if scheduled {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||||
} else {
|
} else {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -215,10 +216,11 @@ func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, schedul
|
|||||||
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
rdb := c.db.ReadOnly()
|
||||||
if scheduled {
|
if scheduled {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
||||||
} else {
|
} else {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -227,7 +229,7 @@ func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -266,7 +268,7 @@ func (c *Cache) MessagesExpired() ([]string, error) {
|
|||||||
|
|
||||||
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||||
func (c *Cache) Message(id string) (*model.Message, error) {
|
func (c *Cache) Message(id string) (*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -295,7 +297,7 @@ func (c *Cache) MarkPublished(m *model.Message) error {
|
|||||||
|
|
||||||
// MessagesCount returns the total number of messages in the cache
|
// MessagesCount returns the total number of messages in the cache
|
||||||
func (c *Cache) MessagesCount() (int, error) {
|
func (c *Cache) MessagesCount() (int, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -312,7 +314,7 @@ func (c *Cache) MessagesCount() (int, error) {
|
|||||||
|
|
||||||
// Topics returns a list of all topics with messages in the cache
|
// Topics returns a list of all topics with messages in the cache
|
||||||
func (c *Cache) Topics() ([]string, error) {
|
func (c *Cache) Topics() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectTopics)
|
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -426,7 +428,7 @@ func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
|||||||
|
|
||||||
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||||
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -435,7 +437,7 @@ func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
|||||||
|
|
||||||
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
||||||
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -466,7 +468,7 @@ func (c *Cache) UpdateStats(messages int64) error {
|
|||||||
|
|
||||||
// Stats returns the total message count statistic
|
// Stats returns the total message count statistic
|
||||||
func (c *Cache) Stats() (messages int64, err error) {
|
func (c *Cache) Stats() (messages int64, err error) {
|
||||||
rows, err := c.db.Query(c.queries.selectStats)
|
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package message
|
package message
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostgreSQL runtime query constants
|
// PostgreSQL runtime query constants
|
||||||
@@ -102,9 +103,9 @@ var postgresQueries = queries{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||||
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||||
if err := setupPostgres(db); err != nil {
|
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newCache(db, postgresQueries, nil, batchSize, batchTimeout, false), nil
|
return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,14 +111,14 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
|
|||||||
if !util.FileExists(parentDir) {
|
if !util.FileExists(parentDir) {
|
||||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||||
}
|
}
|
||||||
db, err := sql.Open("sqlite3", filename)
|
sqlDB, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemStore creates an in-memory cache
|
// NewMemStore creates an in-memory cache
|
||||||
|
|||||||
@@ -95,7 +95,8 @@ type Config struct {
|
|||||||
ListenUnixMode fs.FileMode
|
ListenUnixMode fs.FileMode
|
||||||
KeyFile string
|
KeyFile string
|
||||||
CertFile string
|
CertFile string
|
||||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||||
|
DatabaseReplicaURLs []string // PostgreSQL read replica connection strings
|
||||||
FirebaseKeyFile string
|
FirebaseKeyFile string
|
||||||
CacheFile string
|
CacheFile string
|
||||||
CacheDuration time.Duration
|
CacheDuration time.Duration
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/db/pg"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
@@ -47,7 +48,7 @@ import (
|
|||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
|
db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
httpsServer *http.Server
|
httpsServer *http.Server
|
||||||
httpMetricsServer *http.Server
|
httpMetricsServer *http.Server
|
||||||
@@ -179,13 +180,26 @@ func New(conf *Config) (*Server, error) {
|
|||||||
stripe = newStripeAPI()
|
stripe = newStripeAPI()
|
||||||
}
|
}
|
||||||
// Open shared PostgreSQL connection pool if configured
|
// Open shared PostgreSQL connection pool if configured
|
||||||
var pool *sql.DB
|
var pool *db.DB
|
||||||
if conf.DatabaseURL != "" {
|
if conf.DatabaseURL != "" {
|
||||||
var err error
|
primary, err := pg.Open(conf.DatabaseURL)
|
||||||
pool, err = pg.Open(conf.DatabaseURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
var replicas []*sql.DB
|
||||||
|
for _, replicaURL := range conf.DatabaseReplicaURLs {
|
||||||
|
r, err := pg.Open(replicaURL)
|
||||||
|
if err != nil {
|
||||||
|
// Close already-opened replicas before returning
|
||||||
|
for _, opened := range replicas {
|
||||||
|
opened.Close()
|
||||||
|
}
|
||||||
|
primary.Close()
|
||||||
|
return nil, fmt.Errorf("failed to open database replica: %w", err)
|
||||||
|
}
|
||||||
|
replicas = append(replicas, r)
|
||||||
|
}
|
||||||
|
pool = db.NewDB(primary, replicas)
|
||||||
}
|
}
|
||||||
messageCache, err := createMessageCache(conf, pool)
|
messageCache, err := createMessageCache(conf, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -277,7 +291,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
|
func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return message.NewNopStore()
|
return message.NewNopStore()
|
||||||
} else if pool != nil {
|
} else if pool != nil {
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ var (
|
|||||||
// Manager handles user authentication, authorization, and management
|
// Manager handles user authentication, authorization, and management
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
config *Config
|
config *Config
|
||||||
db *sql.DB
|
db *db.DB
|
||||||
queries queries
|
queries queries
|
||||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||||
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
|
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
|
||||||
@@ -58,7 +58,7 @@ type Manager struct {
|
|||||||
|
|
||||||
var _ Auther = (*Manager)(nil)
|
var _ Auther = (*Manager)(nil)
|
||||||
|
|
||||||
func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
|
func newManager(d *db.DB, queries queries, config *Config) (*Manager, error) {
|
||||||
if config.BcryptCost <= 0 {
|
if config.BcryptCost <= 0 {
|
||||||
config.BcryptCost = DefaultUserPasswordBcryptCost
|
config.BcryptCost = DefaultUserPasswordBcryptCost
|
||||||
}
|
}
|
||||||
@@ -67,7 +67,7 @@ func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) {
|
|||||||
}
|
}
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
config: config,
|
config: config,
|
||||||
db: db,
|
db: d,
|
||||||
statsQueue: make(map[string]*Stats),
|
statsQueue: make(map[string]*Stats),
|
||||||
tokenQueue: make(map[string]*TokenUpdate),
|
tokenQueue: make(map[string]*TokenUpdate),
|
||||||
queries: queries,
|
queries: queries,
|
||||||
@@ -388,7 +388,7 @@ func (a *Manager) writeUserStatsQueue() error {
|
|||||||
|
|
||||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||||
func (a *Manager) User(username string) (*User, error) {
|
func (a *Manager) User(username string) (*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserByName, username)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByName, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -397,7 +397,7 @@ func (a *Manager) User(username string) (*User, error) {
|
|||||||
|
|
||||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||||
func (a *Manager) UserByID(id string) (*User, error) {
|
func (a *Manager) UserByID(id string) (*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserByID, id)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByID, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -406,7 +406,7 @@ func (a *Manager) UserByID(id string) (*User, error) {
|
|||||||
|
|
||||||
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
// userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||||
func (a *Manager) userByToken(token string) (*User, error) {
|
func (a *Manager) userByToken(token string) (*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByToken, token, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -415,7 +415,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
|||||||
|
|
||||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||||
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -425,7 +425,7 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
|||||||
// Users returns a list of users. It loads all users in a single query
|
// Users returns a list of users. It loads all users in a single query
|
||||||
// rather than one query per user to avoid N+1 performance issues.
|
// rather than one query per user to avoid N+1 performance issues.
|
||||||
func (a *Manager) Users() ([]*User, error) {
|
func (a *Manager) Users() ([]*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUsers)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func (a *Manager) Users() ([]*User, error) {
|
|||||||
|
|
||||||
// UsersCount returns the number of users in the database
|
// UsersCount returns the number of users in the database
|
||||||
func (a *Manager) UsersCount() (int64, error) {
|
func (a *Manager) UsersCount() (int64, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserCount)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -642,7 +642,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
|
|||||||
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
|
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
|
||||||
// - It also prioritizes write permissions over read permissions
|
// - It also prioritizes write permissions over read permissions
|
||||||
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||||
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, false, err
|
return false, false, false, err
|
||||||
}
|
}
|
||||||
@@ -660,7 +660,7 @@ func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read,
|
|||||||
|
|
||||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||||
func (a *Manager) AllGrants() (map[string][]Grant, error) {
|
func (a *Manager) AllGrants() (map[string][]Grant, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserAllAccess)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAllAccess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -688,7 +688,7 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
|
|||||||
|
|
||||||
// Grants returns all user-specific access control entries
|
// Grants returns all user-specific access control entries
|
||||||
func (a *Manager) Grants(username string) ([]Grant, error) {
|
func (a *Manager) Grants(username string) ([]Grant, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserAccess, username)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserAccess, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -753,7 +753,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
|||||||
|
|
||||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||||
func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -779,7 +779,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
|||||||
|
|
||||||
// HasReservation returns true if the given topic access is owned by the user
|
// HasReservation returns true if the given topic access is owned by the user
|
||||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -796,7 +796,7 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
|||||||
|
|
||||||
// ReservationsCount returns the number of reservations owned by this user
|
// ReservationsCount returns the number of reservations owned by this user
|
||||||
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserReservationsCount, username)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -813,7 +813,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
|||||||
|
|
||||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||||
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -830,7 +830,7 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
|||||||
|
|
||||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -962,7 +962,7 @@ func (a *Manager) canChangeToken(userID, token string) error {
|
|||||||
|
|
||||||
// Token returns a specific token for a user
|
// Token returns a specific token for a user
|
||||||
func (a *Manager) Token(userID, token string) (*Token, error) {
|
func (a *Manager) Token(userID, token string) (*Token, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectToken, userID, token)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectToken, userID, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -972,7 +972,7 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
|||||||
|
|
||||||
// Tokens returns all existing tokens for the user with the given user ID
|
// Tokens returns all existing tokens for the user with the given user ID
|
||||||
func (a *Manager) Tokens(userID string) ([]*Token, error) {
|
func (a *Manager) Tokens(userID string) ([]*Token, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectTokens, userID)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectTokens, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -991,7 +991,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) allProvisionedTokens() ([]*Token, error) {
|
func (a *Manager) allProvisionedTokens() ([]*Token, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectAllProvisionedTokens)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectAllProvisionedTokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1114,7 +1114,7 @@ func (a *Manager) RemoveTier(code string) error {
|
|||||||
|
|
||||||
// Tiers returns a list of all Tier structs
|
// Tiers returns a list of all Tier structs
|
||||||
func (a *Manager) Tiers() ([]*Tier, error) {
|
func (a *Manager) Tiers() ([]*Tier, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectTiers)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectTiers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1134,7 +1134,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
|||||||
|
|
||||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectTierByCode, code)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1144,7 +1144,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
|||||||
|
|
||||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1185,7 +1185,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
|
|||||||
|
|
||||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||||
func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID)
|
rows, err := a.db.ReadOnly().Query(a.queries.selectPhoneNumbers, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostgreSQL queries
|
// PostgreSQL queries
|
||||||
@@ -278,9 +278,9 @@ var postgresQueries = queries{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
|
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
|
||||||
func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) {
|
func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) {
|
||||||
if err := setupPostgres(db); err != nil {
|
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newManager(db, postgresQueries, config)
|
return newManager(d, postgresQueries, config)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -280,15 +281,15 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager
|
|||||||
if !util.FileExists(parentDir) {
|
if !util.FileExists(parentDir) {
|
||||||
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||||
}
|
}
|
||||||
db, err := sql.Open("sqlite3", filename)
|
sqlDB, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := setupSQLite(db); err != nil {
|
if err := setupSQLite(sqlDB); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newManager(db, sqliteQueries, config)
|
return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/db/pg"
|
"heckel.io/ntfy/v2/db/pg"
|
||||||
dbtest "heckel.io/ntfy/v2/db/test"
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
@@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
|
|||||||
f(t, func(config *Config) *Manager {
|
f(t, func(config *Config) *Manager {
|
||||||
pool, err := pg.Open(schemaDSN)
|
pool, err := pg.Open(schemaDSN)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
a, err := NewPostgresManager(pool, config)
|
a, err := NewPostgresManager(db.NewDB(pool, nil), config)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return a
|
return a
|
||||||
})
|
})
|
||||||
@@ -1734,8 +1735,8 @@ func TestMigrationFrom4(t *testing.T) {
|
|||||||
require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
|
require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
func checkSchemaVersion(t *testing.T, d *db.DB) {
|
||||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
rows, err := d.Query(`SELECT version FROM schemaVersion`)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.True(t, rows.Next())
|
require.True(t, rows.Next())
|
||||||
|
|
||||||
@@ -1771,7 +1772,7 @@ func newTestManagerFromConfig(t *testing.T, newManager newManagerFunc, conf *Con
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func testDB(a *Manager) *sql.DB {
|
func testDB(a *Manager) *db.DB {
|
||||||
return a.db
|
return a.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ var (
|
|||||||
|
|
||||||
// Store holds the database connection and queries for web push subscriptions.
|
// Store holds the database connection and queries for web push subscriptions.
|
||||||
type Store struct {
|
type Store struct {
|
||||||
db *sql.DB
|
db *db.DB
|
||||||
queries queries
|
queries queries
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string,
|
|||||||
|
|
||||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||||
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
|||||||
|
|
||||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
||||||
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
rows, err := s.db.ReadOnly().Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/db"
|
ntfydb "heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -73,12 +73,12 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
||||||
func NewPostgresStore(db *sql.DB) (*Store, error) {
|
func NewPostgresStore(d *ntfydb.DB) (*Store, error) {
|
||||||
if err := setupPostgres(db); err != nil {
|
if err := setupPostgres(d.SetupPrimary()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Store{
|
return &Store{
|
||||||
db: db,
|
db: d,
|
||||||
queries: queries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
||||||
@@ -110,7 +110,7 @@ func setupPostgres(db *sql.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupNewPostgres(sqlDB *sql.DB) error {
|
func setupNewPostgres(sqlDB *sql.DB) error {
|
||||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
return ntfydb.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,18 +79,18 @@ const (
|
|||||||
|
|
||||||
// NewSQLiteStore creates a new SQLite-backed web push store.
|
// NewSQLiteStore creates a new SQLite-backed web push store.
|
||||||
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
sqlDB, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := setupSQLite(db); err != nil {
|
if err := setupSQLite(sqlDB); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Store{
|
return &Store{
|
||||||
db: db,
|
db: db.NewDB(sqlDB, nil),
|
||||||
queries: queries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
|
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,
|
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,
|
||||||
|
|||||||
Reference in New Issue
Block a user