mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Users() optimization to help with startup time
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ server/docs/
|
|||||||
server/site/
|
server/site/
|
||||||
tools/fbsend/fbsend
|
tools/fbsend/fbsend
|
||||||
tools/pgimport/pgimport
|
tools/pgimport/pgimport
|
||||||
|
tools/loadtest/loadtest
|
||||||
playground/
|
playground/
|
||||||
secrets/
|
secrets/
|
||||||
*.iml
|
*.iml
|
||||||
|
|||||||
5
tools/loadtest/go.mod
Normal file
5
tools/loadtest/go.mod
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
module loadtest
|
||||||
|
|
||||||
|
go 1.25.2
|
||||||
|
|
||||||
|
require github.com/gorilla/websocket v1.5.3
|
||||||
2
tools/loadtest/go.sum
Normal file
2
tools/loadtest/go.sum
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
522
tools/loadtest/main.go
Normal file
522
tools/loadtest/main.go
Normal file
@@ -0,0 +1,522 @@
|
|||||||
|
// Load test program for ntfy staging server.
|
||||||
|
// Replicates production traffic patterns derived from access.log analysis.
|
||||||
|
//
|
||||||
|
// Traffic profile (from ~5M requests over 20 hours):
|
||||||
|
// ~71 req/sec average, ~4,300 req/min
|
||||||
|
// 49.6% poll requests (GET /TOPIC/json?poll=1&since=ID)
|
||||||
|
// 21.4% publish POST (POST /TOPIC with small body)
|
||||||
|
// 6.2% subscribe stream (GET /TOPIC/json?since=X, long-lived)
|
||||||
|
// 4.1% config check (GET /v1/config)
|
||||||
|
// 2.3% other topic GET (GET /TOPIC)
|
||||||
|
// 2.2% account check (GET /v1/account)
|
||||||
|
// 1.9% websocket sub (GET /TOPIC/ws?since=X)
|
||||||
|
// 1.5% publish PUT (PUT /TOPIC with small body)
|
||||||
|
// 1.5% raw subscribe (GET /TOPIC/raw?since=X)
|
||||||
|
// 1.1% json subscribe (GET /TOPIC/json, no since)
|
||||||
|
// 0.7% SSE subscribe (GET /TOPIC/sse?since=X)
|
||||||
|
// remaining: static, PATCH, OPTIONS, etc. (omitted)
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"math/big"
|
||||||
|
mrand "math/rand"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
baseURL string
|
||||||
|
rps float64
|
||||||
|
scale float64
|
||||||
|
numTopics int
|
||||||
|
subStreams int
|
||||||
|
wsStreams int
|
||||||
|
sseStreams int
|
||||||
|
rawStreams int
|
||||||
|
duration time.Duration
|
||||||
|
|
||||||
|
totalRequests atomic.Int64
|
||||||
|
totalErrors atomic.Int64
|
||||||
|
activeStreams atomic.Int64
|
||||||
|
|
||||||
|
// Error tracking by category
|
||||||
|
errMu sync.Mutex
|
||||||
|
recentErrors []string // last N unique error messages
|
||||||
|
errorCounts = make(map[string]int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server")
|
||||||
|
flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)")
|
||||||
|
flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)")
|
||||||
|
flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use")
|
||||||
|
flag.IntVar(&subStreams, "sub-streams", 200, "Number of concurrent JSON streaming subscriptions")
|
||||||
|
flag.IntVar(&wsStreams, "ws-streams", 50, "Number of concurrent WebSocket subscriptions")
|
||||||
|
flag.IntVar(&sseStreams, "sse-streams", 20, "Number of concurrent SSE subscriptions")
|
||||||
|
flag.IntVar(&rawStreams, "raw-streams", 30, "Number of concurrent raw subscriptions")
|
||||||
|
flag.DurationVar(&duration, "duration", 10*time.Minute, "Test duration")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
rps *= scale
|
||||||
|
subStreams = int(float64(subStreams) * scale)
|
||||||
|
wsStreams = int(float64(wsStreams) * scale)
|
||||||
|
sseStreams = int(float64(sseStreams) * scale)
|
||||||
|
rawStreams = int(float64(rawStreams) * scale)
|
||||||
|
|
||||||
|
topics := generateTopics(numTopics)
|
||||||
|
|
||||||
|
fmt.Printf("ntfy load test\n")
|
||||||
|
fmt.Printf(" Target: %s\n", baseURL)
|
||||||
|
fmt.Printf(" RPS: %.1f\n", rps)
|
||||||
|
fmt.Printf(" Scale: %.1fx\n", scale)
|
||||||
|
fmt.Printf(" Topics: %d\n", numTopics)
|
||||||
|
fmt.Printf(" Sub streams: %d json, %d ws, %d sse, %d raw\n", subStreams, wsStreams, sseStreams, rawStreams)
|
||||||
|
fmt.Printf(" Duration: %s\n", duration)
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Also handle Ctrl+C
|
||||||
|
sigCtx, sigCancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
|
defer sigCancel()
|
||||||
|
ctx = sigCtx
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
MaxIdleConns: 1000,
|
||||||
|
MaxIdleConnsPerHost: 1000,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Long-lived streaming client (no timeout)
|
||||||
|
streamClient := &http.Client{
|
||||||
|
Timeout: 0,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
MaxIdleConns: 500,
|
||||||
|
MaxIdleConnsPerHost: 500,
|
||||||
|
IdleConnTimeout: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start long-lived streaming subscriptions
|
||||||
|
for i := 0; i < subStreams; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
streamSubscription(ctx, streamClient, topics, "json")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for i := 0; i < wsStreams; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
wsSubscription(ctx, topics)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for i := 0; i < sseStreams; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
streamSubscription(ctx, streamClient, topics, "sse")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for i := 0; i < rawStreams; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
streamSubscription(ctx, streamClient, topics, "raw")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start request generators based on traffic weights
|
||||||
|
// Weights from log analysis (normalized to sum ~100):
|
||||||
|
// poll=49.6, publish_post=21.4, config=4.1, other_get=2.3, account=2.2, publish_put=1.5
|
||||||
|
// Total short-lived weight ≈ 81.1
|
||||||
|
type requestType struct {
|
||||||
|
name string
|
||||||
|
weight float64
|
||||||
|
fn func(ctx context.Context, client *http.Client, topics []string)
|
||||||
|
}
|
||||||
|
|
||||||
|
types := []requestType{
|
||||||
|
{"poll", 49.6, doPoll},
|
||||||
|
{"publish_post", 21.4, doPublishPost},
|
||||||
|
{"config", 4.1, doConfig},
|
||||||
|
{"other_get", 2.3, doOtherGet},
|
||||||
|
{"account", 2.2, doAccountCheck},
|
||||||
|
{"publish_put", 1.5, doPublishPut},
|
||||||
|
}
|
||||||
|
|
||||||
|
totalWeight := 0.0
|
||||||
|
for _, t := range types {
|
||||||
|
totalWeight += t.weight
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range types {
|
||||||
|
t := t
|
||||||
|
typeRPS := rps * (t.weight / totalWeight)
|
||||||
|
if typeRPS < 0.1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
runAtRate(ctx, typeRPS, func() {
|
||||||
|
t.fn(ctx, client, topics)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats reporter
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reportStats(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
fmt.Printf("\nDone. Total requests: %d, errors: %d\n", totalRequests.Load(), totalErrors.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func trackError(category string, err error) {
|
||||||
|
totalErrors.Add(1)
|
||||||
|
key := fmt.Sprintf("%s: %s", category, truncateErr(err))
|
||||||
|
errMu.Lock()
|
||||||
|
errorCounts[key]++
|
||||||
|
errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func trackErrorMsg(category string, msg string) {
|
||||||
|
totalErrors.Add(1)
|
||||||
|
key := fmt.Sprintf("%s: %s", category, msg)
|
||||||
|
errMu.Lock()
|
||||||
|
errorCounts[key]++
|
||||||
|
errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateErr(err error) string {
|
||||||
|
s := err.Error()
|
||||||
|
if len(s) > 120 {
|
||||||
|
s = s[:120] + "..."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTopics(n int) []string {
|
||||||
|
topics := make([]string, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
rand.Read(b)
|
||||||
|
topics[i] = "loadtest-" + hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
return topics
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickTopic(topics []string) string {
|
||||||
|
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(topics))))
|
||||||
|
return topics[n.Int64()]
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomSince() string {
|
||||||
|
b := make([]byte, 6)
|
||||||
|
rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomMessage() string {
|
||||||
|
messages := []string{
|
||||||
|
"Test notification",
|
||||||
|
"Server backup completed successfully",
|
||||||
|
"Deployment finished",
|
||||||
|
"Alert: disk usage above 80%",
|
||||||
|
"Build #1234 passed",
|
||||||
|
"New order received",
|
||||||
|
"Temperature sensor reading: 72F",
|
||||||
|
"Cron job completed",
|
||||||
|
}
|
||||||
|
return messages[mrand.Intn(len(messages))]
|
||||||
|
}
|
||||||
|
|
||||||
|
// runAtRate executes fn at approximately the given rate per second
|
||||||
|
func runAtRate(ctx context.Context, rate float64, fn func()) {
|
||||||
|
interval := time.Duration(float64(time.Second) / rate)
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
go fn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Short-lived request types ---
|
||||||
|
|
||||||
|
func doPoll(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s/json?poll=1&since=%s", baseURL, topic, randomSince())
|
||||||
|
doGet(ctx, client, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doPublishPost(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(randomMessage()))
|
||||||
|
if err != nil {
|
||||||
|
trackError("publish_post_req", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Some messages have titles/priorities like real traffic
|
||||||
|
if mrand.Float32() < 0.3 {
|
||||||
|
req.Header.Set("X-Title", "Load Test")
|
||||||
|
}
|
||||||
|
if mrand.Float32() < 0.1 {
|
||||||
|
req.Header.Set("X-Priority", fmt.Sprintf("%d", mrand.Intn(5)+1))
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
totalRequests.Add(1)
|
||||||
|
if err != nil {
|
||||||
|
trackError("publish_post", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
trackErrorMsg("publish_post_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doPublishPut(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(randomMessage()))
|
||||||
|
if err != nil {
|
||||||
|
trackError("publish_put_req", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
totalRequests.Add(1)
|
||||||
|
if err != nil {
|
||||||
|
trackError("publish_put", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
trackErrorMsg("publish_put_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doConfig(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
url := fmt.Sprintf("%s/v1/config", baseURL)
|
||||||
|
doGet(ctx, client, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doAccountCheck(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
url := fmt.Sprintf("%s/v1/account", baseURL)
|
||||||
|
doGet(ctx, client, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doOtherGet(ctx context.Context, client *http.Client, topics []string) {
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||||
|
doGet(ctx, client, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doGet(ctx context.Context, client *http.Client, url string) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
trackError("get_req", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
totalRequests.Add(1)
|
||||||
|
if err != nil {
|
||||||
|
trackError("get", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
io.Copy(io.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
trackErrorMsg("get_http", fmt.Sprintf("status %d for %s", resp.StatusCode, url))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Long-lived streaming subscriptions ---
|
||||||
|
|
||||||
|
func streamSubscription(ctx context.Context, client *http.Client, topics []string, format string) {
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s/%s?since=all", baseURL, topic, format)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
activeStreams.Add(1)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
trackError("stream_"+format+"_connect", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
trackErrorMsg("stream_"+format+"_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||||
|
resp.Body.Close()
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Read from stream until context cancelled or connection drops
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
_, err := resp.Body.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
trackError("stream_"+format+"_read", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
// Reconnect with small delay (like real clients do)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsSubscription(ctx context.Context, topics []string) {
|
||||||
|
wsURL := strings.Replace(baseURL, "https://", "wss://", 1)
|
||||||
|
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
topic := pickTopic(topics)
|
||||||
|
url := fmt.Sprintf("%s/%s/ws?since=all", wsURL, topic)
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
activeStreams.Add(1)
|
||||||
|
conn, _, err := dialer.DialContext(ctx, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
trackError("ws_connect", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read messages until context cancelled or error
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
||||||
|
_, _, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
conn.Close()
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
return
|
||||||
|
case <-done:
|
||||||
|
conn.Close()
|
||||||
|
activeStreams.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportStats(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastRequests, lastErrors int64
|
||||||
|
lastTime := time.Now()
|
||||||
|
reportCount := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
currentRequests := totalRequests.Load()
|
||||||
|
currentErrors := totalErrors.Load()
|
||||||
|
elapsed := now.Sub(lastTime).Seconds()
|
||||||
|
currentRPS := float64(currentRequests-lastRequests) / elapsed
|
||||||
|
errorRate := float64(currentErrors-lastErrors) / elapsed
|
||||||
|
|
||||||
|
fmt.Printf("[%s] rps=%.1f err/s=%.1f total=%d errors=%d streams=%d\n",
|
||||||
|
now.Format("15:04:05"),
|
||||||
|
currentRPS,
|
||||||
|
errorRate,
|
||||||
|
currentRequests,
|
||||||
|
currentErrors,
|
||||||
|
activeStreams.Load(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Print error breakdown every 30 seconds
|
||||||
|
reportCount++
|
||||||
|
if reportCount%6 == 0 && currentErrors > 0 {
|
||||||
|
errMu.Lock()
|
||||||
|
fmt.Printf(" Error breakdown:\n")
|
||||||
|
for k, v := range errorCounts {
|
||||||
|
fmt.Printf(" %s: %d\n", k, v)
|
||||||
|
}
|
||||||
|
errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRequests = currentRequests
|
||||||
|
lastErrors = currentErrors
|
||||||
|
lastTime = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,13 +22,22 @@ pgimport \
|
|||||||
--auth-file /var/lib/ntfy/user.db \
|
--auth-file /var/lib/ntfy/user.db \
|
||||||
--web-push-file /var/lib/ntfy/webpush.db
|
--web-push-file /var/lib/ntfy/webpush.db
|
||||||
|
|
||||||
|
# Using --create-schema to set up PostgreSQL schema automatically
|
||||||
|
pgimport \
|
||||||
|
--create-schema \
|
||||||
|
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
|
||||||
|
--cache-file /var/cache/ntfy/cache.db \
|
||||||
|
--auth-file /var/lib/ntfy/user.db \
|
||||||
|
--web-push-file /var/lib/ntfy/webpush.db
|
||||||
|
|
||||||
# Using server.yml (flags override config values)
|
# Using server.yml (flags override config values)
|
||||||
pgimport --config /etc/ntfy/server.yml
|
pgimport --config /etc/ntfy/server.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
- PostgreSQL schema must already be set up (run ntfy with `database-url` once)
|
- PostgreSQL schema must already be set up, either by running ntfy with `database-url` once,
|
||||||
|
or by passing `--create-schema` to pgimport to create the initial schema automatically
|
||||||
- ntfy must not be running during the import
|
- ntfy must not be running during the import
|
||||||
- All three SQLite files are optional; only the ones specified will be imported
|
- All three SQLite files are optional; only the ones specified will be imported
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,159 @@ const (
|
|||||||
expectedMessageSchemaVersion = 14
|
expectedMessageSchemaVersion = 14
|
||||||
expectedUserSchemaVersion = 6
|
expectedUserSchemaVersion = 6
|
||||||
expectedWebPushSchemaVersion = 1
|
expectedWebPushSchemaVersion = 1
|
||||||
|
|
||||||
|
everyoneID = "u_everyone"
|
||||||
|
|
||||||
|
// Initial PostgreSQL schema for message store (from message/cache_postgres_schema.go)
|
||||||
|
createMessageSchemaQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS message (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
mid TEXT NOT NULL,
|
||||||
|
sequence_id TEXT NOT NULL,
|
||||||
|
time BIGINT NOT NULL,
|
||||||
|
event TEXT NOT NULL,
|
||||||
|
expires BIGINT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags TEXT NOT NULL,
|
||||||
|
click TEXT NOT NULL,
|
||||||
|
icon TEXT NOT NULL,
|
||||||
|
actions TEXT NOT NULL,
|
||||||
|
attachment_name TEXT NOT NULL,
|
||||||
|
attachment_type TEXT NOT NULL,
|
||||||
|
attachment_size BIGINT NOT NULL,
|
||||||
|
attachment_expires BIGINT NOT NULL,
|
||||||
|
attachment_url TEXT NOT NULL,
|
||||||
|
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
sender TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
content_type TEXT NOT NULL,
|
||||||
|
encoding TEXT NOT NULL,
|
||||||
|
published BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
|
||||||
|
CREATE TABLE IF NOT EXISTS message_stats (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value BIGINT
|
||||||
|
);
|
||||||
|
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
store TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO schema_version (store, version) VALUES ('message', 14);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Initial PostgreSQL schema for user store (from user/manager_postgres_schema.go)
|
||||||
|
createUserSchemaQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS tier (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
messages_limit BIGINT NOT NULL,
|
||||||
|
messages_expiry_duration BIGINT NOT NULL,
|
||||||
|
emails_limit BIGINT NOT NULL,
|
||||||
|
calls_limit BIGINT NOT NULL,
|
||||||
|
reservations_limit BIGINT NOT NULL,
|
||||||
|
attachment_file_size_limit BIGINT NOT NULL,
|
||||||
|
attachment_total_size_limit BIGINT NOT NULL,
|
||||||
|
attachment_expiry_duration BIGINT NOT NULL,
|
||||||
|
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||||
|
stripe_monthly_price_id TEXT,
|
||||||
|
stripe_yearly_price_id TEXT,
|
||||||
|
UNIQUE(code),
|
||||||
|
UNIQUE(stripe_monthly_price_id),
|
||||||
|
UNIQUE(stripe_yearly_price_id)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "user" (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
tier_id TEXT REFERENCES tier(id),
|
||||||
|
user_name TEXT NOT NULL UNIQUE,
|
||||||
|
pass TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||||
|
prefs JSONB NOT NULL DEFAULT '{}',
|
||||||
|
sync_topic TEXT NOT NULL,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stripe_customer_id TEXT UNIQUE,
|
||||||
|
stripe_subscription_id TEXT UNIQUE,
|
||||||
|
stripe_subscription_status TEXT,
|
||||||
|
stripe_subscription_interval TEXT,
|
||||||
|
stripe_subscription_paid_until BIGINT,
|
||||||
|
stripe_subscription_cancel_at BIGINT,
|
||||||
|
created BIGINT NOT NULL,
|
||||||
|
deleted BIGINT
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_access (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
read BOOLEAN NOT NULL,
|
||||||
|
write BOOLEAN NOT NULL,
|
||||||
|
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, topic)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_token (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
label TEXT NOT NULL,
|
||||||
|
last_access BIGINT NOT NULL,
|
||||||
|
last_origin TEXT NOT NULL,
|
||||||
|
expires BIGINT NOT NULL,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, token)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_phone (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
phone_number TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, phone_number)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
store TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||||
|
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||||
|
ON CONFLICT (id) DO NOTHING;
|
||||||
|
INSERT INTO schema_version (store, version) VALUES ('user', 6);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Initial PostgreSQL schema for web push store (from webpush/store_postgres.go)
|
||||||
|
createWebPushSchemaQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS webpush_subscription (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
endpoint TEXT NOT NULL UNIQUE,
|
||||||
|
key_auth TEXT NOT NULL,
|
||||||
|
key_p256dh TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
subscriber_ip TEXT NOT NULL,
|
||||||
|
updated_at BIGINT NOT NULL,
|
||||||
|
warned_at BIGINT NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_updated_at ON webpush_subscription (updated_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_user_id ON webpush_subscription (user_id);
|
||||||
|
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
|
||||||
|
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (subscription_id, topic)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
store TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO schema_version (store, version) VALUES ('webpush', 1);
|
||||||
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
var flags = []cli.Flag{
|
var flags = []cli.Flag{
|
||||||
@@ -31,6 +184,7 @@ var flags = []cli.Flag{
|
|||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
|
||||||
|
&cli.BoolFlag{Name: "create-schema", Usage: "create initial PostgreSQL schema before importing"},
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -88,6 +242,12 @@ func execImport(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
defer pgDB.Close()
|
defer pgDB.Close()
|
||||||
|
|
||||||
|
if c.Bool("create-schema") {
|
||||||
|
if err := createSchema(pgDB, cacheFile, authFile, webPushFile); err != nil {
|
||||||
|
return fmt.Errorf("cannot create schema: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if authFile != "" {
|
if authFile != "" {
|
||||||
if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil {
|
if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -139,6 +299,34 @@ func execImport(c *cli.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createSchema(pgDB *sql.DB, cacheFile, authFile, webPushFile string) error {
|
||||||
|
fmt.Println("Creating initial PostgreSQL schema ...")
|
||||||
|
// User schema must be created before message schema, because message_stats and
|
||||||
|
// schema_version use "INSERT INTO" without "ON CONFLICT", so user schema (which
|
||||||
|
// also creates the schema_version table) must come first.
|
||||||
|
if authFile != "" {
|
||||||
|
fmt.Println(" Creating user schema ...")
|
||||||
|
if _, err := pgDB.Exec(createUserSchemaQuery); err != nil {
|
||||||
|
return fmt.Errorf("creating user schema: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cacheFile != "" {
|
||||||
|
fmt.Println(" Creating message schema ...")
|
||||||
|
if _, err := pgDB.Exec(createMessageSchemaQuery); err != nil {
|
||||||
|
return fmt.Errorf("creating message schema: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webPushFile != "" {
|
||||||
|
fmt.Println(" Creating web push schema ...")
|
||||||
|
if _, err := pgDB.Exec(createWebPushSchemaQuery); err != nil {
|
||||||
|
return fmt.Errorf("creating web push schema: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println(" Schema creation complete.")
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc {
|
func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc {
|
||||||
return func(c *cli.Context) error {
|
return func(c *cli.Context) error {
|
||||||
configFile := c.String(configFlag)
|
configFile := c.String(configFlag)
|
||||||
|
|||||||
@@ -422,33 +422,14 @@ func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
|||||||
return a.readUser(rows)
|
return a.readUser(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Users returns a list of users
|
// Users returns a list of users. It loads all users in a single query
|
||||||
|
// rather than one query per user to avoid N+1 performance issues.
|
||||||
func (a *Manager) Users() ([]*User, error) {
|
func (a *Manager) Users() ([]*User, error) {
|
||||||
rows, err := a.db.Query(a.queries.selectUsernames)
|
rows, err := a.db.Query(a.queries.selectUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
return a.readUsers(rows)
|
||||||
usernames := make([]string, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var username string
|
|
||||||
if err := rows.Scan(&username); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
usernames = append(usernames, username)
|
|
||||||
}
|
|
||||||
rows.Close()
|
|
||||||
users := make([]*User, 0)
|
|
||||||
for _, username := range usernames {
|
|
||||||
user, err := a.User(username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
users = append(users, user)
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsersCount returns the number of users in the database
|
// UsersCount returns the number of users in the database
|
||||||
@@ -470,14 +451,35 @@ func (a *Manager) UsersCount() (int64, error) {
|
|||||||
|
|
||||||
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
user, err := a.scanUser(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Manager) readUsers(rows *sql.Rows) ([]*User, error) {
|
||||||
|
defer rows.Close()
|
||||||
|
users := make([]*User, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
user, err := a.scanUser(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Manager) scanUser(rows *sql.Rows) (*User, error) {
|
||||||
var id, username, hash, role, prefs, syncTopic string
|
var id, username, hash, role, prefs, syncTopic string
|
||||||
var provisioned bool
|
var provisioned bool
|
||||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
||||||
var messages, emails, calls int64
|
var messages, emails, calls int64
|
||||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||||
if !rows.Next() {
|
|
||||||
return nil, ErrUserNotFound
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if err := rows.Err(); err != nil {
|
} else if err := rows.Err(); err != nil {
|
||||||
@@ -1244,6 +1246,12 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
|||||||
if !a.config.ProvisionEnabled {
|
if !a.config.ProvisionEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// If there is nothing to provision, remove any previously provisioned items using
|
||||||
|
// cheap targeted queries, avoiding the expensive Users() call that loads all users.
|
||||||
|
if len(a.config.Users) == 0 && len(a.config.Access) == 0 && len(a.config.Tokens) == 0 {
|
||||||
|
return a.removeAllProvisioned()
|
||||||
|
}
|
||||||
|
// If there are provisioned users, do it the slow way
|
||||||
existingUsers, err := a.Users()
|
existingUsers, err := a.Users()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1269,6 +1277,23 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeAllProvisioned removes all provisioned users, access entries, and tokens. This is the fast path
|
||||||
|
// for when there is nothing to provision, avoiding the expensive Users() call.
|
||||||
|
func (a *Manager) removeAllProvisioned() error {
|
||||||
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(a.queries.deleteAllProvisionedTokens); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(a.queries.deleteUsersProvisioned); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
|
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
|
||||||
// It also removes users that are provisioned, but not in the config anymore.
|
// It also removes users that are provisioned, but not in the config anymore.
|
||||||
func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
|
func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
|
||||||
|
|||||||
@@ -7,6 +7,17 @@ import (
|
|||||||
// PostgreSQL queries
|
// PostgreSQL queries
|
||||||
const (
|
const (
|
||||||
// User queries
|
// User queries
|
||||||
|
postgresSelectUsersQuery = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
ORDER BY
|
||||||
|
CASE u.role
|
||||||
|
WHEN 'admin' THEN 1
|
||||||
|
WHEN 'anonymous' THEN 3
|
||||||
|
ELSE 2
|
||||||
|
END, u.user_name
|
||||||
|
`
|
||||||
postgresSelectUserByIDQuery = `
|
postgresSelectUserByIDQuery = `
|
||||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM "user" u
|
FROM "user" u
|
||||||
@@ -56,6 +67,7 @@ const (
|
|||||||
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
|
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
|
||||||
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||||
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
|
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
|
||||||
|
postgresDeleteUsersProvisionedQuery = `DELETE FROM "user" WHERE provisioned = true`
|
||||||
|
|
||||||
// Access queries
|
// Access queries
|
||||||
postgresSelectTopicPermsQuery = `
|
postgresSelectTopicPermsQuery = `
|
||||||
@@ -150,6 +162,7 @@ const (
|
|||||||
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||||
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||||
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
|
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
|
||||||
|
postgresDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = true`
|
||||||
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
|
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
|
||||||
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||||
postgresDeleteExcessTokensQuery = `
|
postgresDeleteExcessTokensQuery = `
|
||||||
@@ -210,6 +223,7 @@ var postgresQueries = queries{
|
|||||||
selectUserByToken: postgresSelectUserByTokenQuery,
|
selectUserByToken: postgresSelectUserByTokenQuery,
|
||||||
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
|
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
|
||||||
selectUsernames: postgresSelectUsernamesQuery,
|
selectUsernames: postgresSelectUsernamesQuery,
|
||||||
|
selectUsers: postgresSelectUsersQuery,
|
||||||
selectUserCount: postgresSelectUserCountQuery,
|
selectUserCount: postgresSelectUserCountQuery,
|
||||||
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
|
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
|
||||||
insertUser: postgresInsertUserQuery,
|
insertUser: postgresInsertUserQuery,
|
||||||
@@ -224,6 +238,7 @@ var postgresQueries = queries{
|
|||||||
deleteUser: postgresDeleteUserQuery,
|
deleteUser: postgresDeleteUserQuery,
|
||||||
deleteUserTier: postgresDeleteUserTierQuery,
|
deleteUserTier: postgresDeleteUserTierQuery,
|
||||||
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
|
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
|
||||||
|
deleteUsersProvisioned: postgresDeleteUsersProvisionedQuery,
|
||||||
selectTopicPerms: postgresSelectTopicPermsQuery,
|
selectTopicPerms: postgresSelectTopicPermsQuery,
|
||||||
selectUserAllAccess: postgresSelectUserAllAccessQuery,
|
selectUserAllAccess: postgresSelectUserAllAccessQuery,
|
||||||
selectUserAccess: postgresSelectUserAccessQuery,
|
selectUserAccess: postgresSelectUserAccessQuery,
|
||||||
@@ -246,6 +261,7 @@ var postgresQueries = queries{
|
|||||||
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
|
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
|
||||||
deleteToken: postgresDeleteTokenQuery,
|
deleteToken: postgresDeleteTokenQuery,
|
||||||
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
|
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
|
||||||
|
deleteAllProvisionedTokens: postgresDeleteAllProvisionedTokensQuery,
|
||||||
deleteAllToken: postgresDeleteAllTokenQuery,
|
deleteAllToken: postgresDeleteAllTokenQuery,
|
||||||
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
|
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
|
||||||
deleteExcessTokens: postgresDeleteExcessTokensQuery,
|
deleteExcessTokens: postgresDeleteExcessTokensQuery,
|
||||||
|
|||||||
@@ -12,6 +12,17 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// User queries
|
// User queries
|
||||||
|
sqliteSelectUsersQuery = `
|
||||||
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM user u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
ORDER BY
|
||||||
|
CASE u.role
|
||||||
|
WHEN 'admin' THEN 1
|
||||||
|
WHEN 'anonymous' THEN 3
|
||||||
|
ELSE 2
|
||||||
|
END, u.user
|
||||||
|
`
|
||||||
sqliteSelectUserByIDQuery = `
|
sqliteSelectUserByIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
@@ -61,6 +72,7 @@ const (
|
|||||||
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
|
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||||
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||||
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||||
|
sqliteDeleteUsersProvisionedQuery = `DELETE FROM user WHERE provisioned = 1`
|
||||||
|
|
||||||
// Access queries
|
// Access queries
|
||||||
sqliteSelectTopicPermsQuery = `
|
sqliteSelectTopicPermsQuery = `
|
||||||
@@ -148,6 +160,7 @@ const (
|
|||||||
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||||
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
|
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
|
||||||
|
sqliteDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = 1`
|
||||||
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||||
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||||
sqliteDeleteExcessTokensQuery = `
|
sqliteDeleteExcessTokensQuery = `
|
||||||
@@ -207,6 +220,7 @@ var sqliteQueries = queries{
|
|||||||
selectUserByToken: sqliteSelectUserByTokenQuery,
|
selectUserByToken: sqliteSelectUserByTokenQuery,
|
||||||
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
|
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
|
||||||
selectUsernames: sqliteSelectUsernamesQuery,
|
selectUsernames: sqliteSelectUsernamesQuery,
|
||||||
|
selectUsers: sqliteSelectUsersQuery,
|
||||||
selectUserCount: sqliteSelectUserCountQuery,
|
selectUserCount: sqliteSelectUserCountQuery,
|
||||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
|
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
|
||||||
insertUser: sqliteInsertUserQuery,
|
insertUser: sqliteInsertUserQuery,
|
||||||
@@ -221,6 +235,7 @@ var sqliteQueries = queries{
|
|||||||
deleteUser: sqliteDeleteUserQuery,
|
deleteUser: sqliteDeleteUserQuery,
|
||||||
deleteUserTier: sqliteDeleteUserTierQuery,
|
deleteUserTier: sqliteDeleteUserTierQuery,
|
||||||
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
|
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
|
||||||
|
deleteUsersProvisioned: sqliteDeleteUsersProvisionedQuery,
|
||||||
selectTopicPerms: sqliteSelectTopicPermsQuery,
|
selectTopicPerms: sqliteSelectTopicPermsQuery,
|
||||||
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
|
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
|
||||||
selectUserAccess: sqliteSelectUserAccessQuery,
|
selectUserAccess: sqliteSelectUserAccessQuery,
|
||||||
@@ -243,6 +258,7 @@ var sqliteQueries = queries{
|
|||||||
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
|
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
|
||||||
deleteToken: sqliteDeleteTokenQuery,
|
deleteToken: sqliteDeleteTokenQuery,
|
||||||
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
|
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
|
||||||
|
deleteAllProvisionedTokens: sqliteDeleteAllProvisionedTokensQuery,
|
||||||
deleteAllToken: sqliteDeleteAllTokenQuery,
|
deleteAllToken: sqliteDeleteAllTokenQuery,
|
||||||
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
|
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
|
||||||
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
|
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
|
||||||
|
|||||||
@@ -1441,6 +1441,54 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_RemoveProvisionedOnEmptyConfig(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, newManager newManagerFunc) {
|
||||||
|
// Start with provisioned users, access, and tokens
|
||||||
|
conf := &Config{
|
||||||
|
DefaultAccess: PermissionReadWrite,
|
||||||
|
ProvisionEnabled: true,
|
||||||
|
BcryptCost: bcrypt.MinCost,
|
||||||
|
Users: []*User{
|
||||||
|
{Name: "provuser", Hash: "$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
||||||
|
},
|
||||||
|
Access: map[string][]*Grant{
|
||||||
|
"provuser": {
|
||||||
|
{TopicPattern: "stats", Permission: PermissionReadWrite},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tokens: map[string][]*Token{
|
||||||
|
"provuser": {
|
||||||
|
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Provisioned token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a := newTestManagerFromConfig(t, newManager, conf)
|
||||||
|
|
||||||
|
// Also add a manual (non-provisioned) user
|
||||||
|
require.Nil(t, a.AddUser("manualuser", "manual", RoleUser, false))
|
||||||
|
|
||||||
|
// Verify initial state
|
||||||
|
users, err := a.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, users, 3) // provuser, manualuser, everyone
|
||||||
|
|
||||||
|
// Re-open with empty provisioning config (simulates config change)
|
||||||
|
require.Nil(t, a.Close())
|
||||||
|
conf.Users = nil
|
||||||
|
conf.Access = nil
|
||||||
|
conf.Tokens = nil
|
||||||
|
a = newTestManagerFromConfig(t, newManager, conf)
|
||||||
|
|
||||||
|
// Provisioned user should be removed, manual user should remain
|
||||||
|
users, err = a.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, users, 2)
|
||||||
|
require.Equal(t, "manualuser", users[0].Name)
|
||||||
|
require.False(t, users[0].Provisioned)
|
||||||
|
require.Equal(t, "*", users[1].Name) // everyone
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestToFromSQLWildcard(t *testing.T) {
|
func TestToFromSQLWildcard(t *testing.T) {
|
||||||
require.Equal(t, "up%", toSQLWildcard("up*"))
|
require.Equal(t, "up%", toSQLWildcard("up*"))
|
||||||
require.Equal(t, "up\\_%", toSQLWildcard("up_*"))
|
require.Equal(t, "up\\_%", toSQLWildcard("up_*"))
|
||||||
|
|||||||
@@ -283,6 +283,7 @@ type queries struct {
|
|||||||
selectUserByToken string
|
selectUserByToken string
|
||||||
selectUserByStripeCustomerID string
|
selectUserByStripeCustomerID string
|
||||||
selectUsernames string
|
selectUsernames string
|
||||||
|
selectUsers string
|
||||||
selectUserCount string
|
selectUserCount string
|
||||||
selectUserIDFromUsername string
|
selectUserIDFromUsername string
|
||||||
insertUser string
|
insertUser string
|
||||||
@@ -297,6 +298,7 @@ type queries struct {
|
|||||||
deleteUser string
|
deleteUser string
|
||||||
deleteUserTier string
|
deleteUserTier string
|
||||||
deleteUsersMarked string
|
deleteUsersMarked string
|
||||||
|
deleteUsersProvisioned string
|
||||||
|
|
||||||
// Access queries
|
// Access queries
|
||||||
selectTopicPerms string
|
selectTopicPerms string
|
||||||
@@ -323,6 +325,7 @@ type queries struct {
|
|||||||
updateTokenLastAccess string
|
updateTokenLastAccess string
|
||||||
deleteToken string
|
deleteToken string
|
||||||
deleteProvisionedToken string
|
deleteProvisionedToken string
|
||||||
|
deleteAllProvisionedTokens string
|
||||||
deleteAllToken string
|
deleteAllToken string
|
||||||
deleteExpiredTokens string
|
deleteExpiredTokens string
|
||||||
deleteExcessTokens string
|
deleteExcessTokens string
|
||||||
|
|||||||
Reference in New Issue
Block a user