From 0d4b1b00e6ecb71b6240c6ffa17564b6a437d6f2 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 6 Mar 2026 14:46:53 -0500 Subject: [PATCH] Users() optimization to help with startup time --- .gitignore | 1 + tools/loadtest/go.mod | 5 + tools/loadtest/go.sum | 2 + tools/loadtest/main.go | 522 +++++++++++++++++++++++++++++++++++++++ tools/pgimport/README.md | 11 +- tools/pgimport/main.go | 188 ++++++++++++++ user/manager.go | 77 ++++-- user/manager_postgres.go | 30 ++- user/manager_sqlite.go | 30 ++- user/manager_test.go | 48 ++++ user/types.go | 3 + 11 files changed, 876 insertions(+), 41 deletions(-) create mode 100644 tools/loadtest/go.mod create mode 100644 tools/loadtest/go.sum create mode 100644 tools/loadtest/main.go diff --git a/.gitignore b/.gitignore index 3a362286..ed17b2d4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ server/docs/ server/site/ tools/fbsend/fbsend tools/pgimport/pgimport +tools/loadtest/loadtest playground/ secrets/ *.iml diff --git a/tools/loadtest/go.mod b/tools/loadtest/go.mod new file mode 100644 index 00000000..3c3034a9 --- /dev/null +++ b/tools/loadtest/go.mod @@ -0,0 +1,5 @@ +module loadtest + +go 1.25.2 + +require github.com/gorilla/websocket v1.5.3 diff --git a/tools/loadtest/go.sum b/tools/loadtest/go.sum new file mode 100644 index 00000000..25a9fc4b --- /dev/null +++ b/tools/loadtest/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/tools/loadtest/main.go b/tools/loadtest/main.go new file mode 100644 index 00000000..1cfbdd56 --- /dev/null +++ b/tools/loadtest/main.go @@ -0,0 +1,522 @@ +// Load test program for ntfy staging server. +// Replicates production traffic patterns derived from access.log analysis. +// +// Traffic profile (from ~5M requests over 20 hours): +// ~71 req/sec average, ~4,300 req/min +// 49.6% poll requests (GET /TOPIC/json?poll=1&since=ID) +// 21.4% publish POST (POST /TOPIC with small body) +// 6.2% subscribe stream (GET /TOPIC/json?since=X, long-lived) +// 4.1% config check (GET /v1/config) +// 2.3% other topic GET (GET /TOPIC) +// 2.2% account check (GET /v1/account) +// 1.9% websocket sub (GET /TOPIC/ws?since=X) +// 1.5% publish PUT (PUT /TOPIC with small body) +// 1.5% raw subscribe (GET /TOPIC/raw?since=X) +// 1.1% json subscribe (GET /TOPIC/json, no since) +// 0.7% SSE subscribe (GET /TOPIC/sse?since=X) +// remaining: static, PATCH, OPTIONS, etc. (omitted) + +package main + +import ( + "context" + "crypto/rand" + "encoding/hex" + "flag" + "fmt" + "io" + + "math/big" + mrand "math/rand" + "net/http" + "os" + "os/signal" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" +) + +var ( + baseURL string + rps float64 + scale float64 + numTopics int + subStreams int + wsStreams int + sseStreams int + rawStreams int + duration time.Duration + + totalRequests atomic.Int64 + totalErrors atomic.Int64 + activeStreams atomic.Int64 + + // Error tracking by category + errMu sync.Mutex + recentErrors []string // last N unique error messages + errorCounts = make(map[string]int64) +) + +func main() { + flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server") + flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)") + flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)") + flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use") + flag.IntVar(&subStreams, "sub-streams", 200, "Number of concurrent JSON streaming subscriptions") + flag.IntVar(&wsStreams, "ws-streams", 50, "Number of concurrent WebSocket subscriptions") + flag.IntVar(&sseStreams, "sse-streams", 20, "Number of concurrent SSE subscriptions") + flag.IntVar(&rawStreams, "raw-streams", 30, "Number of concurrent raw subscriptions") + flag.DurationVar(&duration, "duration", 10*time.Minute, "Test duration") + flag.Parse() + + rps *= scale + subStreams = int(float64(subStreams) * scale) + wsStreams = int(float64(wsStreams) * scale) + sseStreams = int(float64(sseStreams) * scale) + rawStreams = int(float64(rawStreams) * scale) + + topics := generateTopics(numTopics) + + fmt.Printf("ntfy load test\n") + fmt.Printf(" Target: %s\n", baseURL) + fmt.Printf(" RPS: %.1f\n", rps) + fmt.Printf(" Scale: %.1fx\n", scale) + fmt.Printf(" Topics: %d\n", numTopics) + fmt.Printf(" Sub streams: %d json, %d ws, %d sse, %d raw\n", subStreams, wsStreams, sseStreams, rawStreams) + fmt.Printf(" Duration: %s\n", duration) + fmt.Println() + + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + // Also handle Ctrl+C + sigCtx, sigCancel := signal.NotifyContext(ctx, os.Interrupt) + defer sigCancel() + ctx = sigCtx + + client := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 1000, + MaxIdleConnsPerHost: 1000, + IdleConnTimeout: 90 * time.Second, + }, + } + + // Long-lived streaming client (no timeout) + streamClient := &http.Client{ + Timeout: 0, + Transport: &http.Transport{ + MaxIdleConns: 500, + MaxIdleConnsPerHost: 500, + IdleConnTimeout: 0, + }, + } + + var wg sync.WaitGroup + + // Start long-lived streaming subscriptions + for i := 0; i < subStreams; i++ { + wg.Add(1) + go func() { + defer wg.Done() + streamSubscription(ctx, streamClient, topics, "json") + }() + } + for i := 0; i < wsStreams; i++ { + wg.Add(1) + go func() { + defer wg.Done() + wsSubscription(ctx, topics) + }() + } + for i := 0; i < sseStreams; i++ { + wg.Add(1) + go func() { + defer wg.Done() + streamSubscription(ctx, streamClient, topics, "sse") + }() + } + for i := 0; i < rawStreams; i++ { + wg.Add(1) + go func() { + defer wg.Done() + streamSubscription(ctx, streamClient, topics, "raw") + }() + } + + // Start request generators based on traffic weights + // Weights from log analysis (normalized to sum ~100): + // poll=49.6, publish_post=21.4, config=4.1, other_get=2.3, account=2.2, publish_put=1.5 + // Total short-lived weight ≈ 81.1 + type requestType struct { + name string + weight float64 + fn func(ctx context.Context, client *http.Client, topics []string) + } + + types := []requestType{ + {"poll", 49.6, doPoll}, + {"publish_post", 21.4, doPublishPost}, + {"config", 4.1, doConfig}, + {"other_get", 2.3, doOtherGet}, + {"account", 2.2, doAccountCheck}, + {"publish_put", 1.5, doPublishPut}, + } + + totalWeight := 0.0 + for _, t := range types { + totalWeight += t.weight + } + + for _, t := range types { + t := t + typeRPS := rps * (t.weight / totalWeight) + if typeRPS < 0.1 { + continue + } + wg.Add(1) + go func() { + defer wg.Done() + runAtRate(ctx, typeRPS, func() { + t.fn(ctx, client, topics) + }) + }() + } + + // Stats reporter + wg.Add(1) + go func() { + defer wg.Done() + reportStats(ctx) + }() + + wg.Wait() + fmt.Printf("\nDone. Total requests: %d, errors: %d\n", totalRequests.Load(), totalErrors.Load()) +} + +func trackError(category string, err error) { + totalErrors.Add(1) + key := fmt.Sprintf("%s: %s", category, truncateErr(err)) + errMu.Lock() + errorCounts[key]++ + errMu.Unlock() +} + +func trackErrorMsg(category string, msg string) { + totalErrors.Add(1) + key := fmt.Sprintf("%s: %s", category, msg) + errMu.Lock() + errorCounts[key]++ + errMu.Unlock() +} + +func truncateErr(err error) string { + s := err.Error() + if len(s) > 120 { + s = s[:120] + "..." + } + return s +} + +func generateTopics(n int) []string { + topics := make([]string, n) + for i := 0; i < n; i++ { + b := make([]byte, 8) + rand.Read(b) + topics[i] = "loadtest-" + hex.EncodeToString(b) + } + return topics +} + +func pickTopic(topics []string) string { + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(topics)))) + return topics[n.Int64()] +} + +func randomSince() string { + b := make([]byte, 6) + rand.Read(b) + return hex.EncodeToString(b) +} + +func randomMessage() string { + messages := []string{ + "Test notification", + "Server backup completed successfully", + "Deployment finished", + "Alert: disk usage above 80%", + "Build #1234 passed", + "New order received", + "Temperature sensor reading: 72F", + "Cron job completed", + } + return messages[mrand.Intn(len(messages))] +} + +// runAtRate executes fn at approximately the given rate per second +func runAtRate(ctx context.Context, rate float64, fn func()) { + interval := time.Duration(float64(time.Second) / rate) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + go fn() + } + } +} + +// --- Short-lived request types --- + +func doPoll(ctx context.Context, client *http.Client, topics []string) { + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s/json?poll=1&since=%s", baseURL, topic, randomSince()) + doGet(ctx, client, url) +} + +func doPublishPost(ctx context.Context, client *http.Client, topics []string) { + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s", baseURL, topic) + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(randomMessage())) + if err != nil { + trackError("publish_post_req", err) + return + } + // Some messages have titles/priorities like real traffic + if mrand.Float32() < 0.3 { + req.Header.Set("X-Title", "Load Test") + } + if mrand.Float32() < 0.1 { + req.Header.Set("X-Priority", fmt.Sprintf("%d", mrand.Intn(5)+1)) + } + resp, err := client.Do(req) + totalRequests.Add(1) + if err != nil { + trackError("publish_post", err) + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode >= 400 { + trackErrorMsg("publish_post_http", fmt.Sprintf("status %d", resp.StatusCode)) + } +} + +func doPublishPut(ctx context.Context, client *http.Client, topics []string) { + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s", baseURL, topic) + req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(randomMessage())) + if err != nil { + trackError("publish_put_req", err) + return + } + resp, err := client.Do(req) + totalRequests.Add(1) + if err != nil { + trackError("publish_put", err) + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode >= 400 { + trackErrorMsg("publish_put_http", fmt.Sprintf("status %d", resp.StatusCode)) + } +} + +func doConfig(ctx context.Context, client *http.Client, topics []string) { + url := fmt.Sprintf("%s/v1/config", baseURL) + doGet(ctx, client, url) +} + +func doAccountCheck(ctx context.Context, client *http.Client, topics []string) { + url := fmt.Sprintf("%s/v1/account", baseURL) + doGet(ctx, client, url) +} + +func doOtherGet(ctx context.Context, client *http.Client, topics []string) { + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s", baseURL, topic) + doGet(ctx, client, url) +} + +func doGet(ctx context.Context, client *http.Client, url string) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + trackError("get_req", err) + return + } + resp, err := client.Do(req) + totalRequests.Add(1) + if err != nil { + trackError("get", err) + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode >= 400 { + trackErrorMsg("get_http", fmt.Sprintf("status %d for %s", resp.StatusCode, url)) + } +} + +// --- Long-lived streaming subscriptions --- + +func streamSubscription(ctx context.Context, client *http.Client, topics []string, format string) { + for { + if ctx.Err() != nil { + return + } + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s/%s?since=all", baseURL, topic, format) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + time.Sleep(time.Second) + continue + } + activeStreams.Add(1) + resp, err := client.Do(req) + if err != nil { + activeStreams.Add(-1) + if ctx.Err() == nil { + trackError("stream_"+format+"_connect", err) + } + time.Sleep(time.Second) + continue + } + if resp.StatusCode >= 400 { + trackErrorMsg("stream_"+format+"_http", fmt.Sprintf("status %d", resp.StatusCode)) + resp.Body.Close() + activeStreams.Add(-1) + time.Sleep(time.Second) + continue + } + // Read from stream until context cancelled or connection drops + buf := make([]byte, 4096) + for { + _, err := resp.Body.Read(buf) + if err != nil { + if ctx.Err() == nil { + trackError("stream_"+format+"_read", err) + } + break + } + } + resp.Body.Close() + activeStreams.Add(-1) + // Reconnect with small delay (like real clients do) + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond): + } + } +} + +func wsSubscription(ctx context.Context, topics []string) { + wsURL := strings.Replace(baseURL, "https://", "wss://", 1) + wsURL = strings.Replace(wsURL, "http://", "ws://", 1) + + for { + if ctx.Err() != nil { + return + } + topic := pickTopic(topics) + url := fmt.Sprintf("%s/%s/ws?since=all", wsURL, topic) + + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + activeStreams.Add(1) + conn, _, err := dialer.DialContext(ctx, url, nil) + if err != nil { + activeStreams.Add(-1) + if ctx.Err() == nil { + trackError("ws_connect", err) + } + time.Sleep(time.Second) + continue + } + + // Read messages until context cancelled or error + done := make(chan struct{}) + go func() { + defer close(done) + for { + conn.SetReadDeadline(time.Now().Add(5 * time.Minute)) + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }() + + select { + case <-ctx.Done(): + conn.Close() + activeStreams.Add(-1) + return + case <-done: + conn.Close() + activeStreams.Add(-1) + } + + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond): + } + } +} + +func reportStats(ctx context.Context) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + var lastRequests, lastErrors int64 + lastTime := time.Now() + reportCount := 0 + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + currentRequests := totalRequests.Load() + currentErrors := totalErrors.Load() + elapsed := now.Sub(lastTime).Seconds() + currentRPS := float64(currentRequests-lastRequests) / elapsed + errorRate := float64(currentErrors-lastErrors) / elapsed + + fmt.Printf("[%s] rps=%.1f err/s=%.1f total=%d errors=%d streams=%d\n", + now.Format("15:04:05"), + currentRPS, + errorRate, + currentRequests, + currentErrors, + activeStreams.Load(), + ) + + // Print error breakdown every 30 seconds + reportCount++ + if reportCount%6 == 0 && currentErrors > 0 { + errMu.Lock() + fmt.Printf(" Error breakdown:\n") + for k, v := range errorCounts { + fmt.Printf(" %s: %d\n", k, v) + } + errMu.Unlock() + } + + lastRequests = currentRequests + lastErrors = currentErrors + lastTime = now + } + } +} diff --git a/tools/pgimport/README.md b/tools/pgimport/README.md index f7352111..eff6ccad 100644 --- a/tools/pgimport/README.md +++ b/tools/pgimport/README.md @@ -22,13 +22,22 @@ pgimport \ --auth-file /var/lib/ntfy/user.db \ --web-push-file /var/lib/ntfy/webpush.db +# Using --create-schema to set up PostgreSQL schema automatically +pgimport \ + --create-schema \ + --database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \ + --cache-file /var/cache/ntfy/cache.db \ + --auth-file /var/lib/ntfy/user.db \ + --web-push-file /var/lib/ntfy/webpush.db + # Using server.yml (flags override config values) pgimport --config /etc/ntfy/server.yml ``` ## Prerequisites -- PostgreSQL schema must already be set up (run ntfy with `database-url` once) +- PostgreSQL schema must already be set up, either by running ntfy with `database-url` once, + or by passing `--create-schema` to pgimport to create the initial schema automatically - ntfy must not be running during the import - All three SQLite files are optional; only the ones specified will be imported diff --git a/tools/pgimport/main.go b/tools/pgimport/main.go index bf92a2a7..cbc171dd 100644 --- a/tools/pgimport/main.go +++ b/tools/pgimport/main.go @@ -23,6 +23,159 @@ const ( expectedMessageSchemaVersion = 14 expectedUserSchemaVersion = 6 expectedWebPushSchemaVersion = 1 + + everyoneID = "u_everyone" + + // Initial PostgreSQL schema for message store (from message/cache_postgres_schema.go) + createMessageSchemaQuery = ` + CREATE TABLE IF NOT EXISTS message ( + id BIGSERIAL PRIMARY KEY, + mid TEXT NOT NULL, + sequence_id TEXT NOT NULL, + time BIGINT NOT NULL, + event TEXT NOT NULL, + expires BIGINT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size BIGINT NOT NULL, + attachment_expires BIGINT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE, + sender TEXT NOT NULL, + user_id TEXT NOT NULL, + content_type TEXT NOT NULL, + encoding TEXT NOT NULL, + published BOOLEAN NOT NULL DEFAULT FALSE + ); + CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid); + CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id); + CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id); + CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires); + CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = ''; + CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires); + CREATE TABLE IF NOT EXISTS message_stats ( + key TEXT PRIMARY KEY, + value BIGINT + ); + INSERT INTO message_stats (key, value) VALUES ('messages', 0); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schema_version (store, version) VALUES ('message', 14); + ` + + // Initial PostgreSQL schema for user store (from user/manager_postgres_schema.go) + createUserSchemaQuery = ` + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit BIGINT NOT NULL, + messages_expiry_duration BIGINT NOT NULL, + emails_limit BIGINT NOT NULL, + calls_limit BIGINT NOT NULL, + reservations_limit BIGINT NOT NULL, + attachment_file_size_limit BIGINT NOT NULL, + attachment_total_size_limit BIGINT NOT NULL, + attachment_expiry_duration BIGINT NOT NULL, + attachment_bandwidth_limit BIGINT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT, + UNIQUE(code), + UNIQUE(stripe_monthly_price_id), + UNIQUE(stripe_yearly_price_id) + ); + CREATE TABLE IF NOT EXISTS "user" ( + id TEXT PRIMARY KEY, + tier_id TEXT REFERENCES tier(id), + user_name TEXT NOT NULL UNIQUE, + pass TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')), + prefs JSONB NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned BOOLEAN NOT NULL, + stats_messages BIGINT NOT NULL DEFAULT 0, + stats_emails BIGINT NOT NULL DEFAULT 0, + stats_calls BIGINT NOT NULL DEFAULT 0, + stripe_customer_id TEXT UNIQUE, + stripe_subscription_id TEXT UNIQUE, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until BIGINT, + stripe_subscription_cancel_at BIGINT, + created BIGINT NOT NULL, + deleted BIGINT + ); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + topic TEXT NOT NULL, + read BOOLEAN NOT NULL, + write BOOLEAN NOT NULL, + owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, topic) + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + token TEXT NOT NULL UNIQUE, + label TEXT NOT NULL, + last_access BIGINT NOT NULL, + last_origin TEXT NOT NULL, + expires BIGINT NOT NULL, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, token) + ); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number) + ); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT) + ON CONFLICT (id) DO NOTHING; + INSERT INTO schema_version (store, version) VALUES ('user', 6); + ` + + // Initial PostgreSQL schema for web push store (from webpush/store_postgres.go) + createWebPushSchemaQuery = ` + CREATE TABLE IF NOT EXISTS webpush_subscription ( + id TEXT PRIMARY KEY, + endpoint TEXT NOT NULL UNIQUE, + key_auth TEXT NOT NULL, + key_p256dh TEXT NOT NULL, + user_id TEXT NOT NULL, + subscriber_ip TEXT NOT NULL, + updated_at BIGINT NOT NULL, + warned_at BIGINT NOT NULL DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip); + CREATE INDEX IF NOT EXISTS idx_webpush_updated_at ON webpush_subscription (updated_at); + CREATE INDEX IF NOT EXISTS idx_webpush_user_id ON webpush_subscription (user_id); + CREATE TABLE IF NOT EXISTS webpush_subscription_topic ( + subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE, + topic TEXT NOT NULL, + PRIMARY KEY (subscription_id, topic) + ); + CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schema_version (store, version) VALUES ('webpush', 1); + ` ) var flags = []cli.Flag{ @@ -31,6 +184,7 @@ var flags = []cli.Flag{ altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}), + &cli.BoolFlag{Name: "create-schema", Usage: "create initial PostgreSQL schema before importing"}, } func main() { @@ -88,6 +242,12 @@ func execImport(c *cli.Context) error { } defer pgDB.Close() + if c.Bool("create-schema") { + if err := createSchema(pgDB, cacheFile, authFile, webPushFile); err != nil { + return fmt.Errorf("cannot create schema: %w", err) + } + } + if authFile != "" { if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil { return err @@ -139,6 +299,34 @@ func execImport(c *cli.Context) error { return nil } +func createSchema(pgDB *sql.DB, cacheFile, authFile, webPushFile string) error { + fmt.Println("Creating initial PostgreSQL schema ...") + // User schema must be created before message schema, because message_stats and + // schema_version use "INSERT INTO" without "ON CONFLICT", so user schema (which + // also creates the schema_version table) must come first. + if authFile != "" { + fmt.Println(" Creating user schema ...") + if _, err := pgDB.Exec(createUserSchemaQuery); err != nil { + return fmt.Errorf("creating user schema: %w", err) + } + } + if cacheFile != "" { + fmt.Println(" Creating message schema ...") + if _, err := pgDB.Exec(createMessageSchemaQuery); err != nil { + return fmt.Errorf("creating message schema: %w", err) + } + } + if webPushFile != "" { + fmt.Println(" Creating web push schema ...") + if _, err := pgDB.Exec(createWebPushSchemaQuery); err != nil { + return fmt.Errorf("creating web push schema: %w", err) + } + } + fmt.Println(" Schema creation complete.") + fmt.Println() + return nil +} + func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc { return func(c *cli.Context) error { configFile := c.String(configFlag) diff --git a/user/manager.go b/user/manager.go index cd37cb3b..0ee6a6e1 100644 --- a/user/manager.go +++ b/user/manager.go @@ -422,33 +422,14 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { return a.readUser(rows) } -// Users returns a list of users +// Users returns a list of users. It loads all users in a single query +// rather than one query per user to avoid N+1 performance issues. func (a *Manager) Users() ([]*User, error) { - rows, err := a.db.Query(a.queries.selectUsernames) + rows, err := a.db.Query(a.queries.selectUsers) if err != nil { return nil, err } - defer rows.Close() - usernames := make([]string, 0) - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - usernames = append(usernames, username) - } - rows.Close() - users := make([]*User, 0) - for _, username := range usernames { - user, err := a.User(username) - if err != nil { - return nil, err - } - users = append(users, user) - } - return users, nil + return a.readUsers(rows) } // UsersCount returns the number of users in the database @@ -470,14 +451,35 @@ func (a *Manager) UsersCount() (int64, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() + if !rows.Next() { + return nil, ErrUserNotFound + } + user, err := a.scanUser(rows) + if err != nil { + return nil, err + } + return user, nil +} + +func (a *Manager) readUsers(rows *sql.Rows) ([]*User, error) { + defer rows.Close() + users := make([]*User, 0) + for rows.Next() { + user, err := a.scanUser(rows) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + +func (a *Manager) scanUser(rows *sql.Rows) (*User, error) { var id, username, hash, role, prefs, syncTopic string var provisioned bool var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString var messages, emails, calls int64 var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 - if !rows.Next() { - return nil, ErrUserNotFound - } if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { @@ -1244,6 +1246,12 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error { if !a.config.ProvisionEnabled { return nil } + // If there is nothing to provision, remove any previously provisioned items using + // cheap targeted queries, avoiding the expensive Users() call that loads all users. + if len(a.config.Users) == 0 && len(a.config.Access) == 0 && len(a.config.Tokens) == 0 { + return a.removeAllProvisioned() + } + // If there are provisioned users, do it the slow way existingUsers, err := a.Users() if err != nil { return err @@ -1269,6 +1277,23 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error { }) } +// removeAllProvisioned removes all provisioned users, access entries, and tokens. This is the fast path +// for when there is nothing to provision, avoiding the expensive Users() call. +func (a *Manager) removeAllProvisioned() error { + return db.ExecTx(a.db, func(tx *sql.Tx) error { + if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil { + return err + } + if _, err := tx.Exec(a.queries.deleteAllProvisionedTokens); err != nil { + return err + } + if _, err := tx.Exec(a.queries.deleteUsersProvisioned); err != nil { + return err + } + return nil + }) +} + // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them. // It also removes users that are provisioned, but not in the config anymore. func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error { diff --git a/user/manager_postgres.go b/user/manager_postgres.go index bc8e3852..7138ae2c 100644 --- a/user/manager_postgres.go +++ b/user/manager_postgres.go @@ -7,6 +7,17 @@ import ( // PostgreSQL queries const ( // User queries + postgresSelectUsersQuery = ` + SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM "user" u + LEFT JOIN tier t on t.id = u.tier_id + ORDER BY + CASE u.role + WHEN 'admin' THEN 1 + WHEN 'anonymous' THEN 3 + ELSE 2 + END, u.user_name + ` postgresSelectUserByIDQuery = ` SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM "user" u @@ -56,6 +67,7 @@ const ( postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1` postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1` postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1` + postgresDeleteUsersProvisionedQuery = `DELETE FROM "user" WHERE provisioned = true` // Access queries postgresSelectTopicPermsQuery = ` @@ -146,13 +158,14 @@ const ( ON CONFLICT (user_id, token) DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned ` - postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4` - postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3` - postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2` - postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1` - postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1` - postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1` - postgresDeleteExcessTokensQuery = ` + postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4` + postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3` + postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2` + postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1` + postgresDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = true` + postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1` + postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1` + postgresDeleteExcessTokensQuery = ` DELETE FROM user_token WHERE user_id = $1 AND (user_id, token) NOT IN ( @@ -210,6 +223,7 @@ var postgresQueries = queries{ selectUserByToken: postgresSelectUserByTokenQuery, selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery, selectUsernames: postgresSelectUsernamesQuery, + selectUsers: postgresSelectUsersQuery, selectUserCount: postgresSelectUserCountQuery, selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery, insertUser: postgresInsertUserQuery, @@ -224,6 +238,7 @@ var postgresQueries = queries{ deleteUser: postgresDeleteUserQuery, deleteUserTier: postgresDeleteUserTierQuery, deleteUsersMarked: postgresDeleteUsersMarkedQuery, + deleteUsersProvisioned: postgresDeleteUsersProvisionedQuery, selectTopicPerms: postgresSelectTopicPermsQuery, selectUserAllAccess: postgresSelectUserAllAccessQuery, selectUserAccess: postgresSelectUserAccessQuery, @@ -246,6 +261,7 @@ var postgresQueries = queries{ updateTokenLastAccess: postgresUpdateTokenLastAccessQuery, deleteToken: postgresDeleteTokenQuery, deleteProvisionedToken: postgresDeleteProvisionedTokenQuery, + deleteAllProvisionedTokens: postgresDeleteAllProvisionedTokensQuery, deleteAllToken: postgresDeleteAllTokenQuery, deleteExpiredTokens: postgresDeleteExpiredTokensQuery, deleteExcessTokens: postgresDeleteExcessTokensQuery, diff --git a/user/manager_sqlite.go b/user/manager_sqlite.go index e09c682c..b4068599 100644 --- a/user/manager_sqlite.go +++ b/user/manager_sqlite.go @@ -12,6 +12,17 @@ import ( const ( // User queries + sqliteSelectUsersQuery = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM user u + LEFT JOIN tier t on t.id = u.tier_id + ORDER BY + CASE u.role + WHEN 'admin' THEN 1 + WHEN 'anonymous' THEN 3 + ELSE 2 + END, u.user + ` sqliteSelectUserByIDQuery = ` SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u @@ -61,6 +72,7 @@ const ( sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?` sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?` + sqliteDeleteUsersProvisionedQuery = `DELETE FROM user WHERE provisioned = 1` // Access queries sqliteSelectTopicPermsQuery = ` @@ -144,13 +156,14 @@ const ( ON CONFLICT (user_id, token) DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned ` - sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?` - sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` - sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` - sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?` - sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` - sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` - sqliteDeleteExcessTokensQuery = ` + sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?` + sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` + sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` + sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?` + sqliteDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = 1` + sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` + sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` + sqliteDeleteExcessTokensQuery = ` DELETE FROM user_token WHERE user_id = ? AND (user_id, token) NOT IN ( @@ -207,6 +220,7 @@ var sqliteQueries = queries{ selectUserByToken: sqliteSelectUserByTokenQuery, selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery, selectUsernames: sqliteSelectUsernamesQuery, + selectUsers: sqliteSelectUsersQuery, selectUserCount: sqliteSelectUserCountQuery, selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery, insertUser: sqliteInsertUserQuery, @@ -221,6 +235,7 @@ var sqliteQueries = queries{ deleteUser: sqliteDeleteUserQuery, deleteUserTier: sqliteDeleteUserTierQuery, deleteUsersMarked: sqliteDeleteUsersMarkedQuery, + deleteUsersProvisioned: sqliteDeleteUsersProvisionedQuery, selectTopicPerms: sqliteSelectTopicPermsQuery, selectUserAllAccess: sqliteSelectUserAllAccessQuery, selectUserAccess: sqliteSelectUserAccessQuery, @@ -243,6 +258,7 @@ var sqliteQueries = queries{ updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery, deleteToken: sqliteDeleteTokenQuery, deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery, + deleteAllProvisionedTokens: sqliteDeleteAllProvisionedTokensQuery, deleteAllToken: sqliteDeleteAllTokenQuery, deleteExpiredTokens: sqliteDeleteExpiredTokensQuery, deleteExcessTokens: sqliteDeleteExcessTokensQuery, diff --git a/user/manager_test.go b/user/manager_test.go index c353acb8..53cae1d1 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -1441,6 +1441,54 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { }) } +func TestManager_RemoveProvisionedOnEmptyConfig(t *testing.T) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + // Start with provisioned users, access, and tokens + conf := &Config{ + DefaultAccess: PermissionReadWrite, + ProvisionEnabled: true, + BcryptCost: bcrypt.MinCost, + Users: []*User{ + {Name: "provuser", Hash: "$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser}, + }, + Access: map[string][]*Grant{ + "provuser": { + {TopicPattern: "stats", Permission: PermissionReadWrite}, + }, + }, + Tokens: map[string][]*Token{ + "provuser": { + {Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Provisioned token"}, + }, + }, + } + a := newTestManagerFromConfig(t, newManager, conf) + + // Also add a manual (non-provisioned) user + require.Nil(t, a.AddUser("manualuser", "manual", RoleUser, false)) + + // Verify initial state + users, err := a.Users() + require.Nil(t, err) + require.Len(t, users, 3) // provuser, manualuser, everyone + + // Re-open with empty provisioning config (simulates config change) + require.Nil(t, a.Close()) + conf.Users = nil + conf.Access = nil + conf.Tokens = nil + a = newTestManagerFromConfig(t, newManager, conf) + + // Provisioned user should be removed, manual user should remain + users, err = a.Users() + require.Nil(t, err) + require.Len(t, users, 2) + require.Equal(t, "manualuser", users[0].Name) + require.False(t, users[0].Provisioned) + require.Equal(t, "*", users[1].Name) // everyone + }) +} + func TestToFromSQLWildcard(t *testing.T) { require.Equal(t, "up%", toSQLWildcard("up*")) require.Equal(t, "up\\_%", toSQLWildcard("up_*")) diff --git a/user/types.go b/user/types.go index e909ec78..08c65220 100644 --- a/user/types.go +++ b/user/types.go @@ -283,6 +283,7 @@ type queries struct { selectUserByToken string selectUserByStripeCustomerID string selectUsernames string + selectUsers string selectUserCount string selectUserIDFromUsername string insertUser string @@ -297,6 +298,7 @@ type queries struct { deleteUser string deleteUserTier string deleteUsersMarked string + deleteUsersProvisioned string // Access queries selectTopicPerms string @@ -323,6 +325,7 @@ type queries struct { updateTokenLastAccess string deleteToken string deleteProvisionedToken string + deleteAllProvisionedTokens string deleteAllToken string deleteExpiredTokens string deleteExcessTokens string