Compare commits

..

2 Commits

Author SHA1 Message Date
binwiederhier
2a940ad289 Show provisioned users 2025-12-30 11:30:36 -05:00
binwiederhier
75b2ca7dec Admin web app 2025-12-30 11:10:41 -05:00
27 changed files with 2041 additions and 1994 deletions

View File

@@ -181,7 +181,6 @@ I've added a ⭐ to projects or posts that have a significant following, or had
- [ntfy-heartbeat-monitor](https://codeberg.org/RockWolf/ntfy-heartbeat-monitor) - Application for implementing heartbeat monitoring/alerting by utilizing ntfy
- [ntfy-bridge](https://github.com/AlexGaudon/ntfy-bridge) - An application to bridge Discord messages (or webhooks) to ntfy.
- [ntailfy](https://github.com/leukosaima/ntailfy) - ntfy notifications when Tailscale devices connect/disconnect (Go)
- [BRun](https://github.com/cbrake/brun) - Native Linux automation platform connecting triggers to actions without containers (Go)
## Blog + forum posts

4
go.mod
View File

@@ -32,7 +32,6 @@ require github.com/pkg/errors v0.9.1 // indirect
require (
firebase.google.com/go/v4 v4.18.0
github.com/SherClockHolmes/webpush-go v1.4.0
github.com/jackc/pgx/v5 v5.8.0
github.com/microcosm-cc/bluemonday v1.0.27
github.com/prometheus/client_golang v1.23.2
github.com/stripe/stripe-go/v74 v74.30.0
@@ -73,9 +72,6 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect

9
go.sum
View File

@@ -104,14 +104,6 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -152,7 +144,6 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=

View File

@@ -4,89 +4,351 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
"path/filepath"
"strings"
"sync"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
// MessageCache is the interface for message storage backends
type MessageCache interface {
AddMessage(m *message) error
AddMessages(ms []*message) error
Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error)
MessagesDue() ([]*message, error)
MessagesExpired() ([]string, error)
Message(id string) (*message, error)
MarkPublished(m *message) error
MessageCounts() (map[string]int, error)
Topics() (map[string]*topic, error)
DeleteMessages(ids ...string) error
ExpireMessages(topics ...string) error
AttachmentsExpired() ([]string, error)
MarkAttachmentsDeleted(ids ...string) error
AttachmentBytesUsedBySender(sender string) (int64, error)
AttachmentBytesUsedByUser(userID string) (int64, error)
UpdateStats(messages int64) error
Stats() (messages int64, err error)
DB() *sql.DB
Close() error
var (
errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
errNoRows = errors.New("no rows found")
)
// Messages cache
const (
createMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
insertMessageQuery = `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`
selectMessagesSinceIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`
selectMessagesLatestQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`
selectMessagesDueQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`
selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
)
// Schema management queries
const (
currentSchemaVersion = 13
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 0 -> 1
migrate0To1AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 1 -> 2
migrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
migrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 3 -> 4
migrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
migrate4To5AlterMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
COMMIT;
`
// 5 -> 6
migrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
migrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
migrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
migrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
migrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
migrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
)
var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
10: migrateFrom10,
11: migrateFrom11,
12: migrateFrom12,
}
)
type messageCache struct {
db *sql.DB
queue *util.BatchingQueue[*message]
nop bool
mu sync.Mutex
}
// commonMessageCache contains shared logic for all message cache implementations
type commonMessageCache struct {
db *sql.DB
queue *util.BatchingQueue[*message]
queries *messageCacheQueries
nop bool // If true, cache ignores all messages
mu sync.Mutex // Lock for concurrent access
// newSqliteCache creates a SQLite file-backed cache
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) {
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
}
// Open database
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &messageCache{
db: db,
queue: queue,
nop: nop,
}
go cache.processMessageBatches()
return cache, nil
}
var _ MessageCache = (*commonMessageCache)(nil)
// newMemCache creates an in-memory cache
func newMemCache() (*messageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
}
// messageCacheQueries holds database-specific SQL queries
type messageCacheQueries struct {
insertMessage string
deleteMessage string
updateMessagesForTopicExpiry string
selectRowIDFromMessageID string // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeIncludeScheduled string
selectMessagesSinceID string
selectMessagesSinceIDIncludeScheduled string
selectMessagesLatest string
selectMessagesDue string
selectMessagesExpired string
updateMessagePublished string
selectMessageCountPerTopic string
selectTopics string
// newNopCache creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func newNopCache() (*messageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true)
}
updateAttachmentDeleted string
selectAttachmentsExpired string
selectAttachmentsSizeBySender string
selectAttachmentsSizeByUserID string
selectStats string
updateStats string
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
// sql database, so if the stdlib's sql engine happens to open another connection and
// you've only specified ":memory:", that connection will see a brand new database.
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database."
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *commonMessageCache) AddMessage(m *message) error {
func (c *messageCache) AddMessage(m *message) error {
if c.queue != nil {
c.queue.Enqueue(m)
return nil
}
return c.AddMessages([]*message{m})
return c.addMessages([]*message{m})
}
// AddMessages synchronously stores a batch of messages. If the database is locked, the transaction waits until
// the timeout is exceeded before erroring out.
func (c *commonMessageCache) AddMessages(ms []*message) error {
// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until
// SQLite's busy_timeout is exceeded before erroring out.
func (c *messageCache) addMessages(ms []*message) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.nop {
@@ -101,7 +363,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(c.queries.insertMessage)
stmt, err := tx.Prepare(insertMessageQuery)
if err != nil {
return err
}
@@ -113,8 +375,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
var attachmentSize, attachmentExpires, attachmentDeleted int64
if m.Attachment != nil {
attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type
@@ -151,7 +412,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
attachmentSize,
attachmentExpires,
attachmentURL,
attachmentDeleted, // Always false
attachmentDeleted, // Always zero
sender,
m.User,
m.ContentType,
@@ -170,7 +431,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
return nil
}
func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
if since.IsNone() {
return make([]*message, 0), nil
} else if since.IsLatest() {
@@ -181,21 +442,13 @@ func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled
return c.messagesSinceTime(topic, since, scheduled)
}
func (c *commonMessageCache) messagesLatest(topic string) ([]*message, error) {
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
var rows *sql.Rows
var err error
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeIncludeScheduled, topic, since.Time().Unix())
rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix())
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix())
}
if err != nil {
return nil, err
@@ -203,8 +456,8 @@ func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker,
return readMessages(rows)
}
func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID())
if err != nil {
return nil, err
}
@@ -219,9 +472,9 @@ func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, sc
idrows.Close()
var rows *sql.Rows
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceIDIncludeScheduled, topic, rowID)
rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID)
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID)
}
if err != nil {
return nil, err
@@ -229,17 +482,25 @@ func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, sc
return readMessages(rows)
}
func (c *commonMessageCache) MessagesDue() ([]*message, error) {
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
func (c *messageCache) messagesLatest(topic string) ([]*message, error) {
rows, err := c.db.Query(selectMessagesLatestQuery, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
func (c *commonMessageCache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
func (c *messageCache) MessagesDue() ([]*message, error) {
rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
}
// MessagesExpired returns a list of IDs for messages that have expires (should be deleted)
func (c *messageCache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
@@ -258,24 +519,27 @@ func (c *commonMessageCache) MessagesExpired() ([]string, error) {
return ids, nil
}
func (c *commonMessageCache) Message(id string) (*message, error) {
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
func (c *messageCache) Message(id string) (*message, error) {
rows, err := c.db.Query(selectMessagesByIDQuery, id)
if err != nil {
return nil, err
} else if !rows.Next() {
}
if !rows.Next() {
return nil, errMessageNotFound
}
defer rows.Close()
return readMessage(rows)
}
func (c *commonMessageCache) MarkPublished(m *message) error {
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
func (c *messageCache) MarkPublished(m *message) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err
}
func (c *commonMessageCache) MessageCounts() (map[string]int, error) {
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
func (c *messageCache) MessageCounts() (map[string]int, error) {
rows, err := c.db.Query(selectMessageCountPerTopicQuery)
if err != nil {
return nil, err
}
@@ -294,8 +558,8 @@ func (c *commonMessageCache) MessageCounts() (map[string]int, error) {
return counts, nil
}
func (c *commonMessageCache) Topics() (map[string]*topic, error) {
rows, err := c.db.Query(c.queries.selectTopics)
func (c *messageCache) Topics() (map[string]*topic, error) {
rows, err := c.db.Query(selectTopicsQuery)
if err != nil {
return nil, err
}
@@ -314,36 +578,40 @@ func (c *commonMessageCache) Topics() (map[string]*topic, error) {
return topics, nil
}
func (c *commonMessageCache) DeleteMessages(ids ...string) error {
func (c *messageCache) DeleteMessages(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
if _, err := tx.Exec(deleteMessageQuery, id); err != nil {
return err
}
}
return tx.Commit()
}
func (c *commonMessageCache) ExpireMessages(topics ...string) error {
func (c *messageCache) ExpireMessages(topics ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil {
return err
}
}
return tx.Commit()
}
func (c *commonMessageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
func (c *messageCache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
if err != nil {
return nil, err
}
@@ -362,37 +630,39 @@ func (c *commonMessageCache) AttachmentsExpired() ([]string, error) {
return ids, nil
}
func (c *commonMessageCache) MarkAttachmentsDeleted(ids ...string) error {
func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil {
return err
}
}
return tx.Commit()
}
func (c *commonMessageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *commonMessageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close()
var size int64
if !rows.Next() {
@@ -406,45 +676,17 @@ func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, err
return size, nil
}
func (c *commonMessageCache) processMessageBatches() {
func (c *messageCache) processMessageBatches() {
if c.queue == nil {
return
}
for messages := range c.queue.Dequeue() {
if err := c.AddMessages(messages); err != nil {
if err := c.addMessages(messages); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
}
}
}
func (c *commonMessageCache) UpdateStats(messages int64) error {
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
func (c *commonMessageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(c.queries.selectStats)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *commonMessageCache) DB() *sql.DB {
return c.db
}
func (c *commonMessageCache) Close() error {
return c.db.Close()
}
func readMessages(rows *sql.Rows) ([]*message, error) {
defer rows.Close()
messages := make([]*message, 0)
@@ -534,3 +776,257 @@ func readMessage(rows *sql.Rows) (*message, error) {
Encoding: encoding,
}, nil
}
func (c *messageCache) UpdateStats(messages int64) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(updateStatsQuery, messages)
return err
}
func (c *messageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(selectStatsQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *messageCache) Close() error {
return c.db.Close()
}
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(selectMessagesCountQuery)
if err != nil {
return setupNewCacheDB(db)
}
rowsMC.Close()
// If 'messages' table exists, check 'schemaVersion' table
schemaVersion := 0
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err == nil {
defer rowsSV.Close()
if !rowsSV.Next() {
return errors.New("cannot determine schema version: cache file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
}
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewCacheDB(db *sql.DB) error {
if _, err := db.Exec(createMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom0(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return nil
}
func migrateFrom3(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return nil
}
func migrateFrom4(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return nil
}
func migrateFrom5(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return nil
}
func migrateFrom6(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
return err
}
return nil
}
func migrateFrom7(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
return err
}
return nil
}
func migrateFrom8(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
return err
}
return nil
}
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom12(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -1,193 +0,0 @@
package server
import (
"database/sql"
"strings"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/util"
)
// PostgreSQL schema version (starts at latest, no migrations needed initially)
const pgCurrentSchemaVersion = 1
// Messages cache
const (
pgCreateMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
mid TEXT NOT NULL,
time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted BOOLEAN NOT NULL,
sender TEXT NOT NULL,
"user" TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published BOOLEAN NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages ("user");
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion VALUES (1, 1);
COMMIT;
`
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
)
var (
pgMessageCacheQueries = &messageCacheQueries{
insertMessage: `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, "user", content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
`,
deleteMessage: `DELETE FROM messages WHERE mid = $1`,
updateMessagesForTopicExpiry: `UPDATE messages SET expires = $1 WHERE topic = $2`,
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = $1`,
selectMessagesByID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE mid = $1
`,
selectMessagesSinceTime: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`,
selectMessagesSinceTimeIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`,
selectMessagesSinceID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND id > $2 AND published = TRUE
ORDER BY time, id
`,
selectMessagesSinceIDIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND (id > $2 OR published = FALSE)
ORDER BY time, id
`,
selectMessagesLatest: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`,
selectMessagesDue: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
FROM messages
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`,
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= $1 AND published = TRUE`,
updateMessagePublished: `UPDATE messages SET published = TRUE WHERE mid = $1`,
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = TRUE WHERE mid = $1`,
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`,
selectAttachmentsSizeBySender: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = '' AND sender = $1 AND attachment_expires >= $2`,
selectAttachmentsSizeByUserID: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = $1 AND attachment_expires >= $2`,
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
updateStats: `UPDATE stats SET value = $1 WHERE key = 'messages'`,
}
)
type pgMessageCache struct {
*commonMessageCache
}
var _ MessageCache = (*pgMessageCache)(nil)
// newPgCache creates a PostgreSQL-backed message cache
func newPgCache(connectionString, startupQueries string, batchSize int, batchTimeout time.Duration) (*pgMessageCache, error) {
db, err := sql.Open("pgx", connectionString)
if err != nil {
return nil, err
}
if err := setupPgMessagesDB(db, startupQueries); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &pgMessageCache{
commonMessageCache: &commonMessageCache{
db: db,
queue: queue,
queries: pgMessageCacheQueries,
},
}
go cache.processMessageBatches()
return cache, nil
}
func setupPgMessagesDB(db *sql.DB, startupQueries string) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(pgSelectMessagesCountQuery)
if err != nil {
return setupNewPgCacheDB(db)
}
rowsMC.Close()
// Future: Add migrations here when schema changes
return nil
}
func setupNewPgCacheDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateMessagesTableQuery); err != nil {
return err
}
return nil
}
// isPostgres checks if the connection string indicates a PostgreSQL database
func isPostgres(connectionString string) bool {
return strings.HasPrefix(connectionString, "postgres:")
}

View File

@@ -1,571 +0,0 @@
package server
import (
"database/sql"
"errors"
"fmt"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
errNoRows = errors.New("no rows found")
)
// Messages cache
const (
createMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
)
var (
sqliteMessageCacheQueries = &messageCacheQueries{
insertMessage: `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
deleteMessage: `DELETE FROM messages WHERE mid = ?`,
updateMessagesForTopicExpiry: `UPDATE messages SET expires = ? WHERE topic = ?`,
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = ?`, // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`,
selectMessagesSinceTime: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`,
selectMessagesSinceTimeIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`,
selectMessagesSinceID: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`,
selectMessagesSinceIDIncludeScheduled: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`,
selectMessagesLatest: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`,
selectMessagesDue: `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`,
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= ? AND published = 1`,
updateMessagePublished: `UPDATE messages SET published = 1 WHERE mid = ?`,
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`,
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`,
selectAttachmentsSizeBySender: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`,
selectAttachmentsSizeByUserID: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`,
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
updateStats: `UPDATE stats SET value = ? WHERE key = 'messages'`,
}
)
// Schema management queries
const (
currentSchemaVersion = 13
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
// 0 -> 1
migrate0To1AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 1 -> 2
migrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
migrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 3 -> 4
migrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
migrate4To5AlterMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
COMMIT;
`
// 5 -> 6
migrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
migrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
migrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
migrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
migrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
migrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
migrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
)
var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
10: migrateFrom10,
11: migrateFrom11,
12: migrateFrom12,
}
)
type sqliteMessageCache struct {
*commonMessageCache
}
var _ MessageCache = (*sqliteMessageCache)(nil)
// newSqliteCache creates a SQLite file-backed cache
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*sqliteMessageCache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
}
cache := &sqliteMessageCache{
commonMessageCache: &commonMessageCache{
db: db,
queue: queue,
queries: sqliteMessageCacheQueries,
nop: nop,
},
}
go cache.processMessageBatches()
return cache, nil
}
// newMemCache creates an in-memory cache
func newMemCache() (*sqliteMessageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
}
// newNopCache creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func newNopCache() (*sqliteMessageCache, error) {
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true)
}
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
// sql database, so if the stdlib's sql engine happens to open another connection and
// you've only specified ":memory:", that connection will see a brand new database.
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database."
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *sqliteMessageCache) AddMessage(m *message) error {
if c.nop {
return nil
}
return c.commonMessageCache.AddMessage(m)
}
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(selectMessagesCountQuery)
if err != nil {
return setupNewCacheDB(db)
}
rowsMC.Close()
// If 'messages' table exists, check 'schemaVersion' table
schemaVersion := 0
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err == nil {
defer rowsSV.Close()
if !rowsSV.Next() {
return errors.New("cannot determine schema version: cache file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
}
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewCacheDB(db *sql.DB) error {
if _, err := db.Exec(createMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom0(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return nil
}
func migrateFrom3(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return nil
}
func migrateFrom4(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return nil
}
func migrateFrom5(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return nil
}
func migrateFrom6(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
return err
}
return nil
}
func migrateFrom7(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
return err
}
return nil
}
func migrateFrom8(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
return err
}
return nil
}
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom12(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -21,7 +21,7 @@ func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemTestCache(t))
}
func testCacheMessages(t *testing.T, c MessageCache) {
func testCacheMessages(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1
@@ -100,7 +100,7 @@ func TestMemCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestCache(t))
}
func testCacheMessagesLock(t *testing.T, c MessageCache) {
func testCacheMessagesLock(t *testing.T, c *messageCache) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
@@ -120,7 +120,7 @@ func TestMemCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestCache(t))
}
func testCacheMessagesScheduled(t *testing.T, c MessageCache) {
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
@@ -154,7 +154,7 @@ func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemTestCache(t))
}
func testCacheTopics(t *testing.T, c MessageCache) {
func testCacheTopics(t *testing.T, c *messageCache) {
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
@@ -177,7 +177,7 @@ func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c MessageCache) {
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
@@ -198,7 +198,7 @@ func TestMemCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestCache(t))
}
func testCacheMessagesSinceID(t *testing.T, c MessageCache) {
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := newDefaultMessage("mytopic", "message 2")
@@ -268,7 +268,7 @@ func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemTestCache(t))
}
func testCachePrune(t *testing.T, c MessageCache) {
func testCachePrune(t *testing.T, c *messageCache) {
now := time.Now().Unix()
m1 := newDefaultMessage("mytopic", "my message")
@@ -315,7 +315,7 @@ func TestMemCache_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestCache(t))
}
func testCacheAttachments(t *testing.T, c MessageCache) {
func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
@@ -397,7 +397,7 @@ func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t))
}
func testCacheAttachmentsExpired(t *testing.T, c MessageCache) {
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
@@ -473,7 +473,7 @@ func TestSqliteCache_Migration_From0(t *testing.T) {
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.DB())
checkSchemaVersion(t, c.db)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
@@ -519,7 +519,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.DB())
checkSchemaVersion(t, c.db)
// Add delayed message
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
@@ -537,7 +537,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists
rows, err := c.DB().Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err)
require.True(t, rows.Next())
var indexName string
@@ -623,7 +623,7 @@ func TestSqliteCache_Migration_From9(t *testing.T) {
cacheDuration := 17 * time.Hour
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err)
checkSchemaVersion(t, c.DB())
checkSchemaVersion(t, c.db)
// Check version
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
@@ -681,7 +681,7 @@ func TestMemCache_Sender(t *testing.T) {
testSender(t, newMemTestCache(t))
}
func testSender(t *testing.T, c MessageCache) {
func testSender(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, c.AddMessage(m1))
@@ -720,7 +720,7 @@ func TestMemCache_NopCache(t *testing.T) {
require.Empty(t, topics)
}
func newSqliteTestCache(t *testing.T) MessageCache {
func newSqliteTestCache(t *testing.T) *messageCache {
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
if err != nil {
t.Fatal(err)
@@ -732,13 +732,13 @@ func newSqliteTestCacheFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) MessageCache {
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
return c
}
func newMemTestCache(t *testing.T) MessageCache {
func newMemTestCache(t *testing.T) *messageCache {
c, err := newMemCache()
require.Nil(t, err)
return c

View File

@@ -56,8 +56,8 @@ type Server struct {
messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache MessageCache // Database that stores the messages
webPush WebPushStore // Database that stores web push subscriptions
messageCache *messageCache // Database that stores the messages
webPush *webPushStore // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
@@ -173,7 +173,7 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var webPush WebPushStore
var webPush *webPushStore
if conf.WebPushPublicKey != "" {
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
if err != nil {
@@ -245,11 +245,9 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config) (MessageCache, error) {
func createMessageCache(conf *Config) (*messageCache, error) {
if conf.CacheDuration == 0 {
return newNopCache()
} else if isPostgres(conf.CacheFile) {
return newPgCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" {
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
}

View File

@@ -24,15 +24,17 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit
userGrants := make([]*apiUserGrantResponse, len(grants[u.ID]))
for i, g := range grants[u.ID] {
userGrants[i] = &apiUserGrantResponse{
Topic: g.TopicPattern,
Permission: g.Permission.String(),
Topic: g.TopicPattern,
Permission: g.Permission.String(),
Provisioned: g.Provisioned,
}
}
usersResponse[i] = &apiUserResponse{
Username: u.Name,
Role: string(u.Role),
Tier: tier,
Grants: userGrants,
Username: u.Name,
Role: string(u.Role),
Tier: tier,
Grants: userGrants,
Provisioned: u.Provisioned,
}
}
return s.writeJSON(w, usersResponse)

View File

@@ -381,7 +381,7 @@ func TestServer_PublishAt(t *testing.T) {
// Update message time to the past
fakeTime := time.Now().Add(-10 * time.Second).Unix()
_, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime)
require.Nil(t, err)
// Trigger delayed message sending
@@ -417,7 +417,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) {
// Update message time to the past
fakeTime := time.Now().Add(-10 * time.Second).Unix()
_, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime)
require.Nil(t, err)
// Trigger delayed message sending
@@ -2336,7 +2336,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
require.Nil(t, err)
messages = append(messages, newDefaultMessage(topicID, "some message"))
}
require.Nil(t, s.messageCache.AddMessages(messages))
require.Nil(t, s.messageCache.addMessages(messages))
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
// Update stats

View File

@@ -238,7 +238,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
@@ -248,7 +248,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
return received.Load()
})
_, err = s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()

View File

@@ -308,15 +308,17 @@ type apiUserAddOrUpdateRequest struct {
}
type apiUserResponse struct {
Username string `json:"username"`
Role string `json:"role"`
Tier string `json:"tier,omitempty"`
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
Username string `json:"username"`
Role string `json:"role"`
Tier string `json:"tier,omitempty"`
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
Provisioned bool `json:"provisioned,omitempty"`
}
type apiUserGrantResponse struct {
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
Provisioned bool `json:"provisioned,omitempty"`
}
type apiUserDeleteRequest struct {

View File

@@ -53,7 +53,7 @@ const (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
messageCache MessageCache
messageCache *messageCache
userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil
@@ -114,7 +114,7 @@ const (
visitorLimitBasisTier = visitorLimitBasis("tier")
)
func newVisitor(conf *Config, messageCache MessageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails, calls int64
if user != nil {
messages = user.Stats.Messages

View File

@@ -3,11 +3,11 @@ package server
import (
"database/sql"
"errors"
"heckel.io/ntfy/v2/util"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/util"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
const (
@@ -22,49 +22,126 @@ var (
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
// WebPushStore is an interface for storing web push subscriptions
type WebPushStore interface {
UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
SubscriptionsForTopic(topic string) ([]*webPushSubscription, error)
SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error)
MarkExpiryWarningSent(subscriptions []*webPushSubscription) error
RemoveSubscriptionsByEndpoint(endpoint string) error
RemoveSubscriptionsByUserID(userID string) error
RemoveExpiredSubscriptions(expireAfter time.Duration) error
DB() *sql.DB
Close() error
}
const (
createWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
// webPushQueries holds all the SQL queries used by webPushStore
type webPushQueries struct {
selectSubscriptionIDByEndpoint string
selectSubscriptionCountBySubscriberIP string
selectSubscriptionsForTopic string
selectSubscriptionsExpiringSoon string
insertSubscription string
updateSubscriptionWarningSent string
deleteSubscriptionByEndpoint string
deleteSubscriptionByUserID string
deleteSubscriptionByAge string
insertSubscriptionTopic string
deleteSubscriptionTopicAll string
deleteSubscriptionTopicWithoutSub string
}
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
selectWebPushSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`
selectWebPushSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`
insertWebPushSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
)
// Schema management queries
const (
currentWebPushSchemaVersion = 1
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
type webPushStore struct {
db *sql.DB
queries *webPushQueries
db *sql.DB
}
// newWebPushStore creates a new webPushStore based on the connection string
func newWebPushStore(filename, startupQueries string) (WebPushStore, error) {
if strings.HasPrefix(filename, "postgres:") {
return newPgWebPushStore(strings.TrimPrefix(filename, "postgres:"), startupQueries)
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
return newSqliteWebPushStore(filename, startupQueries)
if err := setupWebPushDB(db); err != nil {
return nil, err
}
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
}, nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID
func setupWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(selectWebPushSchemaVersionQuery)
if err != nil {
return setupNewWebPushDB(db)
}
return rows.Close()
}
func setupNewWebPushDB(db *sql.DB) error {
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin()
if err != nil {
@@ -72,7 +149,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(c.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String())
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil {
return err
}
@@ -88,7 +165,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
return err
}
// Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(c.queries.selectSubscriptionIDByEndpoint, endpoint)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
if err != nil {
return err
}
@@ -109,15 +186,15 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(c.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(c.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(c.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
return err
}
}
@@ -126,7 +203,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
// SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(c.queries.selectSubscriptionsForTopic, topic)
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil {
return nil, err
}
@@ -136,7 +213,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(c.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
}
@@ -152,7 +229,7 @@ func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscriptio
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(c.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
@@ -179,7 +256,7 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(c.queries.deleteSubscriptionByEndpoint, endpoint)
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
return err
}
@@ -188,25 +265,20 @@ func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return errWebPushUserIDCannotBeEmpty
}
_, err := c.db.Exec(c.queries.deleteSubscriptionByUserID, userID)
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
return err
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(c.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
if err != nil {
return err
}
_, err = c.db.Exec(c.queries.deleteSubscriptionTopicWithoutSub)
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
return err
}
// DB returns the underlying database connection (for testing)
func (c *webPushStore) DB() *sql.DB {
return c.db
}
// Close closes the underlying database connection
func (c *webPushStore) Close() error {
return c.db.Close()

View File

@@ -1,130 +0,0 @@
package server
import (
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
// PostgreSQL-specific queries
const (
pgCreateWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schema_version (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
// Schema management queries
pgCurrentWebPushSchemaVersion = 1
pgInsertWebPushSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
pgSelectWebPushSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
)
// PostgreSQL-specific webpush queries
var pgWebPushQueries = &webPushQueries{
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = $1`,
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = $1`,
selectSubscriptionsForTopic: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = $1
ORDER BY endpoint
`,
selectSubscriptionsExpiringSoon: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= $1
`,
insertSubscription: `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = EXCLUDED.key_auth, key_p256dh = EXCLUDED.key_p256dh, user_id = EXCLUDED.user_id, subscriber_ip = EXCLUDED.subscriber_ip, updated_at = EXCLUDED.updated_at, warned_at = EXCLUDED.warned_at
`,
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = $1 WHERE id = $2`,
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = $1`,
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = $1`,
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= $1`,
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES ($1, $2)`,
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = $1`,
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
}
func newPgWebPushStore(connStr, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
if err := setupPgWebPushDB(db); err != nil {
return nil, err
}
if err := runPgWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
queries: pgWebPushQueries,
}, nil
}
func setupPgWebPushDB(db *sql.DB) error {
// If 'schema_version' table does not exist, this must be a new database
rows, err := db.Query(pgSelectWebPushSchemaVersionQuery)
if err != nil {
return setupNewPgWebPushDB(db)
}
defer rows.Close()
// If table exists but no rows, also create new
if !rows.Next() {
return setupNewPgWebPushDB(db)
}
return nil
}
func setupNewPgWebPushDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(pgInsertWebPushSchemaVersion, pgCurrentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runPgWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}

View File

@@ -1,126 +0,0 @@
package server
import (
"database/sql"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// SQLite-specific queries
const (
sqliteCreateWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
sqliteWebPushBuiltinStartupQueries = `
PRAGMA foreign_keys = ON;
`
// Schema management queries
sqliteCurrentWebPushSchemaVersion = 1
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// SQLite-specific webpush queries
var sqliteWebPushQueries = &webPushQueries{
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = ?`,
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`,
selectSubscriptionsForTopic: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`,
selectSubscriptionsExpiringSoon: `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`,
insertSubscription: `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`,
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = ? WHERE id = ?`,
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = ?`,
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = ?`,
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= ?`,
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`,
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = ?`,
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
}
func newSqliteWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSqliteWebPushDB(db); err != nil {
return nil, err
}
if err := runSqliteWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
queries: sqliteWebPushQueries,
}, nil
}
func setupSqliteWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
if err != nil {
return setupNewSqliteWebPushDB(db)
}
return rows.Close()
}
func setupNewSqliteWebPushDB(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runSqliteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
if _, err := db.Exec(sqliteWebPushBuiltinStartupQueries); err != nil {
return err
}
return nil
}

View File

@@ -134,7 +134,7 @@ func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
// Mark them as warning sent
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
rows, err := webPush.DB().Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
require.Nil(t, err)
defer rows.Close()
var endpoint string
@@ -156,7 +156,7 @@ func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire
_, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Should not be cleaned up yet
@@ -180,7 +180,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
require.Len(t, subs, 1)
// Fake-mark them as expired
_, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Run expiration
@@ -192,7 +192,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
require.Len(t, subs, 0)
}
func newTestWebPushStore(t *testing.T) WebPushStore {
func newTestWebPushStore(t *testing.T) *webPushStore {
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
return webPush

View File

@@ -6,17 +6,17 @@ import (
"encoding/json"
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util"
"net/netip"
"path/filepath"
"slices"
"strings"
"sync"
"time"
)
const (
@@ -43,10 +43,100 @@ const (
var (
errNoTokenProvided = errors.New("no token provided")
errTopicOwnedByOthers = errors.New("topic owned by others")
errNoRows = errors.New("no rows found")
)
// Manager-related queries
const (
createTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
@@ -236,6 +326,229 @@ const (
`
)
// Schema management queries
const (
currentSchemaVersion = 6
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// 1 -> 2 (complex migration!)
migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
JOIN access a ON u.user = a.user;
DROP TABLE access;
DROP TABLE user_old;
`
// 2 -> 3
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
// 3 -> 4
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
`
// 4 -> 5
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
// 5 -> 6
migrate5To6UpdateQueries = `
PRAGMA foreign_keys=off;
-- Alter user table: Add provisioned column
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
INSERT INTO user
SELECT
id,
tier_id,
user,
pass,
role,
prefs,
sync_topic,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
stripe_customer_id,
stripe_subscription_id,
stripe_subscription_status,
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created,
deleted
FROM user_old;
DROP TABLE user_old;
-- Alter user_access table: Add provisioned column
ALTER TABLE user_access RENAME TO user_access_old;
CREATE TABLE user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INTEGER NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
// Manager is an implementation of Manager. It stores users and access control list
// in a SQLite database.
type Manager struct {
@@ -270,20 +583,28 @@ func NewManager(config *Config) (*Manager, error) {
if config.QueueWriterInterval.Seconds() <= 0 {
config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
}
var manager *Manager
var err error
// Select database backend based on connection string
if strings.HasPrefix(config.Filename, "postgres:") {
manager, err = newPgManager(config)
} else {
manager, err = newSqliteManager(config)
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil {
return nil, err
}
if err := setupDB(db); err != nil {
return nil, err
}
if err := runStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
manager := &Manager{
db: db,
config: config,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}
if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
return nil, err
}
@@ -1608,6 +1929,173 @@ func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
}
func runStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
func setupDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rowsSV, err := db.Query(selectSchemaVersionQuery)
if err != nil {
return setupNewDB(db)
}
defer rowsSV.Close()
// If 'schemaVersion' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
return errors.New("cannot determine schema version: database file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewDB(db *sql.DB) error {
if _, err := db.Exec(createTablesQueries); err != nil {
return err
}
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
if err != nil {
return err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom4(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}

View File

@@ -1,175 +0,0 @@
package user
import (
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/log"
)
// PostgreSQL-specific queries
const (
pgCreateTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_code ON tier (code);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS "user" (
id TEXT PRIMARY KEY,
tier_id TEXT,
"user" TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT 0,
stats_emails INT NOT NULL DEFAULT 0,
stats_calls INT NOT NULL DEFAULT 0,
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user ON "user" ("user");
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_customer_id ON "user" (stripe_customer_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_subscription_id ON "user" (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id TEXT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schema_version (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO "user" (id, "user", pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', 0, EXTRACT(EPOCH FROM NOW())::INT)
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
pgCurrentSchemaVersion = 1
pgInsertSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
pgUpdateSchemaVersion = `UPDATE schema_version SET version = $1 WHERE id = 1`
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
)
// newPgManager creates a new PostgreSQL-backed user manager
func newPgManager(config *Config) (*Manager, error) {
db, err := sql.Open("pgx", config.Filename)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
}
if err := setupPgDB(db); err != nil {
return nil, err
}
if err := runPgStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
return &Manager{
config: config,
db: db,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}, nil
}
func runPgStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}
func setupPgDB(db *sql.DB) error {
// If 'schema_version' table does not exist, this must be a new database
rowsSV, err := db.Query(pgSelectSchemaVersionQuery)
if err != nil {
return setupNewPgDB(db)
}
defer rowsSV.Close()
// If 'schema_version' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
// Table exists but no rows, insert version
return setupNewPgDB(db)
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == pgCurrentSchemaVersion {
return nil
} else if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
}
// No migrations needed yet for PG (starting at version 1)
log.Tag(tag).Info("PostgreSQL user database schema is up to date (version %d)", schemaVersion)
return nil
}
func setupNewPgDB(db *sql.DB) error {
if _, err := db.Exec(pgCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
return err
}
return nil
}

View File

@@ -1,532 +0,0 @@
package user
import (
"database/sql"
"errors"
"fmt"
"path/filepath"
"github.com/mattn/go-sqlite3"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
errNoRows = errors.New("no rows found")
)
// SQLite-specific queries
const (
sqliteCreateTablesQueries = `
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
sqliteBuiltinStartupQueries = `
PRAGMA foreign_keys = ON;
`
)
// Schema management queries
const (
sqliteCurrentSchemaVersion = 6
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
// Migration queries (1->2 through 5->6)
// These are SQLite-specific due to PRAGMA, ALTER TABLE syntax, etc.
migrate1To2CreateTablesQueries = `
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
JOIN access a ON u.user = a.user;
DROP TABLE access;
DROP TABLE user_old;
`
migrate2To3UpdateQueries = `
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
DROP INDEX IF EXISTS idx_tier_price_id;
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
`
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
`
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
migrate5To6UpdateQueries = `
PRAGMA foreign_keys=off;
-- Alter user table: Add provisioned column
ALTER TABLE user RENAME TO user_old;
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
INSERT INTO user
SELECT
id,
tier_id,
user,
pass,
role,
prefs,
sync_topic,
0, -- provisioned
stats_messages,
stats_emails,
stats_calls,
stripe_customer_id,
stripe_subscription_id,
stripe_subscription_status,
stripe_subscription_interval,
stripe_subscription_paid_until,
stripe_subscription_cancel_at,
created,
deleted
FROM user_old;
DROP TABLE user_old;
-- Alter user_access table: Add provisioned column
ALTER TABLE user_access RENAME TO user_access_old;
CREATE TABLE user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INTEGER NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
DROP TABLE user_access_old;
-- Alter user_token table: Add provisioned column
ALTER TABLE user_token RENAME TO user_token_old;
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
DROP TABLE user_token_old;
-- Recreate indices
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
-- Re-enable foreign keys
PRAGMA foreign_keys=on;
`
)
var (
sqliteMigrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
// newSqliteManager creates a new SQLite-backed Manager
func newSqliteManager(config *Config) (*Manager, error) {
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil {
return nil, err
}
if err := setupSqliteDB(db); err != nil {
return nil, err
}
if err := runSqliteStartupQueries(db, config.StartupQueries); err != nil {
return nil, err
}
return &Manager{
db: db,
config: config,
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}, nil
}
func runSqliteStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
return err
}
return nil
}
func setupSqliteDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
if err != nil {
return setupNewSqliteDB(db)
}
defer rowsSV.Close()
// If 'schemaVersion' table exists, read version and potentially upgrade
schemaVersion := 0
if !rowsSV.Next() {
return errors.New("cannot determine schema version: database file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
// Do migrations
if schemaVersion == sqliteCurrentSchemaVersion {
return nil
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
fn, ok := sqliteMigrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewSqliteDB(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Rename user -> user_old, and create new tables
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
if err != nil {
return err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return err
}
usernames = append(usernames, username)
}
if err := rows.Close(); err != nil {
return err
}
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func migrateFrom2(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom4(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
// isSqliteConstraintUniqueError checks if the error is a SQLite unique constraint error
func isSqliteConstraintUniqueError(err error) bool {
if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
return true
}
return errors.Is(err, sqlite3.ErrConstraintUnique)
}

View File

@@ -1578,7 +1578,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion)
require.Equal(t, currentSchemaVersion, schemaVersion)
require.Nil(t, rows.Close())
}

View File

@@ -405,5 +405,48 @@
"web_push_subscription_expiring_title": "Notifications will be paused",
"web_push_subscription_expiring_body": "Open ntfy to continue receiving notifications",
"web_push_unknown_notification_title": "Unknown notification received from server",
"web_push_unknown_notification_body": "You may need to update ntfy by opening the web app"
"web_push_unknown_notification_body": "You may need to update ntfy by opening the web app",
"nav_button_admin": "Admin",
"admin_users_title": "Users",
"admin_users_description": "Manage users and their access permissions. Admin users cannot be modified via the web interface.",
"admin_users_table_username_header": "Username",
"admin_users_table_role_header": "Role",
"admin_users_table_tier_header": "Tier",
"admin_users_table_grants_header": "Access grants",
"admin_users_table_actions_header": "Actions",
"admin_users_table_grant_tooltip": "Permission: {{permission}}",
"admin_users_table_grant_provisioned_tooltip": "Permission: {{permission}} (provisioned, cannot be changed)",
"admin_users_table_add_access_tooltip": "Add access grant",
"admin_users_table_edit_tooltip": "Edit user",
"admin_users_table_delete_tooltip": "Delete user",
"admin_users_table_admin_no_actions": "Cannot modify admin users",
"admin_users_provisioned_tooltip": "Provisioned user (defined in server config)",
"admin_users_provisioned_cannot_edit": "Provisioned users cannot be edited or deleted",
"admin_users_role_admin": "Admin",
"admin_users_role_user": "User",
"admin_users_add_button": "Add user",
"admin_users_add_dialog_title": "Add user",
"admin_users_add_dialog_username_label": "Username",
"admin_users_add_dialog_password_label": "Password",
"admin_users_add_dialog_tier_label": "Tier",
"admin_users_add_dialog_tier_helper": "Optional. Leave empty for no tier.",
"admin_users_edit_dialog_title": "Edit user {{username}}",
"admin_users_edit_dialog_password_label": "New password",
"admin_users_edit_dialog_password_helper": "Leave empty to keep current password",
"admin_users_edit_dialog_tier_label": "Tier",
"admin_users_edit_dialog_tier_helper": "Leave empty to keep current tier",
"admin_users_delete_dialog_title": "Delete user",
"admin_users_delete_dialog_description": "Are you sure you want to delete user {{username}}? This action cannot be undone.",
"admin_users_delete_dialog_button": "Delete user",
"admin_access_add_dialog_title": "Add access for {{username}}",
"admin_access_add_dialog_topic_label": "Topic",
"admin_access_add_dialog_topic_helper": "Topic name or pattern (e.g. mytopic or alerts-*)",
"admin_access_add_dialog_permission_label": "Permission",
"admin_access_permission_read_write": "Read & Write",
"admin_access_permission_read_only": "Read only",
"admin_access_permission_write_only": "Write only",
"admin_access_permission_deny_all": "Deny all",
"admin_access_delete_dialog_title": "Remove access",
"admin_access_delete_dialog_description": "Are you sure you want to remove access to topic {{topic}} for user {{username}}?",
"admin_access_delete_dialog_button": "Remove access"
}

82
web/src/app/AdminApi.js Normal file
View File

@@ -0,0 +1,82 @@
import { fetchOrThrow } from "./errors";
import { withBearerAuth } from "./utils";
import session from "./Session";
const usersUrl = (baseUrl) => `${baseUrl}/v1/users`;
const usersAccessUrl = (baseUrl) => `${baseUrl}/v1/users/access`;
class AdminApi {
async getUsers() {
const url = usersUrl(config.base_url);
console.log(`[AdminApi] Fetching users ${url}`);
const response = await fetchOrThrow(url, {
headers: withBearerAuth({}, session.token()),
});
return response.json();
}
async addUser(username, password, tier) {
const url = usersUrl(config.base_url);
const body = { username, password };
if (tier) {
body.tier = tier;
}
console.log(`[AdminApi] Adding user ${url}`);
await fetchOrThrow(url, {
method: "POST",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify(body),
});
}
async updateUser(username, password, tier) {
const url = usersUrl(config.base_url);
const body = { username };
if (password) {
body.password = password;
}
if (tier) {
body.tier = tier;
}
console.log(`[AdminApi] Updating user ${url}`);
await fetchOrThrow(url, {
method: "PUT",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify(body),
});
}
async deleteUser(username) {
const url = usersUrl(config.base_url);
console.log(`[AdminApi] Deleting user ${url}`);
await fetchOrThrow(url, {
method: "DELETE",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({ username }),
});
}
async allowAccess(username, topic, permission) {
const url = usersAccessUrl(config.base_url);
console.log(`[AdminApi] Allowing access ${url}`);
await fetchOrThrow(url, {
method: "PUT",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({ username, topic, permission }),
});
}
async resetAccess(username, topic) {
const url = usersAccessUrl(config.base_url);
console.log(`[AdminApi] Resetting access ${url}`);
await fetchOrThrow(url, {
method: "DELETE",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({ username, topic }),
});
}
}
const adminApi = new AdminApi();
export default adminApi;

View File

@@ -0,0 +1,593 @@
import * as React from "react";
import { useContext, useEffect, useState } from "react";
import {
Alert,
CardActions,
CardContent,
Chip,
FormControl,
Select,
Table,
TableBody,
TableCell,
TableHead,
TableRow,
Tooltip,
Typography,
Container,
Card,
Button,
Dialog,
DialogTitle,
DialogContent,
TextField,
IconButton,
MenuItem,
DialogContentText,
useMediaQuery,
useTheme,
Stack,
CircularProgress,
Box,
} from "@mui/material";
import EditIcon from "@mui/icons-material/Edit";
import { useTranslation } from "react-i18next";
import DeleteOutlineIcon from "@mui/icons-material/DeleteOutline";
import AddIcon from "@mui/icons-material/Add";
import CloseIcon from "@mui/icons-material/Close";
import LockIcon from "@mui/icons-material/Lock";
import routes from "./routes";
import { AccountContext } from "./App";
import DialogFooter from "./DialogFooter";
import { Paragraph } from "./styles";
import { UnauthorizedError } from "../app/errors";
import session from "../app/Session";
import adminApi from "../app/AdminApi";
import { Role } from "../app/AccountApi";
const Admin = () => {
const { account } = useContext(AccountContext);
// Redirect non-admins away
if (!session.exists() || (account && account.role !== Role.ADMIN)) {
window.location.href = routes.app;
return null;
}
// Wait for account to load
if (!account) {
return (
<Box sx={{ display: "flex", justifyContent: "center", alignItems: "center", height: "100vh" }}>
<CircularProgress />
</Box>
);
}
return (
<Container maxWidth="lg" sx={{ marginTop: 3, marginBottom: 3 }}>
<Stack spacing={3}>
<Users />
</Stack>
</Container>
);
};
const Users = () => {
const { t } = useTranslation();
const [users, setUsers] = useState(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState("");
const [addDialogKey, setAddDialogKey] = useState(0);
const [addDialogOpen, setAddDialogOpen] = useState(false);
const loadUsers = async () => {
try {
setLoading(true);
const data = await adminApi.getUsers();
setUsers(data);
setError("");
} catch (e) {
console.log(`[Admin] Error loading users`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
} finally {
setLoading(false);
}
};
useEffect(() => {
loadUsers();
}, []);
const handleAddClick = () => {
setAddDialogKey((prev) => prev + 1);
setAddDialogOpen(true);
};
const handleDialogClose = () => {
setAddDialogOpen(false);
loadUsers();
};
return (
<Card sx={{ padding: 1 }} aria-label={t("admin_users_title")}>
<CardContent sx={{ paddingBottom: 1 }}>
<Typography variant="h5" sx={{ marginBottom: 2 }}>
{t("admin_users_title")}
</Typography>
<Paragraph>{t("admin_users_description")}</Paragraph>
{error && (
<Alert severity="error" sx={{ mb: 2 }}>
{error}
</Alert>
)}
{loading && (
<Box sx={{ display: "flex", justifyContent: "center", p: 3 }}>
<CircularProgress />
</Box>
)}
{!loading && users && (
<div style={{ width: "100%", overflowX: "auto" }}>
<UsersTable users={users} onUserChanged={loadUsers} />
</div>
)}
</CardContent>
<CardActions>
<Button onClick={handleAddClick} startIcon={<AddIcon />}>
{t("admin_users_add_button")}
</Button>
</CardActions>
<AddUserDialog key={`addUserDialog${addDialogKey}`} open={addDialogOpen} onClose={handleDialogClose} />
</Card>
);
};
const UsersTable = (props) => {
const { t } = useTranslation();
const [editDialogKey, setEditDialogKey] = useState(0);
const [editDialogOpen, setEditDialogOpen] = useState(false);
const [deleteDialogOpen, setDeleteDialogOpen] = useState(false);
const [accessDialogKey, setAccessDialogKey] = useState(0);
const [accessDialogOpen, setAccessDialogOpen] = useState(false);
const [deleteAccessDialogOpen, setDeleteAccessDialogOpen] = useState(false);
const [selectedUser, setSelectedUser] = useState(null);
const [selectedGrant, setSelectedGrant] = useState(null);
const { users } = props;
const handleEditClick = (user) => {
setEditDialogKey((prev) => prev + 1);
setSelectedUser(user);
setEditDialogOpen(true);
};
const handleDeleteClick = (user) => {
setSelectedUser(user);
setDeleteDialogOpen(true);
};
const handleAddAccessClick = (user) => {
setAccessDialogKey((prev) => prev + 1);
setSelectedUser(user);
setAccessDialogOpen(true);
};
const handleDeleteAccessClick = (user, grant) => {
setSelectedUser(user);
setSelectedGrant(grant);
setDeleteAccessDialogOpen(true);
};
const handleDialogClose = () => {
setEditDialogOpen(false);
setDeleteDialogOpen(false);
setAccessDialogOpen(false);
setDeleteAccessDialogOpen(false);
setSelectedUser(null);
setSelectedGrant(null);
props.onUserChanged();
};
return (
<>
<Table size="small" aria-label={t("admin_users_title")}>
<TableHead>
<TableRow>
<TableCell sx={{ paddingLeft: 0 }}>{t("admin_users_table_username_header")}</TableCell>
<TableCell>{t("admin_users_table_role_header")}</TableCell>
<TableCell>{t("admin_users_table_tier_header")}</TableCell>
<TableCell>{t("admin_users_table_grants_header")}</TableCell>
<TableCell align="right">{t("admin_users_table_actions_header")}</TableCell>
</TableRow>
</TableHead>
<TableBody>
{users.map((user) => (
<TableRow key={user.username} sx={{ "&:last-child td, &:last-child th": { border: 0 } }}>
<TableCell component="th" scope="row" sx={{ paddingLeft: 0 }}>
<Stack direction="row" spacing={0.5} alignItems="center">
<span>{user.username}</span>
{user.provisioned && (
<Tooltip title={t("admin_users_provisioned_tooltip")}>
<LockIcon fontSize="small" color="disabled" />
</Tooltip>
)}
</Stack>
</TableCell>
<TableCell>
<RoleChip role={user.role} />
</TableCell>
<TableCell>{user.tier || "-"}</TableCell>
<TableCell>
{user.grants && user.grants.length > 0 ? (
<Stack direction="row" spacing={0.5} flexWrap="wrap" useFlexGap>
{user.grants.map((grant, idx) => {
const canDelete = user.role !== "admin" && !grant.provisioned;
const tooltipText = grant.provisioned
? t("admin_users_table_grant_provisioned_tooltip", { permission: grant.permission })
: t("admin_users_table_grant_tooltip", { permission: grant.permission });
return (
<Tooltip key={idx} title={tooltipText}>
<Chip
label={grant.topic}
size="small"
variant={grant.provisioned ? "filled" : "outlined"}
color={grant.provisioned ? "default" : "default"}
icon={grant.provisioned ? <LockIcon fontSize="small" /> : undefined}
onDelete={canDelete ? () => handleDeleteAccessClick(user, grant) : undefined}
/>
</Tooltip>
);
})}
</Stack>
) : (
"-"
)}
</TableCell>
<TableCell align="right" sx={{ whiteSpace: "nowrap" }}>
{user.role !== "admin" && !user.provisioned ? (
<>
<Tooltip title={t("admin_users_table_add_access_tooltip")}>
<IconButton onClick={() => handleAddAccessClick(user)} size="small">
<AddIcon />
</IconButton>
</Tooltip>
<Tooltip title={t("admin_users_table_edit_tooltip")}>
<IconButton onClick={() => handleEditClick(user)} size="small">
<EditIcon />
</IconButton>
</Tooltip>
<Tooltip title={t("admin_users_table_delete_tooltip")}>
<IconButton onClick={() => handleDeleteClick(user)} size="small">
<DeleteOutlineIcon />
</IconButton>
</Tooltip>
</>
) : user.role !== "admin" && user.provisioned ? (
<>
<Tooltip title={t("admin_users_table_add_access_tooltip")}>
<IconButton onClick={() => handleAddAccessClick(user)} size="small">
<AddIcon />
</IconButton>
</Tooltip>
<Tooltip title={t("admin_users_provisioned_cannot_edit")}>
<span>
<IconButton disabled size="small">
<EditIcon />
</IconButton>
<IconButton disabled size="small">
<DeleteOutlineIcon />
</IconButton>
</span>
</Tooltip>
</>
) : (
<Tooltip title={t("admin_users_table_admin_no_actions")}>
<span>
<IconButton disabled size="small">
<EditIcon />
</IconButton>
</span>
</Tooltip>
)}
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
<EditUserDialog key={`editUserDialog${editDialogKey}`} open={editDialogOpen} user={selectedUser} onClose={handleDialogClose} />
<DeleteUserDialog open={deleteDialogOpen} user={selectedUser} onClose={handleDialogClose} />
<AddAccessDialog key={`addAccessDialog${accessDialogKey}`} open={accessDialogOpen} user={selectedUser} onClose={handleDialogClose} />
<DeleteAccessDialog open={deleteAccessDialogOpen} user={selectedUser} grant={selectedGrant} onClose={handleDialogClose} />
</>
);
};
const RoleChip = ({ role }) => {
const { t } = useTranslation();
if (role === "admin") {
return <Chip label={t("admin_users_role_admin")} size="small" color="primary" />;
}
return <Chip label={t("admin_users_role_user")} size="small" variant="outlined" />;
};
const AddUserDialog = (props) => {
const theme = useTheme();
const { t } = useTranslation();
const [error, setError] = useState("");
const [username, setUsername] = useState("");
const [password, setPassword] = useState("");
const [tier, setTier] = useState("");
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
const handleSubmit = async () => {
try {
await adminApi.addUser(username, password, tier || undefined);
props.onClose();
} catch (e) {
console.log(`[Admin] Error adding user`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
}
};
return (
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
<DialogTitle>{t("admin_users_add_dialog_title")}</DialogTitle>
<DialogContent>
<TextField
margin="dense"
id="username"
label={t("admin_users_add_dialog_username_label")}
type="text"
value={username}
onChange={(ev) => setUsername(ev.target.value)}
fullWidth
variant="standard"
autoFocus
/>
<TextField
margin="dense"
id="password"
label={t("admin_users_add_dialog_password_label")}
type="password"
value={password}
onChange={(ev) => setPassword(ev.target.value)}
fullWidth
variant="standard"
/>
<TextField
margin="dense"
id="tier"
label={t("admin_users_add_dialog_tier_label")}
type="text"
value={tier}
onChange={(ev) => setTier(ev.target.value)}
fullWidth
variant="standard"
helperText={t("admin_users_add_dialog_tier_helper")}
/>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button onClick={handleSubmit} disabled={!username || !password}>
{t("common_add")}
</Button>
</DialogFooter>
</Dialog>
);
};
const EditUserDialog = (props) => {
const theme = useTheme();
const { t } = useTranslation();
const [error, setError] = useState("");
const [password, setPassword] = useState("");
const [tier, setTier] = useState(props.user?.tier || "");
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
const handleSubmit = async () => {
try {
await adminApi.updateUser(props.user.username, password || undefined, tier || undefined);
props.onClose();
} catch (e) {
console.log(`[Admin] Error updating user`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
}
};
if (!props.user) {
return null;
}
return (
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
<DialogTitle>{t("admin_users_edit_dialog_title", { username: props.user.username })}</DialogTitle>
<DialogContent>
<TextField
margin="dense"
id="password"
label={t("admin_users_edit_dialog_password_label")}
type="password"
value={password}
onChange={(ev) => setPassword(ev.target.value)}
fullWidth
variant="standard"
helperText={t("admin_users_edit_dialog_password_helper")}
/>
<TextField
margin="dense"
id="tier"
label={t("admin_users_edit_dialog_tier_label")}
type="text"
value={tier}
onChange={(ev) => setTier(ev.target.value)}
fullWidth
variant="standard"
helperText={t("admin_users_edit_dialog_tier_helper")}
/>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button onClick={handleSubmit} disabled={!password && !tier}>
{t("common_save")}
</Button>
</DialogFooter>
</Dialog>
);
};
const DeleteUserDialog = (props) => {
const { t } = useTranslation();
const [error, setError] = useState("");
const handleSubmit = async () => {
try {
await adminApi.deleteUser(props.user.username);
props.onClose();
} catch (e) {
console.log(`[Admin] Error deleting user`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
}
};
if (!props.user) {
return null;
}
return (
<Dialog open={props.open} onClose={props.onClose}>
<DialogTitle>{t("admin_users_delete_dialog_title")}</DialogTitle>
<DialogContent>
<DialogContentText>{t("admin_users_delete_dialog_description", { username: props.user.username })}</DialogContentText>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button onClick={handleSubmit} color="error">
{t("admin_users_delete_dialog_button")}
</Button>
</DialogFooter>
</Dialog>
);
};
const AddAccessDialog = (props) => {
const theme = useTheme();
const { t } = useTranslation();
const [error, setError] = useState("");
const [topic, setTopic] = useState("");
const [permission, setPermission] = useState("read-write");
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
const handleSubmit = async () => {
try {
await adminApi.allowAccess(props.user.username, topic, permission);
props.onClose();
} catch (e) {
console.log(`[Admin] Error adding access`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
}
};
if (!props.user) {
return null;
}
return (
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
<DialogTitle>{t("admin_access_add_dialog_title", { username: props.user.username })}</DialogTitle>
<DialogContent>
<TextField
margin="dense"
id="topic"
label={t("admin_access_add_dialog_topic_label")}
type="text"
value={topic}
onChange={(ev) => setTopic(ev.target.value)}
fullWidth
variant="standard"
autoFocus
helperText={t("admin_access_add_dialog_topic_helper")}
/>
<FormControl fullWidth variant="standard" sx={{ mt: 2 }}>
<Select
value={permission}
onChange={(ev) => setPermission(ev.target.value)}
label={t("admin_access_add_dialog_permission_label")}
>
<MenuItem value="read-write">{t("admin_access_permission_read_write")}</MenuItem>
<MenuItem value="read-only">{t("admin_access_permission_read_only")}</MenuItem>
<MenuItem value="write-only">{t("admin_access_permission_write_only")}</MenuItem>
<MenuItem value="deny-all">{t("admin_access_permission_deny_all")}</MenuItem>
</Select>
</FormControl>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button onClick={handleSubmit} disabled={!topic}>
{t("common_add")}
</Button>
</DialogFooter>
</Dialog>
);
};
const DeleteAccessDialog = (props) => {
const { t } = useTranslation();
const [error, setError] = useState("");
const handleSubmit = async () => {
try {
await adminApi.resetAccess(props.user.username, props.grant.topic);
props.onClose();
} catch (e) {
console.log(`[Admin] Error removing access`, e);
if (e instanceof UnauthorizedError) {
await session.resetAndRedirect(routes.login);
} else {
setError(e.message);
}
}
};
if (!props.user || !props.grant) {
return null;
}
return (
<Dialog open={props.open} onClose={props.onClose}>
<DialogTitle>{t("admin_access_delete_dialog_title")}</DialogTitle>
<DialogContent>
<DialogContentText>
{t("admin_access_delete_dialog_description", { username: props.user.username, topic: props.grant.topic })}
</DialogContentText>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button onClick={handleSubmit} color="error">
{t("admin_access_delete_dialog_button")}
</Button>
</DialogFooter>
</Dialog>
);
};
export default Admin;

View File

@@ -20,6 +20,7 @@ import Messaging from "./Messaging";
import Login from "./Login";
import Signup from "./Signup";
import Account from "./Account";
import Admin from "./Admin";
import initI18n from "../app/i18n"; // Translations!
import prefs, { THEME } from "../app/Prefs";
import RTLCacheProvider from "./RTLCacheProvider";
@@ -80,6 +81,7 @@ const App = () => {
<Route element={<Layout />}>
<Route path={routes.app} element={<AllSubscriptions />} />
<Route path={routes.account} element={<Account />} />
<Route path={routes.admin} element={<Admin />} />
<Route path={routes.settings} element={<Preferences />} />
<Route path={routes.subscription} element={<SingleSubscription />} />
<Route path={routes.subscriptionExternal} element={<SingleSubscription />} />

View File

@@ -25,6 +25,7 @@ import { useContext, useState } from "react";
import ChatBubbleOutlineIcon from "@mui/icons-material/ChatBubbleOutline";
import Person from "@mui/icons-material/Person";
import SettingsIcon from "@mui/icons-material/Settings";
import AdminPanelSettingsIcon from "@mui/icons-material/AdminPanelSettings";
import AddIcon from "@mui/icons-material/Add";
import { useLocation, useNavigate } from "react-router-dom";
import { ChatBubble, MoreVert, NotificationsOffOutlined, Send } from "@mui/icons-material";
@@ -164,6 +165,14 @@ const NavList = (props) => {
<ListItemText primary={t("nav_button_account")} />
</ListItemButton>
)}
{session.exists() && isAdmin && (
<ListItemButton onClick={() => navigate(routes.admin)} selected={location.pathname === routes.admin}>
<ListItemIcon>
<AdminPanelSettingsIcon />
</ListItemIcon>
<ListItemText primary={t("nav_button_admin")} />
</ListItemButton>
)}
<ListItemButton onClick={() => navigate(routes.settings)} selected={location.pathname === routes.settings}>
<ListItemIcon>
<SettingsIcon />

View File

@@ -6,6 +6,7 @@ const routes = {
signup: "/signup",
app: config.app_root,
account: "/account",
admin: "/admin",
settings: "/settings",
subscription: "/:topic",
subscriptionExternal: "/:baseUrl/:topic",