Compare commits

...

4 Commits

Author SHA1 Message Date
binwiederhier
8b9f23f2e0 Derp 2025-12-30 22:01:57 -05:00
binwiederhier
4bf4899c08 Switch lib 2025-12-30 21:33:22 -05:00
binwiederhier
c6d9559a6c Merge branch 'main' of github.com:binwiederhier/ntfy into postgres 2025-12-30 20:55:57 -05:00
binwiederhier
e595d3ae15 Postgres 2025-12-27 21:55:30 -05:00
18 changed files with 1980 additions and 1294 deletions

4
go.mod
View File

@@ -32,6 +32,7 @@ require github.com/pkg/errors v0.9.1 // indirect
require ( require (
firebase.google.com/go/v4 v4.18.0 firebase.google.com/go/v4 v4.18.0
github.com/SherClockHolmes/webpush-go v1.4.0 github.com/SherClockHolmes/webpush-go v1.4.0
github.com/jackc/pgx/v5 v5.8.0
github.com/microcosm-cc/bluemonday v1.0.27 github.com/microcosm-cc/bluemonday v1.0.27
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/stripe/stripe-go/v74 v74.30.0 github.com/stripe/stripe-go/v74 v74.30.0
@@ -72,6 +73,9 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.16.0 // indirect github.com/googleapis/gax-go/v2 v2.16.0 // indirect
github.com/gorilla/css v1.0.1 // indirect github.com/gorilla/css v1.0.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

9
go.sum
View File

@@ -104,6 +104,14 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -144,6 +152,7 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=

View File

@@ -4,351 +4,89 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"net/netip" "net/netip"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
) )
var ( // MessageCache is the interface for message storage backends
errUnexpectedMessageType = errors.New("unexpected message type") type MessageCache interface {
errMessageNotFound = errors.New("message not found") AddMessage(m *message) error
errNoRows = errors.New("no rows found") AddMessages(ms []*message) error
) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error)
MessagesDue() ([]*message, error)
// Messages cache MessagesExpired() ([]string, error)
const ( Message(id string) (*message, error)
createMessagesTableQuery = ` MarkPublished(m *message) error
BEGIN; MessageCounts() (map[string]int, error)
CREATE TABLE IF NOT EXISTS messages ( Topics() (map[string]*topic, error)
id INTEGER PRIMARY KEY AUTOINCREMENT, DeleteMessages(ids ...string) error
mid TEXT NOT NULL, ExpireMessages(topics ...string) error
time INT NOT NULL, AttachmentsExpired() ([]string, error)
expires INT NOT NULL, MarkAttachmentsDeleted(ids ...string) error
topic TEXT NOT NULL, AttachmentBytesUsedBySender(sender string) (int64, error)
message TEXT NOT NULL, AttachmentBytesUsedByUser(userID string) (int64, error)
title TEXT NOT NULL, UpdateStats(messages int64) error
priority INT NOT NULL, Stats() (messages int64, err error)
tags TEXT NOT NULL, DB() *sql.DB
click TEXT NOT NULL, Close() error
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
insertMessageQuery = `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`
selectMessagesSinceIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`
selectMessagesLatestQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`
selectMessagesDueQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`
selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
)
// Schema management queries
const (
currentSchemaVersion = 13
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 0 -> 1
migrate0To1AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 1 -> 2
migrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
migrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 3 -> 4
migrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
migrate4To5AlterMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
COMMIT;
`
// 5 -> 6
migrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
migrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
migrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
migrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
migrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
migrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
)
var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
10: migrateFrom10,
11: migrateFrom11,
12: migrateFrom12,
}
)
type messageCache struct {
db *sql.DB
queue *util.BatchingQueue[*message]
nop bool
mu sync.Mutex
} }
// newSqliteCache creates a SQLite file-backed cache // commonMessageCache contains shared logic for all message cache implementations
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { type commonMessageCache struct {
// Check the parent directory of the database file (makes for friendly error messages) db *sql.DB
parentDir := filepath.Dir(filename) queue *util.BatchingQueue[*message]
if !util.FileExists(parentDir) { queries *messageCacheQueries
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) nop bool // If true, cache ignores all messages
} mu sync.Mutex // Lock for concurrent access
// Open database
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &messageCache{
db: db,
queue: queue,
nop: nop,
}
go cache.processMessageBatches()
return cache, nil
} }
// newMemCache creates an in-memory cache var _ MessageCache = (*commonMessageCache)(nil)
func newMemCache() (*messageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
}
// newNopCache creates an in-memory cache that discards all messages; // messageCacheQueries holds database-specific SQL queries
// it is always empty and can be used if caching is entirely disabled type messageCacheQueries struct {
func newNopCache() (*messageCache, error) { insertMessage string
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true) deleteMessage string
} updateMessagesForTopicExpiry string
selectRowIDFromMessageID string // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeIncludeScheduled string
selectMessagesSinceID string
selectMessagesSinceIDIncludeScheduled string
selectMessagesLatest string
selectMessagesDue string
selectMessagesExpired string
updateMessagePublished string
selectMessageCountPerTopic string
selectTopics string
// createMemoryFilename creates a unique memory filename to use for the SQLite backend. updateAttachmentDeleted string
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory selectAttachmentsExpired string
// sql database, so if the stdlib's sql engine happens to open another connection and selectAttachmentsSizeBySender string
// you've only specified ":memory:", that connection will see a brand new database. selectAttachmentsSizeByUserID string
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database." selectStats string
func createMemoryFilename() string { updateStats string
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
} }
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. // AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. // The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *messageCache) AddMessage(m *message) error { func (c *commonMessageCache) AddMessage(m *message) error {
if c.queue != nil { if c.queue != nil {
c.queue.Enqueue(m) c.queue.Enqueue(m)
return nil return nil
} }
return c.addMessages([]*message{m}) return c.AddMessages([]*message{m})
} }
// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until // AddMessages synchronously stores a batch of messages. If the database is locked, the transaction waits until
// SQLite's busy_timeout is exceeded before erroring out. // the timeout is exceeded before erroring out.
func (c *messageCache) addMessages(ms []*message) error { func (c *commonMessageCache) AddMessages(ms []*message) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.nop { if c.nop {
@@ -363,7 +101,7 @@ func (c *messageCache) addMessages(ms []*message) error {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
stmt, err := tx.Prepare(insertMessageQuery) stmt, err := tx.Prepare(c.queries.insertMessage)
if err != nil { if err != nil {
return err return err
} }
@@ -375,7 +113,8 @@ func (c *messageCache) addMessages(ms []*message) error {
published := m.Time <= time.Now().Unix() published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",") tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL string var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires, attachmentDeleted int64 var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
if m.Attachment != nil { if m.Attachment != nil {
attachmentName = m.Attachment.Name attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type attachmentType = m.Attachment.Type
@@ -412,7 +151,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentSize, attachmentSize,
attachmentExpires, attachmentExpires,
attachmentURL, attachmentURL,
attachmentDeleted, // Always zero attachmentDeleted, // Always false
sender, sender,
m.User, m.User,
m.ContentType, m.ContentType,
@@ -431,7 +170,7 @@ func (c *messageCache) addMessages(ms []*message) error {
return nil return nil
} }
func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
if since.IsNone() { if since.IsNone() {
return make([]*message, 0), nil return make([]*message, 0), nil
} else if since.IsLatest() { } else if since.IsLatest() {
@@ -442,13 +181,21 @@ func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool)
return c.messagesSinceTime(topic, since, scheduled) return c.messagesSinceTime(topic, since, scheduled)
} }
func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) { func (c *commonMessageCache) messagesLatest(topic string) ([]*message, error) {
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if scheduled { if scheduled {
rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) rows, err = c.db.Query(c.queries.selectMessagesSinceTimeIncludeScheduled, topic, since.Time().Unix())
} else { } else {
rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -456,8 +203,8 @@ func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, schedu
return readMessages(rows) return readMessages(rows)
} }
func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID()) idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -472,9 +219,9 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule
idrows.Close() idrows.Close()
var rows *sql.Rows var rows *sql.Rows
if scheduled { if scheduled {
rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID) rows, err = c.db.Query(c.queries.selectMessagesSinceIDIncludeScheduled, topic, rowID)
} else { } else {
rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID) rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -482,25 +229,17 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule
return readMessages(rows) return readMessages(rows)
} }
func (c *messageCache) messagesLatest(topic string) ([]*message, error) { func (c *commonMessageCache) MessagesDue() ([]*message, error) {
rows, err := c.db.Query(selectMessagesLatestQuery, topic) rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return readMessages(rows) return readMessages(rows)
} }
func (c *messageCache) MessagesDue() ([]*message, error) { // MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix()) func (c *commonMessageCache) MessagesExpired() ([]string, error) {
if err != nil { rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
return nil, err
}
return readMessages(rows)
}
// MessagesExpired returns a list of IDs for messages that have expires (should be deleted)
func (c *messageCache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -519,27 +258,24 @@ func (c *messageCache) MessagesExpired() ([]string, error) {
return ids, nil return ids, nil
} }
func (c *messageCache) Message(id string) (*message, error) { func (c *commonMessageCache) Message(id string) (*message, error) {
rows, err := c.db.Query(selectMessagesByIDQuery, id) rows, err := c.db.Query(c.queries.selectMessagesByID, id)
if err != nil { if err != nil {
return nil, err return nil, err
} } else if !rows.Next() {
if !rows.Next() {
return nil, errMessageNotFound return nil, errMessageNotFound
} }
defer rows.Close() defer rows.Close()
return readMessage(rows) return readMessage(rows)
} }
func (c *messageCache) MarkPublished(m *message) error { func (c *commonMessageCache) MarkPublished(m *message) error {
c.mu.Lock() _, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
defer c.mu.Unlock()
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err return err
} }
func (c *messageCache) MessageCounts() (map[string]int, error) { func (c *commonMessageCache) MessageCounts() (map[string]int, error) {
rows, err := c.db.Query(selectMessageCountPerTopicQuery) rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -558,8 +294,8 @@ func (c *messageCache) MessageCounts() (map[string]int, error) {
return counts, nil return counts, nil
} }
func (c *messageCache) Topics() (map[string]*topic, error) { func (c *commonMessageCache) Topics() (map[string]*topic, error) {
rows, err := c.db.Query(selectTopicsQuery) rows, err := c.db.Query(c.queries.selectTopics)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -578,40 +314,36 @@ func (c *messageCache) Topics() (map[string]*topic, error) {
return topics, nil return topics, nil
} }
func (c *messageCache) DeleteMessages(ids ...string) error { func (c *commonMessageCache) DeleteMessages(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
for _, id := range ids { for _, id := range ids {
if _, err := tx.Exec(deleteMessageQuery, id); err != nil { if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err return err
} }
} }
return tx.Commit() return tx.Commit()
} }
func (c *messageCache) ExpireMessages(topics ...string) error { func (c *commonMessageCache) ExpireMessages(topics ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
for _, t := range topics { for _, t := range topics {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil { if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err return err
} }
} }
return tx.Commit() return tx.Commit()
} }
func (c *messageCache) AttachmentsExpired() ([]string, error) { func (c *commonMessageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -630,39 +362,37 @@ func (c *messageCache) AttachmentsExpired() ([]string, error) {
return ids, nil return ids, nil
} }
func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error { func (c *commonMessageCache) MarkAttachmentsDeleted(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
for _, id := range ids { for _, id := range ids {
if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil { if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err return err
} }
} }
return tx.Commit() return tx.Commit()
} }
func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { func (c *commonMessageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix()) rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return c.readAttachmentBytesUsed(rows) return c.readAttachmentBytesUsed(rows)
} }
func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) { func (c *commonMessageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix()) rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil { if err != nil {
return 0, err return 0, err
} }
return c.readAttachmentBytesUsed(rows) return c.readAttachmentBytesUsed(rows)
} }
func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close() defer rows.Close()
var size int64 var size int64
if !rows.Next() { if !rows.Next() {
@@ -676,17 +406,45 @@ func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
return size, nil return size, nil
} }
func (c *messageCache) processMessageBatches() { func (c *commonMessageCache) processMessageBatches() {
if c.queue == nil { if c.queue == nil {
return return
} }
for messages := range c.queue.Dequeue() { for messages := range c.queue.Dequeue() {
if err := c.addMessages(messages); err != nil { if err := c.AddMessages(messages); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch") log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
} }
} }
} }
func (c *commonMessageCache) UpdateStats(messages int64) error {
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
func (c *commonMessageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(c.queries.selectStats)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *commonMessageCache) DB() *sql.DB {
return c.db
}
func (c *commonMessageCache) Close() error {
return c.db.Close()
}
func readMessages(rows *sql.Rows) ([]*message, error) { func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close() defer rows.Close()
messages := make([]*message, 0) messages := make([]*message, 0)
@@ -776,257 +534,3 @@ func readMessage(rows *sql.Rows) (*message, error) {
Encoding: encoding, Encoding: encoding,
}, nil }, nil
} }
func (c *messageCache) UpdateStats(messages int64) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(updateStatsQuery, messages)
return err
}
func (c *messageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(selectStatsQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *messageCache) Close() error {
return c.db.Close()
}
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(selectMessagesCountQuery)
if err != nil {
return setupNewCacheDB(db)
}
rowsMC.Close()
// If 'messages' table exists, check 'schemaVersion' table
schemaVersion := 0
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err == nil {
defer rowsSV.Close()
if !rowsSV.Next() {
return errors.New("cannot determine schema version: cache file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
}
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewCacheDB(db *sql.DB) error {
if _, err := db.Exec(createMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom0(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return nil
}
func migrateFrom3(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return nil
}
func migrateFrom4(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return nil
}
func migrateFrom5(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return nil
}
func migrateFrom6(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
return err
}
return nil
}
func migrateFrom7(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
return err
}
return nil
}
func migrateFrom8(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
return err
}
return nil
}
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom12(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
return err
}
return tx.Commit()
}

193
server/message_cache_pg.go Normal file
View File

@@ -0,0 +1,193 @@
package server
import (
"database/sql"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/util"
)
// PostgreSQL schema version (starts at latest, no migrations needed initially)
const pgCurrentSchemaVersion = 1
// Messages cache
const (
pgCreateMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
mid TEXT NOT NULL,
time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted BOOLEAN NOT NULL,
sender TEXT NOT NULL,
"user" TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published BOOLEAN NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages ("user");
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion VALUES (1, 1);
COMMIT;
`
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
)
var (
pgMessageCacheQueries = &messageCacheQueries{
insertMessage: `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, "user", content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
`,
deleteMessage: `DELETE FROM messages WHERE mid = $1`,
updateMessagesForTopicExpiry: `UPDATE messages SET expires = $1 WHERE topic = $2`,
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = $1`,
selectMessagesByID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE mid = $1
`,
selectMessagesSinceTime: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`,
selectMessagesSinceTimeIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`,
selectMessagesSinceID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND id > $2 AND published = TRUE
ORDER BY time, id
`,
selectMessagesSinceIDIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND (id > $2 OR published = FALSE)
ORDER BY time, id
`,
selectMessagesLatest: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`,
selectMessagesDue: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`,
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= $1 AND published = TRUE`,
updateMessagePublished: `UPDATE messages SET published = TRUE WHERE mid = $1`,
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = TRUE WHERE mid = $1`,
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`,
selectAttachmentsSizeBySender: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = '' AND sender = $1 AND attachment_expires >= $2`,
selectAttachmentsSizeByUserID: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = $1 AND attachment_expires >= $2`,
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
updateStats: `UPDATE stats SET value = $1 WHERE key = 'messages'`,
}
)
type pgMessageCache struct {
*commonMessageCache
}
var _ MessageCache = (*pgMessageCache)(nil)
// newPgCache creates a PostgreSQL-backed message cache
func newPgCache(connectionString, startupQueries string, batchSize int, batchTimeout time.Duration) (*pgMessageCache, error) {
db, err := sql.Open("pgx", connectionString)
if err != nil {
return nil, err
}
if err := setupPgMessagesDB(db, startupQueries); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &pgMessageCache{
commonMessageCache: &commonMessageCache{
db: db,
queue: queue,
queries: pgMessageCacheQueries,
},
}
go cache.processMessageBatches()
return cache, nil
}
func setupPgMessagesDB(db *sql.DB, startupQueries string) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(pgSelectMessagesCountQuery)
if err != nil {
return setupNewPgCacheDB(db)
}
rowsMC.Close()
// Future: Add migrations here when schema changes
return nil
}
func setupNewPgCacheDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateMessagesTableQuery); err != nil {
return err
}
return nil
}
// isPostgres checks if the connection string indicates a PostgreSQL database
func isPostgres(connectionString string) bool {
return strings.HasPrefix(connectionString, "postgres:")
}

View File

@@ -0,0 +1,571 @@
package server
import (
"database/sql"
"errors"
"fmt"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
errNoRows = errors.New("no rows found")
)
// Messages cache
const (
createMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
)
var (
sqliteMessageCacheQueries = &messageCacheQueries{
insertMessage: `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
deleteMessage: `DELETE FROM messages WHERE mid = ?`,
updateMessagesForTopicExpiry: `UPDATE messages SET expires = ? WHERE topic = ?`,
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = ?`, // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`,
selectMessagesSinceTime: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`,
selectMessagesSinceTimeIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`,
selectMessagesSinceID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`,
selectMessagesSinceIDIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`,
selectMessagesLatest: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`,
selectMessagesDue: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`,
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= ? AND published = 1`,
updateMessagePublished: `UPDATE messages SET published = 1 WHERE mid = ?`,
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`,
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`,
selectAttachmentsSizeBySender: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`,
selectAttachmentsSizeByUserID: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`,
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
updateStats: `UPDATE stats SET value = ? WHERE key = 'messages'`,
}
)
// Schema management queries
const (
currentSchemaVersion = 13
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
// 0 -> 1
migrate0To1AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 1 -> 2
migrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
migrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 3 -> 4
migrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
migrate4To5AlterMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
COMMIT;
`
// 5 -> 6
migrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
migrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
migrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
migrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
migrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
migrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
)
var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
10: migrateFrom10,
11: migrateFrom11,
12: migrateFrom12,
}
)
type sqliteMessageCache struct {
*commonMessageCache
}
var _ MessageCache = (*sqliteMessageCache)(nil)
// newSqliteCache creates a SQLite file-backed cache
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*sqliteMessageCache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &sqliteMessageCache{
commonMessageCache: &commonMessageCache{
db: db,
queue: queue,
queries: sqliteMessageCacheQueries,
nop: nop,
},
}
go cache.processMessageBatches()
return cache, nil
}
// newMemCache creates an in-memory cache
func newMemCache() (*sqliteMessageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
}
// newNopCache creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func newNopCache() (*sqliteMessageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true)
}
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
// sql database, so if the stdlib's sql engine happens to open another connection and
// you've only specified ":memory:", that connection will see a brand new database.
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database."
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *sqliteMessageCache) AddMessage(m *message) error {
if c.nop {
return nil
}
return c.commonMessageCache.AddMessage(m)
}
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(selectMessagesCountQuery)
if err != nil {
return setupNewCacheDB(db)
}
rowsMC.Close()
// If 'messages' table exists, check 'schemaVersion' table
schemaVersion := 0
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err == nil {
defer rowsSV.Close()
if !rowsSV.Next() {
return errors.New("cannot determine schema version: cache file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
}
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewCacheDB(db *sql.DB) error {
if _, err := db.Exec(createMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom0(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return nil
}
func migrateFrom3(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return nil
}
func migrateFrom4(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return nil
}
func migrateFrom5(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return nil
}
func migrateFrom6(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
return err
}
return nil
}
func migrateFrom7(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
return err
}
return nil
}
func migrateFrom8(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
return err
}
return nil
}
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom12(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -21,7 +21,7 @@ func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemTestCache(t)) testCacheMessages(t, newMemTestCache(t))
} }
func testCacheMessages(t *testing.T, c *messageCache) { func testCacheMessages(t *testing.T, c MessageCache) {
m1 := newDefaultMessage("mytopic", "my message") m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1 m1.Time = 1
@@ -100,7 +100,7 @@ func TestMemCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestCache(t)) testCacheMessagesLock(t, newMemTestCache(t))
} }
func testCacheMessagesLock(t *testing.T, c *messageCache) { func testCacheMessagesLock(t *testing.T, c MessageCache) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 5000; i++ { for i := 0; i < 5000; i++ {
wg.Add(1) wg.Add(1)
@@ -120,7 +120,7 @@ func TestMemCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestCache(t)) testCacheMessagesScheduled(t, newMemTestCache(t))
} }
func testCacheMessagesScheduled(t *testing.T, c *messageCache) { func testCacheMessagesScheduled(t *testing.T, c MessageCache) {
m1 := newDefaultMessage("mytopic", "message 1") m1 := newDefaultMessage("mytopic", "message 1")
m2 := newDefaultMessage("mytopic", "message 2") m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix() m2.Time = time.Now().Add(time.Hour).Unix()
@@ -154,7 +154,7 @@ func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemTestCache(t)) testCacheTopics(t, newMemTestCache(t))
} }
func testCacheTopics(t *testing.T, c *messageCache) { func testCacheTopics(t *testing.T, c MessageCache) {
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message"))) require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1"))) require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2"))) require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
@@ -177,7 +177,7 @@ func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t)) testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
} }
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) { func testCacheMessagesTagsPrioAndTitle(t *testing.T, c MessageCache) {
m := newDefaultMessage("mytopic", "some message") m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"} m.Tags = []string{"tag1", "tag2"}
m.Priority = 5 m.Priority = 5
@@ -198,7 +198,7 @@ func TestMemCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestCache(t)) testCacheMessagesSinceID(t, newMemTestCache(t))
} }
func testCacheMessagesSinceID(t *testing.T, c *messageCache) { func testCacheMessagesSinceID(t *testing.T, c MessageCache) {
m1 := newDefaultMessage("mytopic", "message 1") m1 := newDefaultMessage("mytopic", "message 1")
m1.Time = 100 m1.Time = 100
m2 := newDefaultMessage("mytopic", "message 2") m2 := newDefaultMessage("mytopic", "message 2")
@@ -268,7 +268,7 @@ func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemTestCache(t)) testCachePrune(t, newMemTestCache(t))
} }
func testCachePrune(t *testing.T, c *messageCache) { func testCachePrune(t *testing.T, c MessageCache) {
now := time.Now().Unix() now := time.Now().Unix()
m1 := newDefaultMessage("mytopic", "my message") m1 := newDefaultMessage("mytopic", "my message")
@@ -315,7 +315,7 @@ func TestMemCache_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestCache(t)) testCacheAttachments(t, newMemTestCache(t))
} }
func testCacheAttachments(t *testing.T, c *messageCache) { func testCacheAttachments(t *testing.T, c MessageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
@@ -397,7 +397,7 @@ func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t)) testCacheAttachmentsExpired(t, newMemTestCache(t))
} }
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) { func testCacheAttachmentsExpired(t *testing.T, c MessageCache) {
m := newDefaultMessage("mytopic", "flower for you") m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1" m.ID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix() m.Expires = time.Now().Add(time.Hour).Unix()
@@ -473,7 +473,7 @@ func TestSqliteCache_Migration_From0(t *testing.T) {
// Create cache to trigger migration // Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "") c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db) checkSchemaVersion(t, c.DB())
messages, err := c.Messages("mytopic", sinceAllMessages, false) messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err) require.Nil(t, err)
@@ -519,7 +519,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
// Create cache to trigger migration // Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "") c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db) checkSchemaVersion(t, c.DB())
// Add delayed message // Add delayed message
delayedMessage := newDefaultMessage("mytopic", "some delayed message") delayedMessage := newDefaultMessage("mytopic", "some delayed message")
@@ -537,7 +537,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Equal(t, 11, len(messages)) require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists // Check that index "idx_topic" exists
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`) rows, err := c.DB().Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err) require.Nil(t, err)
require.True(t, rows.Next()) require.True(t, rows.Next())
var indexName string var indexName string
@@ -623,7 +623,7 @@ func TestSqliteCache_Migration_From9(t *testing.T) {
cacheDuration := 17 * time.Hour cacheDuration := 17 * time.Hour
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false) c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err) require.Nil(t, err)
checkSchemaVersion(t, c.db) checkSchemaVersion(t, c.DB())
// Check version // Check version
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`) rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
@@ -681,7 +681,7 @@ func TestMemCache_Sender(t *testing.T) {
testSender(t, newMemTestCache(t)) testSender(t, newMemTestCache(t))
} }
func testSender(t *testing.T, c *messageCache) { func testSender(t *testing.T, c MessageCache) {
m1 := newDefaultMessage("mytopic", "mymessage") m1 := newDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4") m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, c.AddMessage(m1)) require.Nil(t, c.AddMessage(m1))
@@ -720,7 +720,7 @@ func TestMemCache_NopCache(t *testing.T) {
require.Empty(t, topics) require.Empty(t, topics)
} }
func newSqliteTestCache(t *testing.T) *messageCache { func newSqliteTestCache(t *testing.T) MessageCache {
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false) c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -732,13 +732,13 @@ func newSqliteTestCacheFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db") return filepath.Join(t.TempDir(), "cache.db")
} }
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache { func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) MessageCache {
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err) require.Nil(t, err)
return c return c
} }
func newMemTestCache(t *testing.T) *messageCache { func newMemTestCache(t *testing.T) MessageCache {
c, err := newMemCache() c, err := newMemCache()
require.Nil(t, err) require.Nil(t, err)
return c return c

View File

@@ -56,8 +56,8 @@ type Server struct {
messages int64 // Total number of messages (persisted if messageCache enabled) messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil! userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages messageCache MessageCache // Database that stores the messages
webPush *webPushStore // Database that stores web push subscriptions webPush WebPushStore // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
@@ -173,7 +173,7 @@ func New(conf *Config) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var webPush *webPushStore var webPush WebPushStore
if conf.WebPushPublicKey != "" { if conf.WebPushPublicKey != "" {
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries) webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
if err != nil { if err != nil {
@@ -245,9 +245,11 @@ func New(conf *Config) (*Server, error) {
return s, nil return s, nil
} }
func createMessageCache(conf *Config) (*messageCache, error) { func createMessageCache(conf *Config) (MessageCache, error) {
if conf.CacheDuration == 0 { if conf.CacheDuration == 0 {
return newNopCache() return newNopCache()
} else if isPostgres(conf.CacheFile) {
return newPgCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" { } else if conf.CacheFile != "" {
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
} }

View File

@@ -381,7 +381,7 @@ func TestServer_PublishAt(t *testing.T) {
// Update message time to the past // Update message time to the past
fakeTime := time.Now().Add(-10 * time.Second).Unix() fakeTime := time.Now().Add(-10 * time.Second).Unix()
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
require.Nil(t, err) require.Nil(t, err)
// Trigger delayed message sending // Trigger delayed message sending
@@ -417,7 +417,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) {
// Update message time to the past // Update message time to the past
fakeTime := time.Now().Add(-10 * time.Second).Unix() fakeTime := time.Now().Add(-10 * time.Second).Unix()
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
require.Nil(t, err) require.Nil(t, err)
// Trigger delayed message sending // Trigger delayed message sending
@@ -2336,7 +2336,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
messages = append(messages, newDefaultMessage(topicID, "some message")) messages = append(messages, newDefaultMessage(topicID, "some message"))
} }
require.Nil(t, s.messageCache.addMessages(messages)) require.Nil(t, s.messageCache.AddMessages(messages))
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
// Update stats // Update stats

View File

@@ -238,7 +238,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic") addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1) requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix()) _, err := s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
require.Nil(t, err) require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions() s.pruneAndNotifyWebPushSubscriptions()
@@ -248,7 +248,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
return received.Load() return received.Load()
}) })
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix()) _, err = s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
require.Nil(t, err) require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions() s.pruneAndNotifyWebPushSubscriptions()

View File

@@ -53,7 +53,7 @@ const (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting // visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct { type visitor struct {
config *Config config *Config
messageCache *messageCache messageCache MessageCache
userManager *user.Manager // May be nil userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil user *user.User // Only set if authenticated user, otherwise nil
@@ -114,7 +114,7 @@ const (
visitorLimitBasisTier = visitorLimitBasis("tier") visitorLimitBasisTier = visitorLimitBasis("tier")
) )
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { func newVisitor(conf *Config, messageCache MessageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails, calls int64 var messages, emails, calls int64
if user != nil { if user != nil {
messages = user.Stats.Messages messages = user.Stats.Messages

View File

@@ -3,11 +3,11 @@ package server
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"heckel.io/ntfy/v2/util"
"net/netip" "net/netip"
"strings"
"time" "time"
_ "github.com/mattn/go-sqlite3" // SQLite driver "heckel.io/ntfy/v2/util"
) )
const ( const (
@@ -22,126 +22,49 @@ var (
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty") errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
) )
const ( // WebPushStore is an interface for storing web push subscriptions
createWebPushSubscriptionsTableQuery = ` type WebPushStore interface {
BEGIN; UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
CREATE TABLE IF NOT EXISTS subscription ( SubscriptionsForTopic(topic string) ([]*webPushSubscription, error)
id TEXT PRIMARY KEY, SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error)
endpoint TEXT NOT NULL, MarkExpiryWarningSent(subscriptions []*webPushSubscription) error
key_auth TEXT NOT NULL, RemoveSubscriptionsByEndpoint(endpoint string) error
key_p256dh TEXT NOT NULL, RemoveSubscriptionsByUserID(userID string) error
user_id TEXT NOT NULL, RemoveExpiredSubscriptions(expireAfter time.Duration) error
subscriber_ip TEXT NOT NULL, DB() *sql.DB
updated_at INT NOT NULL, Close() error
warned_at INT NOT NULL DEFAULT 0 }
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` // webPushQueries holds all the SQL queries used by webPushStore
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?` type webPushQueries struct {
selectWebPushSubscriptionsForTopicQuery = ` selectSubscriptionIDByEndpoint string
SELECT id, endpoint, key_auth, key_p256dh, user_id selectSubscriptionCountBySubscriberIP string
FROM subscription_topic st selectSubscriptionsForTopic string
JOIN subscription s ON s.id = st.subscription_id selectSubscriptionsExpiringSoon string
WHERE st.topic = ? insertSubscription string
ORDER BY endpoint updateSubscriptionWarningSent string
` deleteSubscriptionByEndpoint string
selectWebPushSubscriptionsExpiringSoonQuery = ` deleteSubscriptionByUserID string
SELECT id, endpoint, key_auth, key_p256dh, user_id deleteSubscriptionByAge string
FROM subscription insertSubscriptionTopic string
WHERE warned_at = 0 AND updated_at <= ? deleteSubscriptionTopicAll string
` deleteSubscriptionTopicWithoutSub string
insertWebPushSubscriptionQuery = ` }
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
)
// Schema management queries
const (
currentWebPushSchemaVersion = 1
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
type webPushStore struct { type webPushStore struct {
db *sql.DB db *sql.DB
queries *webPushQueries
} }
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) { // newWebPushStore creates a new webPushStore based on the connection string
db, err := sql.Open("sqlite3", filename) func newWebPushStore(filename, startupQueries string) (WebPushStore, error) {
if err != nil { if strings.HasPrefix(filename, "postgres:") {
return nil, err return newPgWebPushStore(strings.TrimPrefix(filename, "postgres:"), startupQueries)
} }
if err := setupWebPushDB(db); err != nil { return newSqliteWebPushStore(filename, startupQueries)
return nil, err
}
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
}, nil
} }
func setupWebPushDB(db *sql.DB) error { // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(selectWebPushSchemaVersionQuery)
if err != nil {
return setupNewWebPushDB(db)
}
return rows.Close()
}
func setupNewWebPushDB(db *sql.DB) error {
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
@@ -149,7 +72,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
} }
defer tx.Rollback() defer tx.Rollback()
// Read number of subscriptions for subscriber IP address // Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) rowsCount, err := tx.Query(c.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil { if err != nil {
return err return err
} }
@@ -165,7 +88,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
return err return err
} }
// Read existing subscription ID for endpoint (or create new ID) // Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) rows, err := tx.Query(c.queries.selectSubscriptionIDByEndpoint, endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -186,15 +109,15 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
} }
// Insert or update subscription // Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0 updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { if _, err = tx.Exec(c.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err return err
} }
// Replace all subscription topics // Replace all subscription topics
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil { if _, err := tx.Exec(c.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err return err
} }
for _, topic := range topics { for _, topic := range topics {
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil { if _, err = tx.Exec(c.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return err return err
} }
} }
@@ -203,7 +126,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
// SubscriptionsForTopic returns all subscriptions for the given topic // SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) { func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) rows, err := c.db.Query(c.queries.selectSubscriptionsForTopic, topic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -213,7 +136,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) rows, err := c.db.Query(c.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -229,7 +152,7 @@ func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscriptio
} }
defer tx.Rollback() defer tx.Rollback()
for _, subscription := range subscriptions { for _, subscription := range subscriptions {
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { if _, err := tx.Exec(c.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err return err
} }
} }
@@ -256,7 +179,7 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error { func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint) _, err := c.db.Exec(c.queries.deleteSubscriptionByEndpoint, endpoint)
return err return err
} }
@@ -265,20 +188,25 @@ func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" { if userID == "" {
return errWebPushUserIDCannotBeEmpty return errWebPushUserIDCannotBeEmpty
} }
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID) _, err := c.db.Exec(c.queries.deleteSubscriptionByUserID, userID)
return err return err
} }
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) _, err := c.db.Exec(c.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
if err != nil { if err != nil {
return err return err
} }
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription) _, err = c.db.Exec(c.queries.deleteSubscriptionTopicWithoutSub)
return err return err
} }
// DB returns the underlying database connection (for testing)
func (c *webPushStore) DB() *sql.DB {
return c.db
}
// Close closes the underlying database connection // Close closes the underlying database connection
func (c *webPushStore) Close() error { func (c *webPushStore) Close() error {
return c.db.Close() return c.db.Close()

130
server/webpush_store_pg.go Normal file
View File

@@ -0,0 +1,130 @@
package server
import (
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
// PostgreSQL-specific queries
const (
pgCreateWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schema_version (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
// Schema management queries
pgCurrentWebPushSchemaVersion = 1
pgInsertWebPushSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
pgSelectWebPushSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
)
// PostgreSQL-specific webpush queries
var pgWebPushQueries = &webPushQueries{
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = $1`,
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = $1`,
selectSubscriptionsForTopic: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = $1
ORDER BY endpoint
`,
selectSubscriptionsExpiringSoon: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= $1
`,
insertSubscription: `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = EXCLUDED.key_auth, key_p256dh = EXCLUDED.key_p256dh, user_id = EXCLUDED.user_id, subscriber_ip = EXCLUDED.subscriber_ip, updated_at = EXCLUDED.updated_at, warned_at = EXCLUDED.warned_at
`,
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = $1 WHERE id = $2`,
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = $1`,
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = $1`,
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= $1`,
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES ($1, $2)`,
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = $1`,
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
}
func newPgWebPushStore(connStr, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
if err := setupPgWebPushDB(db); err != nil {
return nil, err
}
if err := runPgWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
queries: pgWebPushQueries,
}, nil
}
func setupPgWebPushDB(db *sql.DB) error {
// If 'schema_version' table does not exist, this must be a new database
rows, err := db.Query(pgSelectWebPushSchemaVersionQuery)
if err != nil {
return setupNewPgWebPushDB(db)
}
defer rows.Close()
// If table exists but no rows, also create new
if !rows.Next() {
return setupNewPgWebPushDB(db)
}
return nil
}
func setupNewPgWebPushDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(pgInsertWebPushSchemaVersion, pgCurrentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runPgWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,126 @@
package server
import (
"database/sql"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// SQLite-specific queries
const (
sqliteCreateWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
sqliteWebPushBuiltinStartupQueries = `
PRAGMA foreign_keys = ON;
`
// Schema management queries
sqliteCurrentWebPushSchemaVersion = 1
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// SQLite-specific webpush queries
var sqliteWebPushQueries = &webPushQueries{
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = ?`,
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`,
selectSubscriptionsForTopic: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`,
selectSubscriptionsExpiringSoon: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`,
insertSubscription: `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`,
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = ? WHERE id = ?`,
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = ?`,
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = ?`,
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= ?`,
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`,
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = ?`,
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
}
func newSqliteWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSqliteWebPushDB(db); err != nil {
return nil, err
}
if err := runSqliteWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
queries: sqliteWebPushQueries,
}, nil
}
func setupSqliteWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
if err != nil {
return setupNewSqliteWebPushDB(db)
}
return rows.Close()
}
func setupNewSqliteWebPushDB(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runSqliteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
if _, err := db.Exec(sqliteWebPushBuiltinStartupQueries); err != nil {
return err
}
return nil
}

View File

@@ -134,7 +134,7 @@ func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
// Mark them as warning sent // Mark them as warning sent
require.Nil(t, webPush.MarkExpiryWarningSent(subs)) require.Nil(t, webPush.MarkExpiryWarningSent(subs))
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0") rows, err := webPush.DB().Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
require.Nil(t, err) require.Nil(t, err)
defer rows.Close() defer rows.Close()
var endpoint string var endpoint string
@@ -156,7 +156,7 @@ func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
require.Len(t, subs, 1) require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire // Fake-mark them as soon-to-expire
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint) _, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err) require.Nil(t, err)
// Should not be cleaned up yet // Should not be cleaned up yet
@@ -180,7 +180,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
require.Len(t, subs, 1) require.Len(t, subs, 1)
// Fake-mark them as expired // Fake-mark them as expired
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint) _, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err) require.Nil(t, err)
// Run expiration // Run expiration
@@ -192,7 +192,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
require.Len(t, subs, 0) require.Len(t, subs, 0)
} }
func newTestWebPushStore(t *testing.T) *webPushStore { func newTestWebPushStore(t *testing.T) WebPushStore {
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "") webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err) require.Nil(t, err)
return webPush return webPush

View File

@@ -6,17 +6,17 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/mattn/go-sqlite3" "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
"net/netip"
"path/filepath"
"slices"
"strings"
"sync"
"time"
) )
const ( const (
@@ -43,100 +43,10 @@ const (
var ( var (
errNoTokenProvided = errors.New("no token provided") errNoTokenProvided = errors.New("no token provided")
errTopicOwnedByOthers = errors.New("topic owned by others") errTopicOwnedByOthers = errors.New("topic owned by others")
errNoRows = errors.New("no rows found")
) )
// Manager-related queries // Manager-related queries
const ( const (
createTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
@@ -326,229 +236,6 @@ const (
` `
) )
// Schema management queries
const (
currentSchemaVersion = 6
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 1 -> 2 (complex migration!)
migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
JOIN access a ON u.user = a.user;
DROP TABLE access;
DROP TABLE user_old;
`
// 2 -> 3
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
// 3 -> 4
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
`
// 4 -> 5
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
// 5 -> 6
migrate5To6UpdateQueries = `
PRAGMA foreign_keys=off;
-- Alter user table: Add provisioned column
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
INSERT INTO user
SELECT
id,
tier_id,
user,
pass,
role,
prefs,
sync_topic,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
stripe_customer_id,
stripe_subscription_id,
stripe_subscription_status,
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created,
deleted
FROM user_old;
DROP TABLE user_old;
-- Alter user_access table: Add provisioned column
ALTER TABLE user_access RENAME TO user_access_old;
CREATE TABLE user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INTEGER NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
// Manager is an implementation of Manager. It stores users and access control list // Manager is an implementation of Manager. It stores users and access control list
// in a SQLite database. // in a SQLite database.
type Manager struct { type Manager struct {
@@ -583,28 +270,20 @@ func NewManager(config *Config) (*Manager, error) {
if config.QueueWriterInterval.Seconds() <= 0 { if config.QueueWriterInterval.Seconds() <= 0 {
config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
} }
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename) var manager *Manager
if !util.FileExists(parentDir) { var err error
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
// Select database backend based on connection string
if strings.HasPrefix(config.Filename, "postgres:") {
manager, err = newPgManager(config)
} else {
manager, err = newSqliteManager(config)
} }
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := setupDB(db); err != nil {
return nil, err
}
if err := runStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
manager := &Manager{
db: db,
config: config,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}
if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil { if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
return nil, err return nil, err
} }
@@ -1929,173 +1608,6 @@ func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_") return strings.ReplaceAll(s, "\\_", "_")
} }
func runStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
func setupDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err != nil {
return setupNewDB(db)
}
defer rowsSV.Close()
// If 'schemaVersion' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
return errors.New("cannot determine schema version: database file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewDB(db *sql.DB) error {
if _, err := db.Exec(createTablesQueries); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
if err != nil {
return err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom4(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString { func nullString(s string) sql.NullString {
if s == "" { if s == "" {
return sql.NullString{} return sql.NullString{}

175
user/manager_pg.go Normal file
View File

@@ -0,0 +1,175 @@
package user
import (
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/log"
)
// PostgreSQL-specific queries
const (
pgCreateTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_code ON tier (code);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS "user" (
id TEXT PRIMARY KEY,
tier_id TEXT,
"user" TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT 0,
stats_emails INT NOT NULL DEFAULT 0,
stats_calls INT NOT NULL DEFAULT 0,
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user ON "user" ("user");
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_customer_id ON "user" (stripe_customer_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_subscription_id ON "user" (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id TEXT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schema_version (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO "user" (id, "user", pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', 0, EXTRACT(EPOCH FROM NOW())::INT)
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
pgCurrentSchemaVersion = 1
pgInsertSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
pgUpdateSchemaVersion = `UPDATE schema_version SET version = $1 WHERE id = 1`
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
)
// newPgManager creates a new PostgreSQL-backed user manager
func newPgManager(config *Config) (*Manager, error) {
db, err := sql.Open("pgx", config.Filename)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
if err := setupPgDB(db); err != nil {
return nil, err
}
if err := runPgStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
return &Manager{
config: config,
db: db,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}, nil
}
func runPgStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}
func setupPgDB(db *sql.DB) error {
// If 'schema_version' table does not exist, this must be a new database
rowsSV, err := db.Query(pgSelectSchemaVersionQuery)
if err != nil {
return setupNewPgDB(db)
}
defer rowsSV.Close()
// If 'schema_version' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
// Table exists but no rows, insert version
return setupNewPgDB(db)
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == pgCurrentSchemaVersion {
return nil
} else if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
}
// No migrations needed yet for PG (starting at version 1)
log.Tag(tag).Info("PostgreSQL user database schema is up to date (version %d)", schemaVersion)
return nil
}
func setupNewPgDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
return err
}
return nil
}

532
user/manager_sqlite.go Normal file
View File

@@ -0,0 +1,532 @@
package user
import (
"database/sql"
"errors"
"fmt"
"path/filepath"
"github.com/mattn/go-sqlite3"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
errNoRows = errors.New("no rows found")
)
// SQLite-specific queries
const (
sqliteCreateTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
sqliteBuiltinStartupQueries = `
PRAGMA foreign_keys = ON;
`
)
// Schema management queries
const (
sqliteCurrentSchemaVersion = 6
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// Migration queries (1->2 through 5->6)
// These are SQLite-specific due to PRAGMA, ALTER TABLE syntax, etc.
migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
JOIN access a ON u.user = a.user;
DROP TABLE access;
DROP TABLE user_old;
`
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
`
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
migrate5To6UpdateQueries = `
PRAGMA foreign_keys=off;
-- Alter user table: Add provisioned column
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
INSERT INTO user
SELECT
id,
tier_id,
user,
pass,
role,
prefs,
sync_topic,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
stripe_customer_id,
stripe_subscription_id,
stripe_subscription_status,
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created,
deleted
FROM user_old;
DROP TABLE user_old;
-- Alter user_access table: Add provisioned column
ALTER TABLE user_access RENAME TO user_access_old;
CREATE TABLE user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INTEGER NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
`
)
var (
sqliteMigrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
// newSqliteManager creates a new SQLite-backed Manager
func newSqliteManager(config *Config) (*Manager, error) {
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil {
return nil, err
}
if err := setupSqliteDB(db); err != nil {
return nil, err
}
if err := runSqliteStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
return &Manager{
db: db,
config: config,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}, nil
}
func runSqliteStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
return err
}
return nil
}
func setupSqliteDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
if err != nil {
return setupNewSqliteDB(db)
}
defer rowsSV.Close()
// If 'schemaVersion' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
return errors.New("cannot determine schema version: database file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == sqliteCurrentSchemaVersion {
return nil
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
fn, ok := sqliteMigrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewSqliteDB(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
if err != nil {
return err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom4(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
// isSqliteConstraintUniqueError checks if the error is a SQLite unique constraint error
func isSqliteConstraintUniqueError(err error) bool {
if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
return true
}
return errors.Is(err, sqlite3.ErrConstraintUnique)
}

View File

@@ -1578,7 +1578,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
var schemaVersion int var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion)) require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, currentSchemaVersion, schemaVersion) require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion)
require.Nil(t, rows.Close()) require.Nil(t, rows.Close())
} }