mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-19 13:50:46 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a940ad289 | ||
|
|
75b2ca7dec |
@@ -181,7 +181,6 @@ I've added a ⭐ to projects or posts that have a significant following, or had
|
||||
- [ntfy-heartbeat-monitor](https://codeberg.org/RockWolf/ntfy-heartbeat-monitor) - Application for implementing heartbeat monitoring/alerting by utilizing ntfy
|
||||
- [ntfy-bridge](https://github.com/AlexGaudon/ntfy-bridge) - An application to bridge Discord messages (or webhooks) to ntfy.
|
||||
- [ntailfy](https://github.com/leukosaima/ntailfy) - ntfy notifications when Tailscale devices connect/disconnect (Go)
|
||||
- [BRun](https://github.com/cbrake/brun) - Native Linux automation platform connecting triggers to actions without containers (Go)
|
||||
|
||||
## Blog + forum posts
|
||||
|
||||
|
||||
4
go.mod
4
go.mod
@@ -32,7 +32,6 @@ require github.com/pkg/errors v0.9.1 // indirect
|
||||
require (
|
||||
firebase.google.com/go/v4 v4.18.0
|
||||
github.com/SherClockHolmes/webpush-go v1.4.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/microcosm-cc/bluemonday v1.0.27
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/stripe/stripe-go/v74 v74.30.0
|
||||
@@ -73,9 +72,6 @@ require (
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
9
go.sum
9
go.sum
@@ -104,14 +104,6 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@@ -152,7 +144,6 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
|
||||
@@ -4,89 +4,351 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// MessageCache is the interface for message storage backends
|
||||
type MessageCache interface {
|
||||
AddMessage(m *message) error
|
||||
AddMessages(ms []*message) error
|
||||
Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error)
|
||||
MessagesDue() ([]*message, error)
|
||||
MessagesExpired() ([]string, error)
|
||||
Message(id string) (*message, error)
|
||||
MarkPublished(m *message) error
|
||||
MessageCounts() (map[string]int, error)
|
||||
Topics() (map[string]*topic, error)
|
||||
DeleteMessages(ids ...string) error
|
||||
ExpireMessages(topics ...string) error
|
||||
AttachmentsExpired() ([]string, error)
|
||||
MarkAttachmentsDeleted(ids ...string) error
|
||||
AttachmentBytesUsedBySender(sender string) (int64, error)
|
||||
AttachmentBytesUsedByUser(userID string) (int64, error)
|
||||
UpdateStats(messages int64) error
|
||||
Stats() (messages int64, err error)
|
||||
DB() *sql.DB
|
||||
Close() error
|
||||
var (
|
||||
errUnexpectedMessageType = errors.New("unexpected message type")
|
||||
errMessageNotFound = errors.New("message not found")
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// Messages cache
|
||||
const (
|
||||
createMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
COMMIT;
|
||||
`
|
||||
insertMessageQuery = `
|
||||
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||
selectMessagesByIDQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`
|
||||
selectMessagesSinceTimeQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceIDQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > ? OR published = 0)
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesLatestQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND published = 1
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
selectMessagesDueQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||
updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
||||
selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||
|
||||
updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
|
||||
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentSchemaVersion = 13
|
||||
createSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
|
||||
// 0 -> 1
|
||||
migrate0To1AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 1 -> 2
|
||||
migrate1To2AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
migrate2To3AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
// 3 -> 4
|
||||
migrate3To4AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
migrate4To5AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_owner TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||
INSERT
|
||||
INTO messages_new (
|
||||
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||
SELECT
|
||||
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||
FROM messages;
|
||||
DROP TABLE messages;
|
||||
ALTER TABLE messages_new RENAME TO messages;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
migrate5To6AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 6 -> 7
|
||||
migrate6To7AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||
`
|
||||
|
||||
// 7 -> 8
|
||||
migrate7To8AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 8 -> 9
|
||||
migrate8To9AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
`
|
||||
|
||||
// 9 -> 10
|
||||
migrate9To10AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
migrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
migrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 12 -> 13
|
||||
migrate12To13AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: migrateFrom0,
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
6: migrateFrom6,
|
||||
7: migrateFrom7,
|
||||
8: migrateFrom8,
|
||||
9: migrateFrom9,
|
||||
10: migrateFrom10,
|
||||
11: migrateFrom11,
|
||||
12: migrateFrom12,
|
||||
}
|
||||
)
|
||||
|
||||
type messageCache struct {
|
||||
db *sql.DB
|
||||
queue *util.BatchingQueue[*message]
|
||||
nop bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// commonMessageCache contains shared logic for all message cache implementations
|
||||
type commonMessageCache struct {
|
||||
db *sql.DB
|
||||
queue *util.BatchingQueue[*message]
|
||||
queries *messageCacheQueries
|
||||
nop bool // If true, cache ignores all messages
|
||||
mu sync.Mutex // Lock for concurrent access
|
||||
// newSqliteCache creates a SQLite file-backed cache
|
||||
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) {
|
||||
// Check the parent directory of the database file (makes for friendly error messages)
|
||||
parentDir := filepath.Dir(filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
// Open database
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var queue *util.BatchingQueue[*message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
|
||||
}
|
||||
cache := &messageCache{
|
||||
db: db,
|
||||
queue: queue,
|
||||
nop: nop,
|
||||
}
|
||||
go cache.processMessageBatches()
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
var _ MessageCache = (*commonMessageCache)(nil)
|
||||
// newMemCache creates an in-memory cache
|
||||
func newMemCache() (*messageCache, error) {
|
||||
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
|
||||
}
|
||||
|
||||
// messageCacheQueries holds database-specific SQL queries
|
||||
type messageCacheQueries struct {
|
||||
insertMessage string
|
||||
deleteMessage string
|
||||
updateMessagesForTopicExpiry string
|
||||
selectRowIDFromMessageID string // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||
selectMessagesByID string
|
||||
selectMessagesSinceTime string
|
||||
selectMessagesSinceTimeIncludeScheduled string
|
||||
selectMessagesSinceID string
|
||||
selectMessagesSinceIDIncludeScheduled string
|
||||
selectMessagesLatest string
|
||||
selectMessagesDue string
|
||||
selectMessagesExpired string
|
||||
updateMessagePublished string
|
||||
selectMessageCountPerTopic string
|
||||
selectTopics string
|
||||
// newNopCache creates an in-memory cache that discards all messages;
|
||||
// it is always empty and can be used if caching is entirely disabled
|
||||
func newNopCache() (*messageCache, error) {
|
||||
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true)
|
||||
}
|
||||
|
||||
updateAttachmentDeleted string
|
||||
selectAttachmentsExpired string
|
||||
selectAttachmentsSizeBySender string
|
||||
selectAttachmentsSizeByUserID string
|
||||
|
||||
selectStats string
|
||||
updateStats string
|
||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
|
||||
// sql database, so if the stdlib's sql engine happens to open another connection and
|
||||
// you've only specified ":memory:", that connection will see a brand new database.
|
||||
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
|
||||
// Every connection to this string will point to the same in-memory database."
|
||||
func createMemoryFilename() string {
|
||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||
}
|
||||
|
||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
|
||||
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
||||
func (c *commonMessageCache) AddMessage(m *message) error {
|
||||
func (c *messageCache) AddMessage(m *message) error {
|
||||
if c.queue != nil {
|
||||
c.queue.Enqueue(m)
|
||||
return nil
|
||||
}
|
||||
return c.AddMessages([]*message{m})
|
||||
return c.addMessages([]*message{m})
|
||||
}
|
||||
|
||||
// AddMessages synchronously stores a batch of messages. If the database is locked, the transaction waits until
|
||||
// the timeout is exceeded before erroring out.
|
||||
func (c *commonMessageCache) AddMessages(ms []*message) error {
|
||||
// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until
|
||||
// SQLite's busy_timeout is exceeded before erroring out.
|
||||
func (c *messageCache) addMessages(ms []*message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.nop {
|
||||
@@ -101,7 +363,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
stmt, err := tx.Prepare(c.queries.insertMessage)
|
||||
stmt, err := tx.Prepare(insertMessageQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -113,8 +375,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
|
||||
published := m.Time <= time.Now().Unix()
|
||||
tags := strings.Join(m.Tags, ",")
|
||||
var attachmentName, attachmentType, attachmentURL string
|
||||
var attachmentSize, attachmentExpires int64
|
||||
var attachmentDeleted bool
|
||||
var attachmentSize, attachmentExpires, attachmentDeleted int64
|
||||
if m.Attachment != nil {
|
||||
attachmentName = m.Attachment.Name
|
||||
attachmentType = m.Attachment.Type
|
||||
@@ -151,7 +412,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
|
||||
attachmentSize,
|
||||
attachmentExpires,
|
||||
attachmentURL,
|
||||
attachmentDeleted, // Always false
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
m.ContentType,
|
||||
@@ -170,7 +431,7 @@ func (c *commonMessageCache) AddMessages(ms []*message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
if since.IsNone() {
|
||||
return make([]*message, 0), nil
|
||||
} else if since.IsLatest() {
|
||||
@@ -181,21 +442,13 @@ func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled
|
||||
return c.messagesSinceTime(topic, since, scheduled)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) messagesLatest(topic string) ([]*message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeIncludeScheduled, topic, since.Time().Unix())
|
||||
rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix())
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -203,8 +456,8 @@ func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker,
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
|
||||
func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
|
||||
idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -219,9 +472,9 @@ func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, sc
|
||||
idrows.Close()
|
||||
var rows *sql.Rows
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDIncludeScheduled, topic, rowID)
|
||||
rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID)
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
|
||||
rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -229,17 +482,25 @@ func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, sc
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) MessagesDue() ([]*message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||
func (c *messageCache) messagesLatest(topic string) ([]*message, error) {
|
||||
rows, err := c.db.Query(selectMessagesLatestQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||
func (c *commonMessageCache) MessagesExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||
func (c *messageCache) MessagesDue() ([]*message, error) {
|
||||
rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesExpired returns a list of IDs for messages that have expires (should be deleted)
|
||||
func (c *messageCache) MessagesExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -258,24 +519,27 @@ func (c *commonMessageCache) MessagesExpired() ([]string, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) Message(id string) (*message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||
func (c *messageCache) Message(id string) (*message, error) {
|
||||
rows, err := c.db.Query(selectMessagesByIDQuery, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !rows.Next() {
|
||||
}
|
||||
if !rows.Next() {
|
||||
return nil, errMessageNotFound
|
||||
}
|
||||
defer rows.Close()
|
||||
return readMessage(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) MarkPublished(m *message) error {
|
||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||
func (c *messageCache) MarkPublished(m *message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) MessageCounts() (map[string]int, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
|
||||
func (c *messageCache) MessageCounts() (map[string]int, error) {
|
||||
rows, err := c.db.Query(selectMessageCountPerTopicQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -294,8 +558,8 @@ func (c *commonMessageCache) MessageCounts() (map[string]int, error) {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) Topics() (map[string]*topic, error) {
|
||||
rows, err := c.db.Query(c.queries.selectTopics)
|
||||
func (c *messageCache) Topics() (map[string]*topic, error) {
|
||||
rows, err := c.db.Query(selectTopicsQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -314,36 +578,40 @@ func (c *commonMessageCache) Topics() (map[string]*topic, error) {
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) DeleteMessages(ids ...string) error {
|
||||
func (c *messageCache) DeleteMessages(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
if _, err := tx.Exec(deleteMessageQuery, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) ExpireMessages(topics ...string) error {
|
||||
func (c *messageCache) ExpireMessages(topics ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||
func (c *messageCache) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -362,37 +630,39 @@ func (c *commonMessageCache) AttachmentsExpired() ([]string, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) MarkAttachmentsDeleted(ids ...string) error {
|
||||
func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||
func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||
defer rows.Close()
|
||||
var size int64
|
||||
if !rows.Next() {
|
||||
@@ -406,45 +676,17 @@ func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, err
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) processMessageBatches() {
|
||||
func (c *messageCache) processMessageBatches() {
|
||||
if c.queue == nil {
|
||||
return
|
||||
}
|
||||
for messages := range c.queue.Dequeue() {
|
||||
if err := c.AddMessages(messages); err != nil {
|
||||
if err := c.addMessages(messages); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) UpdateStats(messages int64) error {
|
||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(c.queries.selectStats)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) DB() *sql.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
func (c *commonMessageCache) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||
defer rows.Close()
|
||||
messages := make([]*message, 0)
|
||||
@@ -534,3 +776,257 @@ func readMessage(rows *sql.Rows) (*message, error) {
|
||||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *messageCache) UpdateStats(messages int64) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(updateStatsQuery, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *messageCache) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(selectStatsQuery)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *messageCache) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
// Run startup queries
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
rowsMC, err := db.Query(selectMessagesCountQuery)
|
||||
if err != nil {
|
||||
return setupNewCacheDB(db)
|
||||
}
|
||||
rowsMC.Close()
|
||||
|
||||
// If 'messages' table exists, check 'schemaVersion' table
|
||||
schemaVersion := 0
|
||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||
if err == nil {
|
||||
defer rowsSV.Close()
|
||||
if !rowsSV.Next() {
|
||||
return errors.New("cannot determine schema version: cache file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
}
|
||||
|
||||
// Do migrations
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < currentSchemaVersion; i++ {
|
||||
fn, ok := migrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db, cacheDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewCacheDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom0(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom1(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom6(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom7(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// PostgreSQL schema version (starts at latest, no migrations needed initially)
|
||||
const pgCurrentSchemaVersion = 1
|
||||
|
||||
// Messages cache
|
||||
const (
|
||||
pgCreateMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id SERIAL PRIMARY KEY,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted BOOLEAN NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
"user" TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published BOOLEAN NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages ("user");
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion VALUES (1, 1);
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
)
|
||||
|
||||
var (
|
||||
pgMessageCacheQueries = &messageCacheQueries{
|
||||
insertMessage: `
|
||||
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, "user", content_type, encoding, published)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
|
||||
`,
|
||||
deleteMessage: `DELETE FROM messages WHERE mid = $1`,
|
||||
updateMessagesForTopicExpiry: `UPDATE messages SET expires = $1 WHERE topic = $2`,
|
||||
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = $1`,
|
||||
selectMessagesByID: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = $1
|
||||
`,
|
||||
selectMessagesSinceTime: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceTimeIncludeScheduled: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = $1 AND time >= $2
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceID: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = $1 AND id > $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceIDIncludeScheduled: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = $1 AND (id > $2 OR published = FALSE)
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesLatest: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = $1 AND published = TRUE
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`,
|
||||
selectMessagesDue: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= $1 AND published = FALSE
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= $1 AND published = TRUE`,
|
||||
updateMessagePublished: `UPDATE messages SET published = TRUE WHERE mid = $1`,
|
||||
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
|
||||
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
|
||||
|
||||
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = TRUE WHERE mid = $1`,
|
||||
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`,
|
||||
selectAttachmentsSizeBySender: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = '' AND sender = $1 AND attachment_expires >= $2`,
|
||||
selectAttachmentsSizeByUserID: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = $1 AND attachment_expires >= $2`,
|
||||
|
||||
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
|
||||
updateStats: `UPDATE stats SET value = $1 WHERE key = 'messages'`,
|
||||
}
|
||||
)
|
||||
|
||||
type pgMessageCache struct {
|
||||
*commonMessageCache
|
||||
}
|
||||
|
||||
var _ MessageCache = (*pgMessageCache)(nil)
|
||||
|
||||
// newPgCache creates a PostgreSQL-backed message cache
|
||||
func newPgCache(connectionString, startupQueries string, batchSize int, batchTimeout time.Duration) (*pgMessageCache, error) {
|
||||
db, err := sql.Open("pgx", connectionString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPgMessagesDB(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var queue *util.BatchingQueue[*message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
|
||||
}
|
||||
cache := &pgMessageCache{
|
||||
commonMessageCache: &commonMessageCache{
|
||||
db: db,
|
||||
queue: queue,
|
||||
queries: pgMessageCacheQueries,
|
||||
},
|
||||
}
|
||||
go cache.processMessageBatches()
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func setupPgMessagesDB(db *sql.DB, startupQueries string) error {
|
||||
// Run startup queries
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
rowsMC, err := db.Query(pgSelectMessagesCountQuery)
|
||||
if err != nil {
|
||||
return setupNewPgCacheDB(db)
|
||||
}
|
||||
rowsMC.Close()
|
||||
|
||||
// Future: Add migrations here when schema changes
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPgCacheDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(pgCreateMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPostgres checks if the connection string indicates a PostgreSQL database
|
||||
func isPostgres(connectionString string) bool {
|
||||
return strings.HasPrefix(connectionString, "postgres:")
|
||||
}
|
||||
@@ -1,571 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnexpectedMessageType = errors.New("unexpected message type")
|
||||
errMessageNotFound = errors.New("message not found")
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// Messages cache
|
||||
const (
|
||||
createMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMessageCacheQueries = &messageCacheQueries{
|
||||
insertMessage: `
|
||||
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
deleteMessage: `DELETE FROM messages WHERE mid = ?`,
|
||||
updateMessagesForTopicExpiry: `UPDATE messages SET expires = ? WHERE topic = ?`,
|
||||
selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = ?`, // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||
selectMessagesByID: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`,
|
||||
selectMessagesSinceTime: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceTimeIncludeScheduled: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceID: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesSinceIDIncludeScheduled: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > ? OR published = 0)
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesLatest: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND published = 1
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`,
|
||||
selectMessagesDue: `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
`,
|
||||
selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= ? AND published = 1`,
|
||||
updateMessagePublished: `UPDATE messages SET published = 1 WHERE mid = ?`,
|
||||
selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`,
|
||||
selectTopics: `SELECT topic FROM messages GROUP BY topic`,
|
||||
|
||||
updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`,
|
||||
selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`,
|
||||
selectAttachmentsSizeBySender: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`,
|
||||
selectAttachmentsSizeByUserID: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`,
|
||||
|
||||
selectStats: `SELECT value FROM stats WHERE key = 'messages'`,
|
||||
updateStats: `UPDATE stats SET value = ? WHERE key = 'messages'`,
|
||||
}
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentSchemaVersion = 13
|
||||
createSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
selectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
|
||||
// 0 -> 1
|
||||
migrate0To1AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 1 -> 2
|
||||
migrate1To2AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
migrate2To3AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
// 3 -> 4
|
||||
migrate3To4AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
migrate4To5AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_owner TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||
INSERT
|
||||
INTO messages_new (
|
||||
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||
SELECT
|
||||
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||
FROM messages;
|
||||
DROP TABLE messages;
|
||||
ALTER TABLE messages_new RENAME TO messages;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
migrate5To6AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 6 -> 7
|
||||
migrate6To7AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||
`
|
||||
|
||||
// 7 -> 8
|
||||
migrate7To8AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 8 -> 9
|
||||
migrate8To9AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
`
|
||||
|
||||
// 9 -> 10
|
||||
migrate9To10AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
migrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
migrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 12 -> 13
|
||||
migrate12To13AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: migrateFrom0,
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
6: migrateFrom6,
|
||||
7: migrateFrom7,
|
||||
8: migrateFrom8,
|
||||
9: migrateFrom9,
|
||||
10: migrateFrom10,
|
||||
11: migrateFrom11,
|
||||
12: migrateFrom12,
|
||||
}
|
||||
)
|
||||
|
||||
type sqliteMessageCache struct {
|
||||
*commonMessageCache
|
||||
}
|
||||
|
||||
var _ MessageCache = (*sqliteMessageCache)(nil)
|
||||
|
||||
// newSqliteCache creates a SQLite file-backed cache
|
||||
func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*sqliteMessageCache, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var queue *util.BatchingQueue[*message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*message](batchSize, batchTimeout)
|
||||
}
|
||||
cache := &sqliteMessageCache{
|
||||
commonMessageCache: &commonMessageCache{
|
||||
db: db,
|
||||
queue: queue,
|
||||
queries: sqliteMessageCacheQueries,
|
||||
nop: nop,
|
||||
},
|
||||
}
|
||||
go cache.processMessageBatches()
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// newMemCache creates an in-memory cache
|
||||
func newMemCache() (*sqliteMessageCache, error) {
|
||||
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false)
|
||||
}
|
||||
|
||||
// newNopCache creates an in-memory cache that discards all messages;
|
||||
// it is always empty and can be used if caching is entirely disabled
|
||||
func newNopCache() (*sqliteMessageCache, error) {
|
||||
return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true)
|
||||
}
|
||||
|
||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
|
||||
// sql database, so if the stdlib's sql engine happens to open another connection and
|
||||
// you've only specified ":memory:", that connection will see a brand new database.
|
||||
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
|
||||
// Every connection to this string will point to the same in-memory database."
|
||||
func createMemoryFilename() string {
|
||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||
}
|
||||
|
||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously.
|
||||
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
||||
func (c *sqliteMessageCache) AddMessage(m *message) error {
|
||||
if c.nop {
|
||||
return nil
|
||||
}
|
||||
return c.commonMessageCache.AddMessage(m)
|
||||
}
|
||||
|
||||
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
// Run startup queries
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
rowsMC, err := db.Query(selectMessagesCountQuery)
|
||||
if err != nil {
|
||||
return setupNewCacheDB(db)
|
||||
}
|
||||
rowsMC.Close()
|
||||
|
||||
// If 'messages' table exists, check 'schemaVersion' table
|
||||
schemaVersion := 0
|
||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||
if err == nil {
|
||||
defer rowsSV.Close()
|
||||
if !rowsSV.Next() {
|
||||
return errors.New("cannot determine schema version: cache file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
}
|
||||
|
||||
// Do migrations
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < currentSchemaVersion; i++ {
|
||||
fn, ok := migrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db, cacheDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewCacheDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom0(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||
if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(createSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom1(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||
if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||
if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||
if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||
if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||
if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom6(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||
if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 7); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom7(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||
if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||
if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(updateSchemaVersion, 9); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func TestMemCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessages(t *testing.T, c MessageCache) {
|
||||
func testCacheMessages(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
@@ -100,7 +100,7 @@ func TestMemCache_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesLock(t *testing.T, c MessageCache) {
|
||||
func testCacheMessagesLock(t *testing.T, c *messageCache) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
@@ -120,7 +120,7 @@ func TestMemCache_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesScheduled(t *testing.T, c MessageCache) {
|
||||
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
@@ -154,7 +154,7 @@ func TestMemCache_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheTopics(t *testing.T, c MessageCache) {
|
||||
func testCacheTopics(t *testing.T, c *messageCache) {
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
|
||||
@@ -177,7 +177,7 @@ func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c MessageCache) {
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
@@ -198,7 +198,7 @@ func TestMemCache_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesSinceID(t *testing.T, c MessageCache) {
|
||||
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
@@ -268,7 +268,7 @@ func TestMemCache_Prune(t *testing.T) {
|
||||
testCachePrune(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCachePrune(t *testing.T, c MessageCache) {
|
||||
func testCachePrune(t *testing.T, c *messageCache) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
@@ -315,7 +315,7 @@ func TestMemCache_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachments(t *testing.T, c MessageCache) {
|
||||
func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
@@ -397,7 +397,7 @@ func TestMemCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, c MessageCache) {
|
||||
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
@@ -473,7 +473,7 @@ func TestSqliteCache_Migration_From0(t *testing.T) {
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.DB())
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
@@ -519,7 +519,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.DB())
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
|
||||
@@ -537,7 +537,7 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
rows, err := c.DB().Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
@@ -623,7 +623,7 @@ func TestSqliteCache_Migration_From9(t *testing.T) {
|
||||
cacheDuration := 17 * time.Hour
|
||||
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
checkSchemaVersion(t, c.DB())
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Check version
|
||||
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
|
||||
@@ -681,7 +681,7 @@ func TestMemCache_Sender(t *testing.T) {
|
||||
testSender(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testSender(t *testing.T, c MessageCache) {
|
||||
func testSender(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
@@ -720,7 +720,7 @@ func TestMemCache_NopCache(t *testing.T) {
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestCache(t *testing.T) MessageCache {
|
||||
func newSqliteTestCache(t *testing.T) *messageCache {
|
||||
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -732,13 +732,13 @@ func newSqliteTestCacheFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) MessageCache {
|
||||
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
|
||||
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
}
|
||||
|
||||
func newMemTestCache(t *testing.T) MessageCache {
|
||||
func newMemTestCache(t *testing.T) *messageCache {
|
||||
c, err := newMemCache()
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
|
||||
@@ -56,8 +56,8 @@ type Server struct {
|
||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache MessageCache // Database that stores the messages
|
||||
webPush WebPushStore // Database that stores web push subscriptions
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
webPush *webPushStore // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
@@ -173,7 +173,7 @@ func New(conf *Config) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var webPush WebPushStore
|
||||
var webPush *webPushStore
|
||||
if conf.WebPushPublicKey != "" {
|
||||
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
if err != nil {
|
||||
@@ -245,11 +245,9 @@ func New(conf *Config) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func createMessageCache(conf *Config) (MessageCache, error) {
|
||||
func createMessageCache(conf *Config) (*messageCache, error) {
|
||||
if conf.CacheDuration == 0 {
|
||||
return newNopCache()
|
||||
} else if isPostgres(conf.CacheFile) {
|
||||
return newPgCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||
} else if conf.CacheFile != "" {
|
||||
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||
}
|
||||
|
||||
@@ -24,15 +24,17 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit
|
||||
userGrants := make([]*apiUserGrantResponse, len(grants[u.ID]))
|
||||
for i, g := range grants[u.ID] {
|
||||
userGrants[i] = &apiUserGrantResponse{
|
||||
Topic: g.TopicPattern,
|
||||
Permission: g.Permission.String(),
|
||||
Topic: g.TopicPattern,
|
||||
Permission: g.Permission.String(),
|
||||
Provisioned: g.Provisioned,
|
||||
}
|
||||
}
|
||||
usersResponse[i] = &apiUserResponse{
|
||||
Username: u.Name,
|
||||
Role: string(u.Role),
|
||||
Tier: tier,
|
||||
Grants: userGrants,
|
||||
Username: u.Name,
|
||||
Role: string(u.Role),
|
||||
Tier: tier,
|
||||
Grants: userGrants,
|
||||
Provisioned: u.Provisioned,
|
||||
}
|
||||
}
|
||||
return s.writeJSON(w, usersResponse)
|
||||
|
||||
@@ -381,7 +381,7 @@ func TestServer_PublishAt(t *testing.T) {
|
||||
|
||||
// Update message time to the past
|
||||
fakeTime := time.Now().Add(-10 * time.Second).Unix()
|
||||
_, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
|
||||
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Trigger delayed message sending
|
||||
@@ -417,7 +417,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) {
|
||||
|
||||
// Update message time to the past
|
||||
fakeTime := time.Now().Add(-10 * time.Second).Unix()
|
||||
_, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime)
|
||||
_, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Trigger delayed message sending
|
||||
@@ -2336,7 +2336,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
||||
}
|
||||
require.Nil(t, s.messageCache.AddMessages(messages))
|
||||
require.Nil(t, s.messageCache.addMessages(messages))
|
||||
log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Update stats
|
||||
|
||||
@@ -238,7 +238,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
_, err := s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
|
||||
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
@@ -248,7 +248,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
_, err = s.webPush.DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
|
||||
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
|
||||
@@ -308,15 +308,17 @@ type apiUserAddOrUpdateRequest struct {
|
||||
}
|
||||
|
||||
type apiUserResponse struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier,omitempty"`
|
||||
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier,omitempty"`
|
||||
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
|
||||
Provisioned bool `json:"provisioned,omitempty"`
|
||||
}
|
||||
|
||||
type apiUserGrantResponse struct {
|
||||
Topic string `json:"topic"` // This may be a pattern
|
||||
Permission string `json:"permission"`
|
||||
Topic string `json:"topic"` // This may be a pattern
|
||||
Permission string `json:"permission"`
|
||||
Provisioned bool `json:"provisioned,omitempty"`
|
||||
}
|
||||
|
||||
type apiUserDeleteRequest struct {
|
||||
|
||||
@@ -53,7 +53,7 @@ const (
|
||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||
type visitor struct {
|
||||
config *Config
|
||||
messageCache MessageCache
|
||||
messageCache *messageCache
|
||||
userManager *user.Manager // May be nil
|
||||
ip netip.Addr // Visitor IP address
|
||||
user *user.User // Only set if authenticated user, otherwise nil
|
||||
@@ -114,7 +114,7 @@ const (
|
||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache MessageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messages, emails, calls int64
|
||||
if user != nil {
|
||||
messages = user.Stats.Messages
|
||||
|
||||
@@ -3,11 +3,11 @@ package server
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/util"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,49 +22,126 @@ var (
|
||||
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
// WebPushStore is an interface for storing web push subscriptions
|
||||
type WebPushStore interface {
|
||||
UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
|
||||
SubscriptionsForTopic(topic string) ([]*webPushSubscription, error)
|
||||
SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error)
|
||||
MarkExpiryWarningSent(subscriptions []*webPushSubscription) error
|
||||
RemoveSubscriptionsByEndpoint(endpoint string) error
|
||||
RemoveSubscriptionsByUserID(userID string) error
|
||||
RemoveExpiredSubscriptions(expireAfter time.Duration) error
|
||||
DB() *sql.DB
|
||||
Close() error
|
||||
}
|
||||
const (
|
||||
createWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
// webPushQueries holds all the SQL queries used by webPushStore
|
||||
type webPushQueries struct {
|
||||
selectSubscriptionIDByEndpoint string
|
||||
selectSubscriptionCountBySubscriberIP string
|
||||
selectSubscriptionsForTopic string
|
||||
selectSubscriptionsExpiringSoon string
|
||||
insertSubscription string
|
||||
updateSubscriptionWarningSent string
|
||||
deleteSubscriptionByEndpoint string
|
||||
deleteSubscriptionByUserID string
|
||||
deleteSubscriptionByAge string
|
||||
insertSubscriptionTopic string
|
||||
deleteSubscriptionTopicAll string
|
||||
deleteSubscriptionTopicWithoutSub string
|
||||
}
|
||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
selectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
selectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
insertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentWebPushSchemaVersion = 1
|
||||
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
type webPushStore struct {
|
||||
db *sql.DB
|
||||
queries *webPushQueries
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// newWebPushStore creates a new webPushStore based on the connection string
|
||||
func newWebPushStore(filename, startupQueries string) (WebPushStore, error) {
|
||||
if strings.HasPrefix(filename, "postgres:") {
|
||||
return newPgWebPushStore(strings.TrimPrefix(filename, "postgres:"), startupQueries)
|
||||
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSqliteWebPushStore(filename, startupQueries)
|
||||
if err := setupWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID
|
||||
func setupWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(selectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||
// existing entries for a given endpoint.
|
||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
@@ -72,7 +149,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
rowsCount, err := tx.Query(c.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -88,7 +165,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
rows, err := tx.Query(c.queries.selectSubscriptionIDByEndpoint, endpoint)
|
||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -109,15 +186,15 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(c.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(c.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
||||
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(c.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -126,7 +203,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(c.queries.selectSubscriptionsForTopic, topic)
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,7 +213,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(c.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,7 +229,7 @@ func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscriptio
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(c.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
||||
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -179,7 +256,7 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := c.db.Exec(c.queries.deleteSubscriptionByEndpoint, endpoint)
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -188,25 +265,20 @@ func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return errWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := c.db.Exec(c.queries.deleteSubscriptionByUserID, userID)
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := c.db.Exec(c.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.db.Exec(c.queries.deleteSubscriptionTopicWithoutSub)
|
||||
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection (for testing)
|
||||
func (c *webPushStore) DB() *sql.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *webPushStore) Close() error {
|
||||
return c.db.Close()
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL-specific queries
|
||||
const (
|
||||
pgCreateWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// Schema management queries
|
||||
pgCurrentWebPushSchemaVersion = 1
|
||||
pgInsertWebPushSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
|
||||
pgSelectWebPushSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
|
||||
)
|
||||
|
||||
// PostgreSQL-specific webpush queries
|
||||
var pgWebPushQueries = &webPushQueries{
|
||||
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = $1`,
|
||||
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = $1`,
|
||||
selectSubscriptionsForTopic: `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = $1
|
||||
ORDER BY endpoint
|
||||
`,
|
||||
selectSubscriptionsExpiringSoon: `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= $1
|
||||
`,
|
||||
insertSubscription: `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = EXCLUDED.key_auth, key_p256dh = EXCLUDED.key_p256dh, user_id = EXCLUDED.user_id, subscriber_ip = EXCLUDED.subscriber_ip, updated_at = EXCLUDED.updated_at, warned_at = EXCLUDED.warned_at
|
||||
`,
|
||||
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = $1 WHERE id = $2`,
|
||||
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = $1`,
|
||||
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = $1`,
|
||||
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= $1`,
|
||||
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES ($1, $2)`,
|
||||
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = $1`,
|
||||
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
|
||||
}
|
||||
|
||||
func newPgWebPushStore(connStr, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("pgx", connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
|
||||
}
|
||||
if err := setupPgWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runPgWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
queries: pgWebPushQueries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupPgWebPushDB(db *sql.DB) error {
|
||||
// If 'schema_version' table does not exist, this must be a new database
|
||||
rows, err := db.Query(pgSelectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewPgWebPushDB(db)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// If table exists but no rows, also create new
|
||||
if !rows.Next() {
|
||||
return setupNewPgWebPushDB(db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPgWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(pgCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(pgInsertWebPushSchemaVersion, pgCurrentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runPgWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// SQLite-specific queries
|
||||
const (
|
||||
sqliteCreateWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
sqliteWebPushBuiltinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
// Schema management queries
|
||||
sqliteCurrentWebPushSchemaVersion = 1
|
||||
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// SQLite-specific webpush queries
|
||||
var sqliteWebPushQueries = &webPushQueries{
|
||||
selectSubscriptionIDByEndpoint: `SELECT id FROM subscription WHERE endpoint = ?`,
|
||||
selectSubscriptionCountBySubscriberIP: `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`,
|
||||
selectSubscriptionsForTopic: `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`,
|
||||
selectSubscriptionsExpiringSoon: `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`,
|
||||
insertSubscription: `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`,
|
||||
updateSubscriptionWarningSent: `UPDATE subscription SET warned_at = ? WHERE id = ?`,
|
||||
deleteSubscriptionByEndpoint: `DELETE FROM subscription WHERE endpoint = ?`,
|
||||
deleteSubscriptionByUserID: `DELETE FROM subscription WHERE user_id = ?`,
|
||||
deleteSubscriptionByAge: `DELETE FROM subscription WHERE updated_at <= ?`,
|
||||
insertSubscriptionTopic: `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`,
|
||||
deleteSubscriptionTopicAll: `DELETE FROM subscription_topic WHERE subscription_id = ?`,
|
||||
deleteSubscriptionTopicWithoutSub: `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`,
|
||||
}
|
||||
|
||||
func newSqliteWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSqliteWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSqliteWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
queries: sqliteWebPushQueries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupSqliteWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewSqliteWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewSqliteWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSqliteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(sqliteWebPushBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
|
||||
|
||||
rows, err := webPush.DB().Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
require.Nil(t, err)
|
||||
defer rows.Close()
|
||||
var endpoint string
|
||||
@@ -156,7 +156,7 @@ func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
_, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Should not be cleaned up yet
|
||||
@@ -180,7 +180,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
_, err = webPush.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Run expiration
|
||||
@@ -192,7 +192,7 @@ func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func newTestWebPushStore(t *testing.T) WebPushStore {
|
||||
func newTestWebPushStore(t *testing.T) *webPushStore {
|
||||
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
return webPush
|
||||
|
||||
520
user/manager.go
520
user/manager.go
@@ -6,17 +6,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,10 +43,100 @@ const (
|
||||
var (
|
||||
errNoTokenProvided = errors.New("no token provided")
|
||||
errTopicOwnedByOthers = errors.New("topic owned by others")
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// Manager-related queries
|
||||
const (
|
||||
createTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
selectUserByIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
@@ -236,6 +326,229 @@ const (
|
||||
`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentSchemaVersion = 6
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
|
||||
// 1 -> 2 (complex migration!)
|
||||
migrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
migrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
JOIN access a ON u.user = a.user;
|
||||
|
||||
DROP TABLE access;
|
||||
DROP TABLE user_old;
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
migrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
DROP INDEX IF EXISTS idx_tier_price_id;
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
`
|
||||
|
||||
// 3 -> 4
|
||||
migrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
migrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
migrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
INSERT INTO user
|
||||
SELECT
|
||||
id,
|
||||
tier_id,
|
||||
user,
|
||||
pass,
|
||||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
stripe_customer_id,
|
||||
stripe_subscription_id,
|
||||
stripe_subscription_status,
|
||||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
-- Alter user_access table: Add provisioned column
|
||||
ALTER TABLE user_access RENAME TO user_access_old;
|
||||
CREATE TABLE user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INTEGER NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB) error{
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
// Manager is an implementation of Manager. It stores users and access control list
|
||||
// in a SQLite database.
|
||||
type Manager struct {
|
||||
@@ -270,20 +583,28 @@ func NewManager(config *Config) (*Manager, error) {
|
||||
if config.QueueWriterInterval.Seconds() <= 0 {
|
||||
config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
|
||||
}
|
||||
|
||||
var manager *Manager
|
||||
var err error
|
||||
|
||||
// Select database backend based on connection string
|
||||
if strings.HasPrefix(config.Filename, "postgres:") {
|
||||
manager, err = newPgManager(config)
|
||||
} else {
|
||||
manager, err = newSqliteManager(config)
|
||||
// Check the parent directory of the database file (makes for friendly error messages)
|
||||
parentDir := filepath.Dir(config.Filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
// Open DB and run setup queries
|
||||
db, err := sql.Open("sqlite3", config.Filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := setupDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runStartupQueries(db, config.StartupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
manager := &Manager{
|
||||
db: db,
|
||||
config: config,
|
||||
statsQueue: make(map[string]*Stats),
|
||||
tokenQueue: make(map[string]*TokenUpdate),
|
||||
}
|
||||
if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1608,6 +1929,173 @@ func unescapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "\\_", "_")
|
||||
}
|
||||
|
||||
func runStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewDB(db)
|
||||
}
|
||||
defer rowsSV.Close()
|
||||
|
||||
// If 'schemaVersion' table exists, read version and potentially upgrade
|
||||
schemaVersion := 0
|
||||
if !rowsSV.Next() {
|
||||
return errors.New("cannot determine schema version: database file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
|
||||
// Do migrations
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < currentSchemaVersion; i++ {
|
||||
fn, ok := migrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom1(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
// PostgreSQL-specific queries
|
||||
const (
|
||||
pgCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
"user" TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT 0,
|
||||
stats_emails INT NOT NULL DEFAULT 0,
|
||||
stats_calls INT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_user ON "user" ("user");
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_customer_id ON "user" (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_stripe_subscription_id ON "user" (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id TEXT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES "user" (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES "user" (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, "user", pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', 0, EXTRACT(EPOCH FROM NOW())::INT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
pgCurrentSchemaVersion = 1
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version VALUES (1, $1)`
|
||||
pgUpdateSchemaVersion = `UPDATE schema_version SET version = $1 WHERE id = 1`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE id = 1`
|
||||
)
|
||||
|
||||
// newPgManager creates a new PostgreSQL-backed user manager
|
||||
func newPgManager(config *Config) (*Manager, error) {
|
||||
db, err := sql.Open("pgx", config.Filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err)
|
||||
}
|
||||
if err := setupPgDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runPgStartupQueries(db, config.StartupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Manager{
|
||||
config: config,
|
||||
db: db,
|
||||
statsQueue: make(map[string]*Stats),
|
||||
tokenQueue: make(map[string]*TokenUpdate),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func runPgStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupPgDB(db *sql.DB) error {
|
||||
// If 'schema_version' table does not exist, this must be a new database
|
||||
rowsSV, err := db.Query(pgSelectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewPgDB(db)
|
||||
}
|
||||
defer rowsSV.Close()
|
||||
|
||||
// If 'schema_version' table exists, read version and potentially upgrade
|
||||
schemaVersion := 0
|
||||
if !rowsSV.Next() {
|
||||
// Table exists but no rows, insert version
|
||||
return setupNewPgDB(db)
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
|
||||
// Do migrations
|
||||
if schemaVersion == pgCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
|
||||
// No migrations needed yet for PG (starting at version 1)
|
||||
log.Tag(tag).Info("PostgreSQL user database schema is up to date (version %d)", schemaVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPgDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(pgCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,532 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// SQLite-specific queries
|
||||
const (
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
sqliteBuiltinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 6
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
|
||||
// Migration queries (1->2 through 5->6)
|
||||
// These are SQLite-specific due to PRAGMA, ALTER TABLE syntax, etc.
|
||||
migrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
migrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
JOIN access a ON u.user = a.user;
|
||||
|
||||
DROP TABLE access;
|
||||
DROP TABLE user_old;
|
||||
`
|
||||
|
||||
migrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
DROP INDEX IF EXISTS idx_tier_price_id;
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
`
|
||||
|
||||
migrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
`
|
||||
|
||||
migrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
migrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
INSERT INTO user
|
||||
SELECT
|
||||
id,
|
||||
tier_id,
|
||||
user,
|
||||
pass,
|
||||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
stripe_customer_id,
|
||||
stripe_subscription_id,
|
||||
stripe_subscription_status,
|
||||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
-- Alter user_access table: Add provisioned column
|
||||
ALTER TABLE user_access RENAME TO user_access_old;
|
||||
CREATE TABLE user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INTEGER NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
// newSqliteManager creates a new SQLite-backed Manager
|
||||
func newSqliteManager(config *Config) (*Manager, error) {
|
||||
// Check the parent directory of the database file (makes for friendly error messages)
|
||||
parentDir := filepath.Dir(config.Filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
// Open DB and run setup queries
|
||||
db, err := sql.Open("sqlite3", config.Filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSqliteDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSqliteStartupQueries(db, config.StartupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Manager{
|
||||
db: db,
|
||||
config: config,
|
||||
statsQueue: make(map[string]*Stats),
|
||||
tokenQueue: make(map[string]*TokenUpdate),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func runSqliteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupSqliteDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewSqliteDB(db)
|
||||
}
|
||||
defer rowsSV.Close()
|
||||
|
||||
// If 'schemaVersion' table exists, read version and potentially upgrade
|
||||
schemaVersion := 0
|
||||
if !rowsSV.Next() {
|
||||
return errors.New("cannot determine schema version: database file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
|
||||
// Do migrations
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSqliteDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom1(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// isSqliteConstraintUniqueError checks if the error is a SQLite unique constraint error
|
||||
func isSqliteConstraintUniqueError(err error) bool {
|
||||
if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
|
||||
return true
|
||||
}
|
||||
return errors.Is(err, sqlite3.ErrConstraintUnique)
|
||||
}
|
||||
|
||||
@@ -1578,7 +1578,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion)
|
||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
|
||||
@@ -405,5 +405,48 @@
|
||||
"web_push_subscription_expiring_title": "Notifications will be paused",
|
||||
"web_push_subscription_expiring_body": "Open ntfy to continue receiving notifications",
|
||||
"web_push_unknown_notification_title": "Unknown notification received from server",
|
||||
"web_push_unknown_notification_body": "You may need to update ntfy by opening the web app"
|
||||
"web_push_unknown_notification_body": "You may need to update ntfy by opening the web app",
|
||||
"nav_button_admin": "Admin",
|
||||
"admin_users_title": "Users",
|
||||
"admin_users_description": "Manage users and their access permissions. Admin users cannot be modified via the web interface.",
|
||||
"admin_users_table_username_header": "Username",
|
||||
"admin_users_table_role_header": "Role",
|
||||
"admin_users_table_tier_header": "Tier",
|
||||
"admin_users_table_grants_header": "Access grants",
|
||||
"admin_users_table_actions_header": "Actions",
|
||||
"admin_users_table_grant_tooltip": "Permission: {{permission}}",
|
||||
"admin_users_table_grant_provisioned_tooltip": "Permission: {{permission}} (provisioned, cannot be changed)",
|
||||
"admin_users_table_add_access_tooltip": "Add access grant",
|
||||
"admin_users_table_edit_tooltip": "Edit user",
|
||||
"admin_users_table_delete_tooltip": "Delete user",
|
||||
"admin_users_table_admin_no_actions": "Cannot modify admin users",
|
||||
"admin_users_provisioned_tooltip": "Provisioned user (defined in server config)",
|
||||
"admin_users_provisioned_cannot_edit": "Provisioned users cannot be edited or deleted",
|
||||
"admin_users_role_admin": "Admin",
|
||||
"admin_users_role_user": "User",
|
||||
"admin_users_add_button": "Add user",
|
||||
"admin_users_add_dialog_title": "Add user",
|
||||
"admin_users_add_dialog_username_label": "Username",
|
||||
"admin_users_add_dialog_password_label": "Password",
|
||||
"admin_users_add_dialog_tier_label": "Tier",
|
||||
"admin_users_add_dialog_tier_helper": "Optional. Leave empty for no tier.",
|
||||
"admin_users_edit_dialog_title": "Edit user {{username}}",
|
||||
"admin_users_edit_dialog_password_label": "New password",
|
||||
"admin_users_edit_dialog_password_helper": "Leave empty to keep current password",
|
||||
"admin_users_edit_dialog_tier_label": "Tier",
|
||||
"admin_users_edit_dialog_tier_helper": "Leave empty to keep current tier",
|
||||
"admin_users_delete_dialog_title": "Delete user",
|
||||
"admin_users_delete_dialog_description": "Are you sure you want to delete user {{username}}? This action cannot be undone.",
|
||||
"admin_users_delete_dialog_button": "Delete user",
|
||||
"admin_access_add_dialog_title": "Add access for {{username}}",
|
||||
"admin_access_add_dialog_topic_label": "Topic",
|
||||
"admin_access_add_dialog_topic_helper": "Topic name or pattern (e.g. mytopic or alerts-*)",
|
||||
"admin_access_add_dialog_permission_label": "Permission",
|
||||
"admin_access_permission_read_write": "Read & Write",
|
||||
"admin_access_permission_read_only": "Read only",
|
||||
"admin_access_permission_write_only": "Write only",
|
||||
"admin_access_permission_deny_all": "Deny all",
|
||||
"admin_access_delete_dialog_title": "Remove access",
|
||||
"admin_access_delete_dialog_description": "Are you sure you want to remove access to topic {{topic}} for user {{username}}?",
|
||||
"admin_access_delete_dialog_button": "Remove access"
|
||||
}
|
||||
|
||||
82
web/src/app/AdminApi.js
Normal file
82
web/src/app/AdminApi.js
Normal file
@@ -0,0 +1,82 @@
|
||||
import { fetchOrThrow } from "./errors";
|
||||
import { withBearerAuth } from "./utils";
|
||||
import session from "./Session";
|
||||
|
||||
const usersUrl = (baseUrl) => `${baseUrl}/v1/users`;
|
||||
const usersAccessUrl = (baseUrl) => `${baseUrl}/v1/users/access`;
|
||||
|
||||
class AdminApi {
|
||||
async getUsers() {
|
||||
const url = usersUrl(config.base_url);
|
||||
console.log(`[AdminApi] Fetching users ${url}`);
|
||||
const response = await fetchOrThrow(url, {
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
});
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async addUser(username, password, tier) {
|
||||
const url = usersUrl(config.base_url);
|
||||
const body = { username, password };
|
||||
if (tier) {
|
||||
body.tier = tier;
|
||||
}
|
||||
console.log(`[AdminApi] Adding user ${url}`);
|
||||
await fetchOrThrow(url, {
|
||||
method: "POST",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async updateUser(username, password, tier) {
|
||||
const url = usersUrl(config.base_url);
|
||||
const body = { username };
|
||||
if (password) {
|
||||
body.password = password;
|
||||
}
|
||||
if (tier) {
|
||||
body.tier = tier;
|
||||
}
|
||||
console.log(`[AdminApi] Updating user ${url}`);
|
||||
await fetchOrThrow(url, {
|
||||
method: "PUT",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async deleteUser(username) {
|
||||
const url = usersUrl(config.base_url);
|
||||
console.log(`[AdminApi] Deleting user ${url}`);
|
||||
await fetchOrThrow(url, {
|
||||
method: "DELETE",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify({ username }),
|
||||
});
|
||||
}
|
||||
|
||||
async allowAccess(username, topic, permission) {
|
||||
const url = usersAccessUrl(config.base_url);
|
||||
console.log(`[AdminApi] Allowing access ${url}`);
|
||||
await fetchOrThrow(url, {
|
||||
method: "PUT",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify({ username, topic, permission }),
|
||||
});
|
||||
}
|
||||
|
||||
async resetAccess(username, topic) {
|
||||
const url = usersAccessUrl(config.base_url);
|
||||
console.log(`[AdminApi] Resetting access ${url}`);
|
||||
await fetchOrThrow(url, {
|
||||
method: "DELETE",
|
||||
headers: withBearerAuth({}, session.token()),
|
||||
body: JSON.stringify({ username, topic }),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const adminApi = new AdminApi();
|
||||
export default adminApi;
|
||||
|
||||
593
web/src/components/Admin.jsx
Normal file
593
web/src/components/Admin.jsx
Normal file
@@ -0,0 +1,593 @@
|
||||
import * as React from "react";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import {
|
||||
Alert,
|
||||
CardActions,
|
||||
CardContent,
|
||||
Chip,
|
||||
FormControl,
|
||||
Select,
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableRow,
|
||||
Tooltip,
|
||||
Typography,
|
||||
Container,
|
||||
Card,
|
||||
Button,
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
DialogContent,
|
||||
TextField,
|
||||
IconButton,
|
||||
MenuItem,
|
||||
DialogContentText,
|
||||
useMediaQuery,
|
||||
useTheme,
|
||||
Stack,
|
||||
CircularProgress,
|
||||
Box,
|
||||
} from "@mui/material";
|
||||
import EditIcon from "@mui/icons-material/Edit";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import DeleteOutlineIcon from "@mui/icons-material/DeleteOutline";
|
||||
import AddIcon from "@mui/icons-material/Add";
|
||||
import CloseIcon from "@mui/icons-material/Close";
|
||||
import LockIcon from "@mui/icons-material/Lock";
|
||||
import routes from "./routes";
|
||||
import { AccountContext } from "./App";
|
||||
import DialogFooter from "./DialogFooter";
|
||||
import { Paragraph } from "./styles";
|
||||
import { UnauthorizedError } from "../app/errors";
|
||||
import session from "../app/Session";
|
||||
import adminApi from "../app/AdminApi";
|
||||
import { Role } from "../app/AccountApi";
|
||||
|
||||
const Admin = () => {
|
||||
const { account } = useContext(AccountContext);
|
||||
|
||||
// Redirect non-admins away
|
||||
if (!session.exists() || (account && account.role !== Role.ADMIN)) {
|
||||
window.location.href = routes.app;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Wait for account to load
|
||||
if (!account) {
|
||||
return (
|
||||
<Box sx={{ display: "flex", justifyContent: "center", alignItems: "center", height: "100vh" }}>
|
||||
<CircularProgress />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Container maxWidth="lg" sx={{ marginTop: 3, marginBottom: 3 }}>
|
||||
<Stack spacing={3}>
|
||||
<Users />
|
||||
</Stack>
|
||||
</Container>
|
||||
);
|
||||
};
|
||||
|
||||
const Users = () => {
|
||||
const { t } = useTranslation();
|
||||
const [users, setUsers] = useState(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState("");
|
||||
const [addDialogKey, setAddDialogKey] = useState(0);
|
||||
const [addDialogOpen, setAddDialogOpen] = useState(false);
|
||||
|
||||
const loadUsers = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const data = await adminApi.getUsers();
|
||||
setUsers(data);
|
||||
setError("");
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error loading users`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadUsers();
|
||||
}, []);
|
||||
|
||||
const handleAddClick = () => {
|
||||
setAddDialogKey((prev) => prev + 1);
|
||||
setAddDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleDialogClose = () => {
|
||||
setAddDialogOpen(false);
|
||||
loadUsers();
|
||||
};
|
||||
|
||||
return (
|
||||
<Card sx={{ padding: 1 }} aria-label={t("admin_users_title")}>
|
||||
<CardContent sx={{ paddingBottom: 1 }}>
|
||||
<Typography variant="h5" sx={{ marginBottom: 2 }}>
|
||||
{t("admin_users_title")}
|
||||
</Typography>
|
||||
<Paragraph>{t("admin_users_description")}</Paragraph>
|
||||
{error && (
|
||||
<Alert severity="error" sx={{ mb: 2 }}>
|
||||
{error}
|
||||
</Alert>
|
||||
)}
|
||||
{loading && (
|
||||
<Box sx={{ display: "flex", justifyContent: "center", p: 3 }}>
|
||||
<CircularProgress />
|
||||
</Box>
|
||||
)}
|
||||
{!loading && users && (
|
||||
<div style={{ width: "100%", overflowX: "auto" }}>
|
||||
<UsersTable users={users} onUserChanged={loadUsers} />
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
<CardActions>
|
||||
<Button onClick={handleAddClick} startIcon={<AddIcon />}>
|
||||
{t("admin_users_add_button")}
|
||||
</Button>
|
||||
</CardActions>
|
||||
<AddUserDialog key={`addUserDialog${addDialogKey}`} open={addDialogOpen} onClose={handleDialogClose} />
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
const UsersTable = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const [editDialogKey, setEditDialogKey] = useState(0);
|
||||
const [editDialogOpen, setEditDialogOpen] = useState(false);
|
||||
const [deleteDialogOpen, setDeleteDialogOpen] = useState(false);
|
||||
const [accessDialogKey, setAccessDialogKey] = useState(0);
|
||||
const [accessDialogOpen, setAccessDialogOpen] = useState(false);
|
||||
const [deleteAccessDialogOpen, setDeleteAccessDialogOpen] = useState(false);
|
||||
const [selectedUser, setSelectedUser] = useState(null);
|
||||
const [selectedGrant, setSelectedGrant] = useState(null);
|
||||
|
||||
const { users } = props;
|
||||
|
||||
const handleEditClick = (user) => {
|
||||
setEditDialogKey((prev) => prev + 1);
|
||||
setSelectedUser(user);
|
||||
setEditDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleDeleteClick = (user) => {
|
||||
setSelectedUser(user);
|
||||
setDeleteDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleAddAccessClick = (user) => {
|
||||
setAccessDialogKey((prev) => prev + 1);
|
||||
setSelectedUser(user);
|
||||
setAccessDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleDeleteAccessClick = (user, grant) => {
|
||||
setSelectedUser(user);
|
||||
setSelectedGrant(grant);
|
||||
setDeleteAccessDialogOpen(true);
|
||||
};
|
||||
|
||||
const handleDialogClose = () => {
|
||||
setEditDialogOpen(false);
|
||||
setDeleteDialogOpen(false);
|
||||
setAccessDialogOpen(false);
|
||||
setDeleteAccessDialogOpen(false);
|
||||
setSelectedUser(null);
|
||||
setSelectedGrant(null);
|
||||
props.onUserChanged();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Table size="small" aria-label={t("admin_users_title")}>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell sx={{ paddingLeft: 0 }}>{t("admin_users_table_username_header")}</TableCell>
|
||||
<TableCell>{t("admin_users_table_role_header")}</TableCell>
|
||||
<TableCell>{t("admin_users_table_tier_header")}</TableCell>
|
||||
<TableCell>{t("admin_users_table_grants_header")}</TableCell>
|
||||
<TableCell align="right">{t("admin_users_table_actions_header")}</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{users.map((user) => (
|
||||
<TableRow key={user.username} sx={{ "&:last-child td, &:last-child th": { border: 0 } }}>
|
||||
<TableCell component="th" scope="row" sx={{ paddingLeft: 0 }}>
|
||||
<Stack direction="row" spacing={0.5} alignItems="center">
|
||||
<span>{user.username}</span>
|
||||
{user.provisioned && (
|
||||
<Tooltip title={t("admin_users_provisioned_tooltip")}>
|
||||
<LockIcon fontSize="small" color="disabled" />
|
||||
</Tooltip>
|
||||
)}
|
||||
</Stack>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<RoleChip role={user.role} />
|
||||
</TableCell>
|
||||
<TableCell>{user.tier || "-"}</TableCell>
|
||||
<TableCell>
|
||||
{user.grants && user.grants.length > 0 ? (
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" useFlexGap>
|
||||
{user.grants.map((grant, idx) => {
|
||||
const canDelete = user.role !== "admin" && !grant.provisioned;
|
||||
const tooltipText = grant.provisioned
|
||||
? t("admin_users_table_grant_provisioned_tooltip", { permission: grant.permission })
|
||||
: t("admin_users_table_grant_tooltip", { permission: grant.permission });
|
||||
return (
|
||||
<Tooltip key={idx} title={tooltipText}>
|
||||
<Chip
|
||||
label={grant.topic}
|
||||
size="small"
|
||||
variant={grant.provisioned ? "filled" : "outlined"}
|
||||
color={grant.provisioned ? "default" : "default"}
|
||||
icon={grant.provisioned ? <LockIcon fontSize="small" /> : undefined}
|
||||
onDelete={canDelete ? () => handleDeleteAccessClick(user, grant) : undefined}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</Stack>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell align="right" sx={{ whiteSpace: "nowrap" }}>
|
||||
{user.role !== "admin" && !user.provisioned ? (
|
||||
<>
|
||||
<Tooltip title={t("admin_users_table_add_access_tooltip")}>
|
||||
<IconButton onClick={() => handleAddAccessClick(user)} size="small">
|
||||
<AddIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
<Tooltip title={t("admin_users_table_edit_tooltip")}>
|
||||
<IconButton onClick={() => handleEditClick(user)} size="small">
|
||||
<EditIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
<Tooltip title={t("admin_users_table_delete_tooltip")}>
|
||||
<IconButton onClick={() => handleDeleteClick(user)} size="small">
|
||||
<DeleteOutlineIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
</>
|
||||
) : user.role !== "admin" && user.provisioned ? (
|
||||
<>
|
||||
<Tooltip title={t("admin_users_table_add_access_tooltip")}>
|
||||
<IconButton onClick={() => handleAddAccessClick(user)} size="small">
|
||||
<AddIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
<Tooltip title={t("admin_users_provisioned_cannot_edit")}>
|
||||
<span>
|
||||
<IconButton disabled size="small">
|
||||
<EditIcon />
|
||||
</IconButton>
|
||||
<IconButton disabled size="small">
|
||||
<DeleteOutlineIcon />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</>
|
||||
) : (
|
||||
<Tooltip title={t("admin_users_table_admin_no_actions")}>
|
||||
<span>
|
||||
<IconButton disabled size="small">
|
||||
<EditIcon />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<EditUserDialog key={`editUserDialog${editDialogKey}`} open={editDialogOpen} user={selectedUser} onClose={handleDialogClose} />
|
||||
<DeleteUserDialog open={deleteDialogOpen} user={selectedUser} onClose={handleDialogClose} />
|
||||
<AddAccessDialog key={`addAccessDialog${accessDialogKey}`} open={accessDialogOpen} user={selectedUser} onClose={handleDialogClose} />
|
||||
<DeleteAccessDialog open={deleteAccessDialogOpen} user={selectedUser} grant={selectedGrant} onClose={handleDialogClose} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const RoleChip = ({ role }) => {
|
||||
const { t } = useTranslation();
|
||||
if (role === "admin") {
|
||||
return <Chip label={t("admin_users_role_admin")} size="small" color="primary" />;
|
||||
}
|
||||
return <Chip label={t("admin_users_role_user")} size="small" variant="outlined" />;
|
||||
};
|
||||
|
||||
const AddUserDialog = (props) => {
|
||||
const theme = useTheme();
|
||||
const { t } = useTranslation();
|
||||
const [error, setError] = useState("");
|
||||
const [username, setUsername] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [tier, setTier] = useState("");
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
await adminApi.addUser(username, password, tier || undefined);
|
||||
props.onClose();
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error adding user`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
|
||||
<DialogTitle>{t("admin_users_add_dialog_title")}</DialogTitle>
|
||||
<DialogContent>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="username"
|
||||
label={t("admin_users_add_dialog_username_label")}
|
||||
type="text"
|
||||
value={username}
|
||||
onChange={(ev) => setUsername(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
autoFocus
|
||||
/>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="password"
|
||||
label={t("admin_users_add_dialog_password_label")}
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(ev) => setPassword(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
/>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="tier"
|
||||
label={t("admin_users_add_dialog_tier_label")}
|
||||
type="text"
|
||||
value={tier}
|
||||
onChange={(ev) => setTier(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
helperText={t("admin_users_add_dialog_tier_helper")}
|
||||
/>
|
||||
</DialogContent>
|
||||
<DialogFooter status={error}>
|
||||
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} disabled={!username || !password}>
|
||||
{t("common_add")}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
const EditUserDialog = (props) => {
|
||||
const theme = useTheme();
|
||||
const { t } = useTranslation();
|
||||
const [error, setError] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [tier, setTier] = useState(props.user?.tier || "");
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
await adminApi.updateUser(props.user.username, password || undefined, tier || undefined);
|
||||
props.onClose();
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error updating user`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (!props.user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
|
||||
<DialogTitle>{t("admin_users_edit_dialog_title", { username: props.user.username })}</DialogTitle>
|
||||
<DialogContent>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="password"
|
||||
label={t("admin_users_edit_dialog_password_label")}
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(ev) => setPassword(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
helperText={t("admin_users_edit_dialog_password_helper")}
|
||||
/>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="tier"
|
||||
label={t("admin_users_edit_dialog_tier_label")}
|
||||
type="text"
|
||||
value={tier}
|
||||
onChange={(ev) => setTier(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
helperText={t("admin_users_edit_dialog_tier_helper")}
|
||||
/>
|
||||
</DialogContent>
|
||||
<DialogFooter status={error}>
|
||||
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} disabled={!password && !tier}>
|
||||
{t("common_save")}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
const DeleteUserDialog = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const [error, setError] = useState("");
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
await adminApi.deleteUser(props.user.username);
|
||||
props.onClose();
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error deleting user`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (!props.user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onClose}>
|
||||
<DialogTitle>{t("admin_users_delete_dialog_title")}</DialogTitle>
|
||||
<DialogContent>
|
||||
<DialogContentText>{t("admin_users_delete_dialog_description", { username: props.user.username })}</DialogContentText>
|
||||
</DialogContent>
|
||||
<DialogFooter status={error}>
|
||||
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} color="error">
|
||||
{t("admin_users_delete_dialog_button")}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
const AddAccessDialog = (props) => {
|
||||
const theme = useTheme();
|
||||
const { t } = useTranslation();
|
||||
const [error, setError] = useState("");
|
||||
const [topic, setTopic] = useState("");
|
||||
const [permission, setPermission] = useState("read-write");
|
||||
const fullScreen = useMediaQuery(theme.breakpoints.down("sm"));
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
await adminApi.allowAccess(props.user.username, topic, permission);
|
||||
props.onClose();
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error adding access`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (!props.user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onClose} maxWidth="sm" fullWidth fullScreen={fullScreen}>
|
||||
<DialogTitle>{t("admin_access_add_dialog_title", { username: props.user.username })}</DialogTitle>
|
||||
<DialogContent>
|
||||
<TextField
|
||||
margin="dense"
|
||||
id="topic"
|
||||
label={t("admin_access_add_dialog_topic_label")}
|
||||
type="text"
|
||||
value={topic}
|
||||
onChange={(ev) => setTopic(ev.target.value)}
|
||||
fullWidth
|
||||
variant="standard"
|
||||
autoFocus
|
||||
helperText={t("admin_access_add_dialog_topic_helper")}
|
||||
/>
|
||||
<FormControl fullWidth variant="standard" sx={{ mt: 2 }}>
|
||||
<Select
|
||||
value={permission}
|
||||
onChange={(ev) => setPermission(ev.target.value)}
|
||||
label={t("admin_access_add_dialog_permission_label")}
|
||||
>
|
||||
<MenuItem value="read-write">{t("admin_access_permission_read_write")}</MenuItem>
|
||||
<MenuItem value="read-only">{t("admin_access_permission_read_only")}</MenuItem>
|
||||
<MenuItem value="write-only">{t("admin_access_permission_write_only")}</MenuItem>
|
||||
<MenuItem value="deny-all">{t("admin_access_permission_deny_all")}</MenuItem>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</DialogContent>
|
||||
<DialogFooter status={error}>
|
||||
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} disabled={!topic}>
|
||||
{t("common_add")}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
const DeleteAccessDialog = (props) => {
|
||||
const { t } = useTranslation();
|
||||
const [error, setError] = useState("");
|
||||
|
||||
const handleSubmit = async () => {
|
||||
try {
|
||||
await adminApi.resetAccess(props.user.username, props.grant.topic);
|
||||
props.onClose();
|
||||
} catch (e) {
|
||||
console.log(`[Admin] Error removing access`, e);
|
||||
if (e instanceof UnauthorizedError) {
|
||||
await session.resetAndRedirect(routes.login);
|
||||
} else {
|
||||
setError(e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (!props.user || !props.grant) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={props.open} onClose={props.onClose}>
|
||||
<DialogTitle>{t("admin_access_delete_dialog_title")}</DialogTitle>
|
||||
<DialogContent>
|
||||
<DialogContentText>
|
||||
{t("admin_access_delete_dialog_description", { username: props.user.username, topic: props.grant.topic })}
|
||||
</DialogContentText>
|
||||
</DialogContent>
|
||||
<DialogFooter status={error}>
|
||||
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
|
||||
<Button onClick={handleSubmit} color="error">
|
||||
{t("admin_access_delete_dialog_button")}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default Admin;
|
||||
|
||||
@@ -20,6 +20,7 @@ import Messaging from "./Messaging";
|
||||
import Login from "./Login";
|
||||
import Signup from "./Signup";
|
||||
import Account from "./Account";
|
||||
import Admin from "./Admin";
|
||||
import initI18n from "../app/i18n"; // Translations!
|
||||
import prefs, { THEME } from "../app/Prefs";
|
||||
import RTLCacheProvider from "./RTLCacheProvider";
|
||||
@@ -80,6 +81,7 @@ const App = () => {
|
||||
<Route element={<Layout />}>
|
||||
<Route path={routes.app} element={<AllSubscriptions />} />
|
||||
<Route path={routes.account} element={<Account />} />
|
||||
<Route path={routes.admin} element={<Admin />} />
|
||||
<Route path={routes.settings} element={<Preferences />} />
|
||||
<Route path={routes.subscription} element={<SingleSubscription />} />
|
||||
<Route path={routes.subscriptionExternal} element={<SingleSubscription />} />
|
||||
|
||||
@@ -25,6 +25,7 @@ import { useContext, useState } from "react";
|
||||
import ChatBubbleOutlineIcon from "@mui/icons-material/ChatBubbleOutline";
|
||||
import Person from "@mui/icons-material/Person";
|
||||
import SettingsIcon from "@mui/icons-material/Settings";
|
||||
import AdminPanelSettingsIcon from "@mui/icons-material/AdminPanelSettings";
|
||||
import AddIcon from "@mui/icons-material/Add";
|
||||
import { useLocation, useNavigate } from "react-router-dom";
|
||||
import { ChatBubble, MoreVert, NotificationsOffOutlined, Send } from "@mui/icons-material";
|
||||
@@ -164,6 +165,14 @@ const NavList = (props) => {
|
||||
<ListItemText primary={t("nav_button_account")} />
|
||||
</ListItemButton>
|
||||
)}
|
||||
{session.exists() && isAdmin && (
|
||||
<ListItemButton onClick={() => navigate(routes.admin)} selected={location.pathname === routes.admin}>
|
||||
<ListItemIcon>
|
||||
<AdminPanelSettingsIcon />
|
||||
</ListItemIcon>
|
||||
<ListItemText primary={t("nav_button_admin")} />
|
||||
</ListItemButton>
|
||||
)}
|
||||
<ListItemButton onClick={() => navigate(routes.settings)} selected={location.pathname === routes.settings}>
|
||||
<ListItemIcon>
|
||||
<SettingsIcon />
|
||||
|
||||
@@ -6,6 +6,7 @@ const routes = {
|
||||
signup: "/signup",
|
||||
app: config.app_root,
|
||||
account: "/account",
|
||||
admin: "/admin",
|
||||
settings: "/settings",
|
||||
subscription: "/:topic",
|
||||
subscriptionExternal: "/:baseUrl/:topic",
|
||||
|
||||
Reference in New Issue
Block a user