diff --git a/tools/loadtest/main.go b/tools/loadtest/main.go index 1cfbdd56..be8a65fc 100644 --- a/tools/loadtest/main.go +++ b/tools/loadtest/main.go @@ -41,6 +41,8 @@ import ( var ( baseURL string + username string + password string rps float64 scale float64 numTopics int @@ -62,6 +64,8 @@ var ( func main() { flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server") + flag.StringVar(&username, "user", "", "Username for authentication") + flag.StringVar(&password, "pass", "", "Password for authentication") flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)") flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)") flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use") @@ -222,6 +226,12 @@ func truncateErr(err error) string { return s } +func setAuth(req *http.Request) { + if username != "" && password != "" { + req.SetBasicAuth(username, password) + } +} + func generateTopics(n int) []string { topics := make([]string, n) for i := 0; i < n; i++ { @@ -289,6 +299,7 @@ func doPublishPost(ctx context.Context, client *http.Client, topics []string) { trackError("publish_post_req", err) return } + setAuth(req) // Some messages have titles/priorities like real traffic if mrand.Float32() < 0.3 { req.Header.Set("X-Title", "Load Test") @@ -317,6 +328,7 @@ func doPublishPut(ctx context.Context, client *http.Client, topics []string) { trackError("publish_put_req", err) return } + setAuth(req) resp, err := client.Do(req) totalRequests.Add(1) if err != nil { @@ -352,6 +364,7 @@ func doGet(ctx context.Context, client *http.Client, url string) { trackError("get_req", err) return } + setAuth(req) resp, err := client.Do(req) totalRequests.Add(1) if err != nil { @@ -379,6 +392,7 @@ func streamSubscription(ctx context.Context, client *http.Client, topics []strin time.Sleep(time.Second) continue } + setAuth(req) activeStreams.Add(1) resp, err := client.Do(req) if err != nil { @@ -432,8 +446,15 @@ func wsSubscription(ctx context.Context, topics []string) { dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, } + var wsHeader http.Header + if username != "" && password != "" { + wsHeader = http.Header{} + req, _ := http.NewRequest("GET", url, nil) + req.SetBasicAuth(username, password) + wsHeader.Set("Authorization", req.Header.Get("Authorization")) + } activeStreams.Add(1) - conn, _, err := dialer.DialContext(ctx, url, nil) + conn, _, err := dialer.DialContext(ctx, url, wsHeader) if err != nil { activeStreams.Add(-1) if ctx.Err() == nil {