Users() optimization to help with startup time

This commit is contained in:
binwiederhier
2026-03-06 14:46:53 -05:00
parent 28c3fd5cbe
commit 0d4b1b00e6
11 changed files with 876 additions and 41 deletions

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@ server/docs/
server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
tools/loadtest/loadtest
playground/
secrets/
*.iml

5
tools/loadtest/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module loadtest
go 1.25.2
require github.com/gorilla/websocket v1.5.3

2
tools/loadtest/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

522
tools/loadtest/main.go Normal file
View File

@@ -0,0 +1,522 @@
// Load test program for ntfy staging server.
// Replicates production traffic patterns derived from access.log analysis.
//
// Traffic profile (from ~5M requests over 20 hours):
// ~71 req/sec average, ~4,300 req/min
// 49.6% poll requests (GET /TOPIC/json?poll=1&since=ID)
// 21.4% publish POST (POST /TOPIC with small body)
// 6.2% subscribe stream (GET /TOPIC/json?since=X, long-lived)
// 4.1% config check (GET /v1/config)
// 2.3% other topic GET (GET /TOPIC)
// 2.2% account check (GET /v1/account)
// 1.9% websocket sub (GET /TOPIC/ws?since=X)
// 1.5% publish PUT (PUT /TOPIC with small body)
// 1.5% raw subscribe (GET /TOPIC/raw?since=X)
// 1.1% json subscribe (GET /TOPIC/json, no since)
// 0.7% SSE subscribe (GET /TOPIC/sse?since=X)
// remaining: static, PATCH, OPTIONS, etc. (omitted)
package main
import (
"context"
"crypto/rand"
"encoding/hex"
"flag"
"fmt"
"io"
"math/big"
mrand "math/rand"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
var (
baseURL string
rps float64
scale float64
numTopics int
subStreams int
wsStreams int
sseStreams int
rawStreams int
duration time.Duration
totalRequests atomic.Int64
totalErrors atomic.Int64
activeStreams atomic.Int64
// Error tracking by category
errMu sync.Mutex
recentErrors []string // last N unique error messages
errorCounts = make(map[string]int64)
)
func main() {
flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server")
flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)")
flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)")
flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use")
flag.IntVar(&subStreams, "sub-streams", 200, "Number of concurrent JSON streaming subscriptions")
flag.IntVar(&wsStreams, "ws-streams", 50, "Number of concurrent WebSocket subscriptions")
flag.IntVar(&sseStreams, "sse-streams", 20, "Number of concurrent SSE subscriptions")
flag.IntVar(&rawStreams, "raw-streams", 30, "Number of concurrent raw subscriptions")
flag.DurationVar(&duration, "duration", 10*time.Minute, "Test duration")
flag.Parse()
rps *= scale
subStreams = int(float64(subStreams) * scale)
wsStreams = int(float64(wsStreams) * scale)
sseStreams = int(float64(sseStreams) * scale)
rawStreams = int(float64(rawStreams) * scale)
topics := generateTopics(numTopics)
fmt.Printf("ntfy load test\n")
fmt.Printf(" Target: %s\n", baseURL)
fmt.Printf(" RPS: %.1f\n", rps)
fmt.Printf(" Scale: %.1fx\n", scale)
fmt.Printf(" Topics: %d\n", numTopics)
fmt.Printf(" Sub streams: %d json, %d ws, %d sse, %d raw\n", subStreams, wsStreams, sseStreams, rawStreams)
fmt.Printf(" Duration: %s\n", duration)
fmt.Println()
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
// Also handle Ctrl+C
sigCtx, sigCancel := signal.NotifyContext(ctx, os.Interrupt)
defer sigCancel()
ctx = sigCtx
client := &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
},
}
// Long-lived streaming client (no timeout)
streamClient := &http.Client{
Timeout: 0,
Transport: &http.Transport{
MaxIdleConns: 500,
MaxIdleConnsPerHost: 500,
IdleConnTimeout: 0,
},
}
var wg sync.WaitGroup
// Start long-lived streaming subscriptions
for i := 0; i < subStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "json")
}()
}
for i := 0; i < wsStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wsSubscription(ctx, topics)
}()
}
for i := 0; i < sseStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "sse")
}()
}
for i := 0; i < rawStreams; i++ {
wg.Add(1)
go func() {
defer wg.Done()
streamSubscription(ctx, streamClient, topics, "raw")
}()
}
// Start request generators based on traffic weights
// Weights from log analysis (normalized to sum ~100):
// poll=49.6, publish_post=21.4, config=4.1, other_get=2.3, account=2.2, publish_put=1.5
// Total short-lived weight ≈ 81.1
type requestType struct {
name string
weight float64
fn func(ctx context.Context, client *http.Client, topics []string)
}
types := []requestType{
{"poll", 49.6, doPoll},
{"publish_post", 21.4, doPublishPost},
{"config", 4.1, doConfig},
{"other_get", 2.3, doOtherGet},
{"account", 2.2, doAccountCheck},
{"publish_put", 1.5, doPublishPut},
}
totalWeight := 0.0
for _, t := range types {
totalWeight += t.weight
}
for _, t := range types {
t := t
typeRPS := rps * (t.weight / totalWeight)
if typeRPS < 0.1 {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
runAtRate(ctx, typeRPS, func() {
t.fn(ctx, client, topics)
})
}()
}
// Stats reporter
wg.Add(1)
go func() {
defer wg.Done()
reportStats(ctx)
}()
wg.Wait()
fmt.Printf("\nDone. Total requests: %d, errors: %d\n", totalRequests.Load(), totalErrors.Load())
}
func trackError(category string, err error) {
totalErrors.Add(1)
key := fmt.Sprintf("%s: %s", category, truncateErr(err))
errMu.Lock()
errorCounts[key]++
errMu.Unlock()
}
func trackErrorMsg(category string, msg string) {
totalErrors.Add(1)
key := fmt.Sprintf("%s: %s", category, msg)
errMu.Lock()
errorCounts[key]++
errMu.Unlock()
}
func truncateErr(err error) string {
s := err.Error()
if len(s) > 120 {
s = s[:120] + "..."
}
return s
}
func generateTopics(n int) []string {
topics := make([]string, n)
for i := 0; i < n; i++ {
b := make([]byte, 8)
rand.Read(b)
topics[i] = "loadtest-" + hex.EncodeToString(b)
}
return topics
}
func pickTopic(topics []string) string {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(topics))))
return topics[n.Int64()]
}
func randomSince() string {
b := make([]byte, 6)
rand.Read(b)
return hex.EncodeToString(b)
}
func randomMessage() string {
messages := []string{
"Test notification",
"Server backup completed successfully",
"Deployment finished",
"Alert: disk usage above 80%",
"Build #1234 passed",
"New order received",
"Temperature sensor reading: 72F",
"Cron job completed",
}
return messages[mrand.Intn(len(messages))]
}
// runAtRate executes fn at approximately the given rate per second
func runAtRate(ctx context.Context, rate float64, fn func()) {
interval := time.Duration(float64(time.Second) / rate)
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
go fn()
}
}
}
// --- Short-lived request types ---
func doPoll(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/json?poll=1&since=%s", baseURL, topic, randomSince())
doGet(ctx, client, url)
}
func doPublishPost(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(randomMessage()))
if err != nil {
trackError("publish_post_req", err)
return
}
// Some messages have titles/priorities like real traffic
if mrand.Float32() < 0.3 {
req.Header.Set("X-Title", "Load Test")
}
if mrand.Float32() < 0.1 {
req.Header.Set("X-Priority", fmt.Sprintf("%d", mrand.Intn(5)+1))
}
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("publish_post", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("publish_post_http", fmt.Sprintf("status %d", resp.StatusCode))
}
}
func doPublishPut(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(randomMessage()))
if err != nil {
trackError("publish_put_req", err)
return
}
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("publish_put", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("publish_put_http", fmt.Sprintf("status %d", resp.StatusCode))
}
}
func doConfig(ctx context.Context, client *http.Client, topics []string) {
url := fmt.Sprintf("%s/v1/config", baseURL)
doGet(ctx, client, url)
}
func doAccountCheck(ctx context.Context, client *http.Client, topics []string) {
url := fmt.Sprintf("%s/v1/account", baseURL)
doGet(ctx, client, url)
}
func doOtherGet(ctx context.Context, client *http.Client, topics []string) {
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s", baseURL, topic)
doGet(ctx, client, url)
}
func doGet(ctx context.Context, client *http.Client, url string) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
trackError("get_req", err)
return
}
resp, err := client.Do(req)
totalRequests.Add(1)
if err != nil {
trackError("get", err)
return
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
if resp.StatusCode >= 400 {
trackErrorMsg("get_http", fmt.Sprintf("status %d for %s", resp.StatusCode, url))
}
}
// --- Long-lived streaming subscriptions ---
func streamSubscription(ctx context.Context, client *http.Client, topics []string, format string) {
for {
if ctx.Err() != nil {
return
}
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/%s?since=all", baseURL, topic, format)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
time.Sleep(time.Second)
continue
}
activeStreams.Add(1)
resp, err := client.Do(req)
if err != nil {
activeStreams.Add(-1)
if ctx.Err() == nil {
trackError("stream_"+format+"_connect", err)
}
time.Sleep(time.Second)
continue
}
if resp.StatusCode >= 400 {
trackErrorMsg("stream_"+format+"_http", fmt.Sprintf("status %d", resp.StatusCode))
resp.Body.Close()
activeStreams.Add(-1)
time.Sleep(time.Second)
continue
}
// Read from stream until context cancelled or connection drops
buf := make([]byte, 4096)
for {
_, err := resp.Body.Read(buf)
if err != nil {
if ctx.Err() == nil {
trackError("stream_"+format+"_read", err)
}
break
}
}
resp.Body.Close()
activeStreams.Add(-1)
// Reconnect with small delay (like real clients do)
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
}
}
}
func wsSubscription(ctx context.Context, topics []string) {
wsURL := strings.Replace(baseURL, "https://", "wss://", 1)
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
for {
if ctx.Err() != nil {
return
}
topic := pickTopic(topics)
url := fmt.Sprintf("%s/%s/ws?since=all", wsURL, topic)
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
activeStreams.Add(1)
conn, _, err := dialer.DialContext(ctx, url, nil)
if err != nil {
activeStreams.Add(-1)
if ctx.Err() == nil {
trackError("ws_connect", err)
}
time.Sleep(time.Second)
continue
}
// Read messages until context cancelled or error
done := make(chan struct{})
go func() {
defer close(done)
for {
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
_, _, err := conn.ReadMessage()
if err != nil {
return
}
}
}()
select {
case <-ctx.Done():
conn.Close()
activeStreams.Add(-1)
return
case <-done:
conn.Close()
activeStreams.Add(-1)
}
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
}
}
}
func reportStats(ctx context.Context) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var lastRequests, lastErrors int64
lastTime := time.Now()
reportCount := 0
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
currentRequests := totalRequests.Load()
currentErrors := totalErrors.Load()
elapsed := now.Sub(lastTime).Seconds()
currentRPS := float64(currentRequests-lastRequests) / elapsed
errorRate := float64(currentErrors-lastErrors) / elapsed
fmt.Printf("[%s] rps=%.1f err/s=%.1f total=%d errors=%d streams=%d\n",
now.Format("15:04:05"),
currentRPS,
errorRate,
currentRequests,
currentErrors,
activeStreams.Load(),
)
// Print error breakdown every 30 seconds
reportCount++
if reportCount%6 == 0 && currentErrors > 0 {
errMu.Lock()
fmt.Printf(" Error breakdown:\n")
for k, v := range errorCounts {
fmt.Printf(" %s: %d\n", k, v)
}
errMu.Unlock()
}
lastRequests = currentRequests
lastErrors = currentErrors
lastTime = now
}
}
}

View File

@@ -22,13 +22,22 @@ pgimport \
--auth-file /var/lib/ntfy/user.db \
--web-push-file /var/lib/ntfy/webpush.db
# Using --create-schema to set up PostgreSQL schema automatically
pgimport \
--create-schema \
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
--cache-file /var/cache/ntfy/cache.db \
--auth-file /var/lib/ntfy/user.db \
--web-push-file /var/lib/ntfy/webpush.db
# Using server.yml (flags override config values)
pgimport --config /etc/ntfy/server.yml
```
## Prerequisites
- PostgreSQL schema must already be set up (run ntfy with `database-url` once)
- PostgreSQL schema must already be set up, either by running ntfy with `database-url` once,
or by passing `--create-schema` to pgimport to create the initial schema automatically
- ntfy must not be running during the import
- All three SQLite files are optional; only the ones specified will be imported

View File

@@ -23,6 +23,159 @@ const (
expectedMessageSchemaVersion = 14
expectedUserSchemaVersion = 6
expectedWebPushSchemaVersion = 1
everyoneID = "u_everyone"
// Initial PostgreSQL schema for message store (from message/cache_postgres_schema.go)
createMessageSchemaQuery = `
CREATE TABLE IF NOT EXISTS message (
id BIGSERIAL PRIMARY KEY,
mid TEXT NOT NULL,
sequence_id TEXT NOT NULL,
time BIGINT NOT NULL,
event TEXT NOT NULL,
expires BIGINT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size BIGINT NOT NULL,
attachment_expires BIGINT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
sender TEXT NOT NULL,
user_id TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
CREATE TABLE IF NOT EXISTS message_stats (
key TEXT PRIMARY KEY,
value BIGINT
);
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schema_version (store, version) VALUES ('message', 14);
`
// Initial PostgreSQL schema for user store (from user/manager_postgres_schema.go)
createUserSchemaQuery = `
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit BIGINT NOT NULL,
messages_expiry_duration BIGINT NOT NULL,
emails_limit BIGINT NOT NULL,
calls_limit BIGINT NOT NULL,
reservations_limit BIGINT NOT NULL,
attachment_file_size_limit BIGINT NOT NULL,
attachment_total_size_limit BIGINT NOT NULL,
attachment_expiry_duration BIGINT NOT NULL,
attachment_bandwidth_limit BIGINT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT,
UNIQUE(code),
UNIQUE(stripe_monthly_price_id),
UNIQUE(stripe_yearly_price_id)
);
CREATE TABLE IF NOT EXISTS "user" (
id TEXT PRIMARY KEY,
tier_id TEXT REFERENCES tier(id),
user_name TEXT NOT NULL UNIQUE,
pass TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
prefs JSONB NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned BOOLEAN NOT NULL,
stats_messages BIGINT NOT NULL DEFAULT 0,
stats_emails BIGINT NOT NULL DEFAULT 0,
stats_calls BIGINT NOT NULL DEFAULT 0,
stripe_customer_id TEXT UNIQUE,
stripe_subscription_id TEXT UNIQUE,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until BIGINT,
stripe_subscription_cancel_at BIGINT,
created BIGINT NOT NULL,
deleted BIGINT
);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
topic TEXT NOT NULL,
read BOOLEAN NOT NULL,
write BOOLEAN NOT NULL,
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
provisioned BOOLEAN NOT NULL,
PRIMARY KEY (user_id, topic)
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
token TEXT NOT NULL UNIQUE,
label TEXT NOT NULL,
last_access BIGINT NOT NULL,
last_origin TEXT NOT NULL,
expires BIGINT NOT NULL,
provisioned BOOLEAN NOT NULL,
PRIMARY KEY (user_id, token)
);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number)
);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
ON CONFLICT (id) DO NOTHING;
INSERT INTO schema_version (store, version) VALUES ('user', 6);
`
// Initial PostgreSQL schema for web push store (from webpush/store_postgres.go)
createWebPushSchemaQuery = `
CREATE TABLE IF NOT EXISTS webpush_subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL UNIQUE,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at BIGINT NOT NULL,
warned_at BIGINT NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
CREATE INDEX IF NOT EXISTS idx_webpush_updated_at ON webpush_subscription (updated_at);
CREATE INDEX IF NOT EXISTS idx_webpush_user_id ON webpush_subscription (user_id);
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic)
);
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schema_version (store, version) VALUES ('webpush', 1);
`
)
var flags = []cli.Flag{
@@ -31,6 +184,7 @@ var flags = []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
&cli.BoolFlag{Name: "create-schema", Usage: "create initial PostgreSQL schema before importing"},
}
func main() {
@@ -88,6 +242,12 @@ func execImport(c *cli.Context) error {
}
defer pgDB.Close()
if c.Bool("create-schema") {
if err := createSchema(pgDB, cacheFile, authFile, webPushFile); err != nil {
return fmt.Errorf("cannot create schema: %w", err)
}
}
if authFile != "" {
if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil {
return err
@@ -139,6 +299,34 @@ func execImport(c *cli.Context) error {
return nil
}
func createSchema(pgDB *sql.DB, cacheFile, authFile, webPushFile string) error {
fmt.Println("Creating initial PostgreSQL schema ...")
// User schema must be created before message schema, because message_stats and
// schema_version use "INSERT INTO" without "ON CONFLICT", so user schema (which
// also creates the schema_version table) must come first.
if authFile != "" {
fmt.Println(" Creating user schema ...")
if _, err := pgDB.Exec(createUserSchemaQuery); err != nil {
return fmt.Errorf("creating user schema: %w", err)
}
}
if cacheFile != "" {
fmt.Println(" Creating message schema ...")
if _, err := pgDB.Exec(createMessageSchemaQuery); err != nil {
return fmt.Errorf("creating message schema: %w", err)
}
}
if webPushFile != "" {
fmt.Println(" Creating web push schema ...")
if _, err := pgDB.Exec(createWebPushSchemaQuery); err != nil {
return fmt.Errorf("creating web push schema: %w", err)
}
}
fmt.Println(" Schema creation complete.")
fmt.Println()
return nil
}
func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc {
return func(c *cli.Context) error {
configFile := c.String(configFlag)

View File

@@ -422,33 +422,14 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
return a.readUser(rows)
}
// Users returns a list of users
// Users returns a list of users. It loads all users in a single query
// rather than one query per user to avoid N+1 performance issues.
func (a *Manager) Users() ([]*User, error) {
rows, err := a.db.Query(a.queries.selectUsernames)
rows, err := a.db.Query(a.queries.selectUsers)
if err != nil {
return nil, err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
rows.Close()
users := make([]*User, 0)
for _, username := range usernames {
user, err := a.User(username)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
return a.readUsers(rows)
}
// UsersCount returns the number of users in the database
@@ -470,14 +451,35 @@ func (a *Manager) UsersCount() (int64, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
if !rows.Next() {
return nil, ErrUserNotFound
}
user, err := a.scanUser(rows)
if err != nil {
return nil, err
}
return user, nil
}
func (a *Manager) readUsers(rows *sql.Rows) ([]*User, error) {
defer rows.Close()
users := make([]*User, 0)
for rows.Next() {
user, err := a.scanUser(rows)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
func (a *Manager) scanUser(rows *sql.Rows) (*User, error) {
var id, username, hash, role, prefs, syncTopic string
var provisioned bool
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
@@ -1244,6 +1246,12 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
if !a.config.ProvisionEnabled {
return nil
}
// If there is nothing to provision, remove any previously provisioned items using
// cheap targeted queries, avoiding the expensive Users() call that loads all users.
if len(a.config.Users) == 0 && len(a.config.Access) == 0 && len(a.config.Tokens) == 0 {
return a.removeAllProvisioned()
}
// If there are provisioned users, do it the slow way
existingUsers, err := a.Users()
if err != nil {
return err
@@ -1269,6 +1277,23 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
})
}
// removeAllProvisioned removes all provisioned users, access entries, and tokens. This is the fast path
// for when there is nothing to provision, avoiding the expensive Users() call.
func (a *Manager) removeAllProvisioned() error {
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
return err
}
if _, err := tx.Exec(a.queries.deleteAllProvisionedTokens); err != nil {
return err
}
if _, err := tx.Exec(a.queries.deleteUsersProvisioned); err != nil {
return err
}
return nil
})
}
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
// It also removes users that are provisioned, but not in the config anymore.
func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {

View File

@@ -7,6 +7,17 @@ import (
// PostgreSQL queries
const (
// User queries
postgresSelectUsersQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
ORDER BY
CASE u.role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, u.user_name
`
postgresSelectUserByIDQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
@@ -56,6 +67,7 @@ const (
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
postgresDeleteUsersProvisionedQuery = `DELETE FROM "user" WHERE provisioned = true`
// Access queries
postgresSelectTopicPermsQuery = `
@@ -146,13 +158,14 @@ const (
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
postgresDeleteExcessTokensQuery = `
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
postgresDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = true`
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
postgresDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = $1
AND (user_id, token) NOT IN (
@@ -210,6 +223,7 @@ var postgresQueries = queries{
selectUserByToken: postgresSelectUserByTokenQuery,
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
selectUsernames: postgresSelectUsernamesQuery,
selectUsers: postgresSelectUsersQuery,
selectUserCount: postgresSelectUserCountQuery,
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
insertUser: postgresInsertUserQuery,
@@ -224,6 +238,7 @@ var postgresQueries = queries{
deleteUser: postgresDeleteUserQuery,
deleteUserTier: postgresDeleteUserTierQuery,
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
deleteUsersProvisioned: postgresDeleteUsersProvisionedQuery,
selectTopicPerms: postgresSelectTopicPermsQuery,
selectUserAllAccess: postgresSelectUserAllAccessQuery,
selectUserAccess: postgresSelectUserAccessQuery,
@@ -246,6 +261,7 @@ var postgresQueries = queries{
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
deleteToken: postgresDeleteTokenQuery,
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
deleteAllProvisionedTokens: postgresDeleteAllProvisionedTokensQuery,
deleteAllToken: postgresDeleteAllTokenQuery,
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
deleteExcessTokens: postgresDeleteExcessTokensQuery,

View File

@@ -12,6 +12,17 @@ import (
const (
// User queries
sqliteSelectUsersQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
ORDER BY
CASE u.role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, u.user
`
sqliteSelectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
@@ -61,6 +72,7 @@ const (
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
sqliteDeleteUsersProvisionedQuery = `DELETE FROM user WHERE provisioned = 1`
// Access queries
sqliteSelectTopicPermsQuery = `
@@ -144,13 +156,14 @@ const (
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
sqliteDeleteExcessTokensQuery = `
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
sqliteDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = 1`
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
sqliteDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = ?
AND (user_id, token) NOT IN (
@@ -207,6 +220,7 @@ var sqliteQueries = queries{
selectUserByToken: sqliteSelectUserByTokenQuery,
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
selectUsernames: sqliteSelectUsernamesQuery,
selectUsers: sqliteSelectUsersQuery,
selectUserCount: sqliteSelectUserCountQuery,
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
insertUser: sqliteInsertUserQuery,
@@ -221,6 +235,7 @@ var sqliteQueries = queries{
deleteUser: sqliteDeleteUserQuery,
deleteUserTier: sqliteDeleteUserTierQuery,
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
deleteUsersProvisioned: sqliteDeleteUsersProvisionedQuery,
selectTopicPerms: sqliteSelectTopicPermsQuery,
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
selectUserAccess: sqliteSelectUserAccessQuery,
@@ -243,6 +258,7 @@ var sqliteQueries = queries{
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
deleteToken: sqliteDeleteTokenQuery,
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
deleteAllProvisionedTokens: sqliteDeleteAllProvisionedTokensQuery,
deleteAllToken: sqliteDeleteAllTokenQuery,
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
deleteExcessTokens: sqliteDeleteExcessTokensQuery,

View File

@@ -1441,6 +1441,54 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
})
}
func TestManager_RemoveProvisionedOnEmptyConfig(t *testing.T) {
forEachBackend(t, func(t *testing.T, newManager newManagerFunc) {
// Start with provisioned users, access, and tokens
conf := &Config{
DefaultAccess: PermissionReadWrite,
ProvisionEnabled: true,
BcryptCost: bcrypt.MinCost,
Users: []*User{
{Name: "provuser", Hash: "$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
},
Access: map[string][]*Grant{
"provuser": {
{TopicPattern: "stats", Permission: PermissionReadWrite},
},
},
Tokens: map[string][]*Token{
"provuser": {
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Provisioned token"},
},
},
}
a := newTestManagerFromConfig(t, newManager, conf)
// Also add a manual (non-provisioned) user
require.Nil(t, a.AddUser("manualuser", "manual", RoleUser, false))
// Verify initial state
users, err := a.Users()
require.Nil(t, err)
require.Len(t, users, 3) // provuser, manualuser, everyone
// Re-open with empty provisioning config (simulates config change)
require.Nil(t, a.Close())
conf.Users = nil
conf.Access = nil
conf.Tokens = nil
a = newTestManagerFromConfig(t, newManager, conf)
// Provisioned user should be removed, manual user should remain
users, err = a.Users()
require.Nil(t, err)
require.Len(t, users, 2)
require.Equal(t, "manualuser", users[0].Name)
require.False(t, users[0].Provisioned)
require.Equal(t, "*", users[1].Name) // everyone
})
}
func TestToFromSQLWildcard(t *testing.T) {
require.Equal(t, "up%", toSQLWildcard("up*"))
require.Equal(t, "up\\_%", toSQLWildcard("up_*"))

View File

@@ -283,6 +283,7 @@ type queries struct {
selectUserByToken string
selectUserByStripeCustomerID string
selectUsernames string
selectUsers string
selectUserCount string
selectUserIDFromUsername string
insertUser string
@@ -297,6 +298,7 @@ type queries struct {
deleteUser string
deleteUserTier string
deleteUsersMarked string
deleteUsersProvisioned string
// Access queries
selectTopicPerms string
@@ -323,6 +325,7 @@ type queries struct {
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
deleteAllProvisionedTokens string
deleteAllToken string
deleteExpiredTokens string
deleteExcessTokens string