mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Compare commits
178 Commits
1364-copy-
...
attachment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b11bc7468 | ||
|
|
d9efe50848 | ||
|
|
2ad78edca1 | ||
|
|
86015e100c | ||
|
|
458fbad770 | ||
|
|
9b1a32ec56 | ||
|
|
3d9ce69042 | ||
|
|
59ce581ba2 | ||
|
|
df82fdf44c | ||
|
|
3a37ea32f7 | ||
|
|
790ba243c7 | ||
|
|
4487299a80 | ||
|
|
6b38acb23a | ||
|
|
f5c255c53c | ||
|
|
fd0a49244e | ||
|
|
4699ed3ffd | ||
|
|
1afb99db67 | ||
|
|
66208e6f88 | ||
|
|
ce24594c32 | ||
|
|
888850d8bc | ||
|
|
be09acd411 | ||
|
|
bf19a5be2d | ||
|
|
b4ec6fa8df | ||
|
|
d517ce4a2a | ||
|
|
fd8f356d1f | ||
|
|
3296d158c5 | ||
|
|
45f045a5a4 | ||
|
|
f7b6e9bbe3 | ||
|
|
22868f4742 | ||
|
|
3801a28958 | ||
|
|
2bf8f6271b | ||
|
|
13be9747e4 | ||
|
|
26dd017401 | ||
|
|
d00cd64220 | ||
|
|
fab08e862d | ||
|
|
143935b917 | ||
|
|
a82ede8a14 | ||
|
|
8a34dfe3f8 | ||
|
|
270fec51a6 | ||
|
|
9eaadd74cf | ||
|
|
1f483dcbd3 | ||
|
|
85bdfc61ce | ||
|
|
ac65df1e83 | ||
|
|
ab33ac7ae5 | ||
|
|
f1865749d7 | ||
|
|
997e20fa3f | ||
|
|
3402510b47 | ||
|
|
19d1618bb8 | ||
|
|
612afb1435 | ||
|
|
2b36ad9eb9 | ||
|
|
bcd07115c2 | ||
|
|
109271a930 | ||
|
|
fcf95dc9b8 | ||
|
|
79c3ab9ecc | ||
|
|
d51465fb6a | ||
|
|
0b189f65ff | ||
|
|
0d4b1b00e6 | ||
|
|
28c3fd5cbe | ||
|
|
62bb335675 | ||
|
|
70fb2732af | ||
|
|
8e91e028a0 | ||
|
|
6d22f568f9 | ||
|
|
59e6c16633 | ||
|
|
2e0b934bc2 | ||
|
|
4f4a093f8d | ||
|
|
5610b7c56d | ||
|
|
d4038f566c | ||
|
|
bff2b47eb6 | ||
|
|
33b19814c7 | ||
|
|
fb26e7ef3a | ||
|
|
66449bd19b | ||
|
|
bedbb121e4 | ||
|
|
c4b8cfa756 | ||
|
|
c864a9baeb | ||
|
|
8afeb813d9 | ||
|
|
ea4739f79b | ||
|
|
31f0234098 | ||
|
|
0e6b467cf0 | ||
|
|
11c79a6369 | ||
|
|
10a6939d8e | ||
|
|
9736973286 | ||
|
|
941c43c10b | ||
|
|
af76aa011d | ||
|
|
039566bcaf | ||
|
|
544ce112b5 | ||
|
|
c19377109e | ||
|
|
ccbd02331c | ||
|
|
542aa403d2 | ||
|
|
ebb48e217d | ||
|
|
7710ace184 | ||
|
|
b937b44f2d | ||
|
|
e618cf1a39 | ||
|
|
e9cf2b5523 | ||
|
|
c49a8179cf | ||
|
|
a1cca7972d | ||
|
|
da1c7b1949 | ||
|
|
5c26e70fe7 | ||
|
|
c66fa92341 | ||
|
|
a7d5a9c5d8 | ||
|
|
391cd2c920 | ||
|
|
9eec72adcc | ||
|
|
28a436c0d2 | ||
|
|
b02366b42b | ||
|
|
90d0eca14d | ||
|
|
0d375d3a08 | ||
|
|
811c7ae25a | ||
|
|
850a9d4cc4 | ||
|
|
43280fbc0a | ||
|
|
35a54407a8 | ||
|
|
f726cc768e | ||
|
|
5d301e7dce | ||
|
|
8b12bdeb3a | ||
|
|
390cff0604 | ||
|
|
6375c2ce60 | ||
|
|
459c80ef9b | ||
|
|
b1eb90addc | ||
|
|
4b6979aa89 | ||
|
|
c76e39bb0e | ||
|
|
b82e1c3915 | ||
|
|
07e60ba041 | ||
|
|
4e22e7f4c8 | ||
|
|
cf3ae187ce | ||
|
|
d19b825192 | ||
|
|
2e499389fc | ||
|
|
7c69a76345 | ||
|
|
a28d8e7924 | ||
|
|
13a3062a7f | ||
|
|
eb6e1ac44a | ||
|
|
a4c836b531 | ||
|
|
e818b063f7 | ||
|
|
039d555689 | ||
|
|
209d5a4c62 | ||
|
|
bf265449ac | ||
|
|
4cbd80c68e | ||
|
|
305e3fc9af | ||
|
|
9e4a48b058 | ||
|
|
93cd7f99f8 | ||
|
|
28e85df36e | ||
|
|
939b3d1117 | ||
|
|
9cc9891f49 | ||
|
|
0d1f3444f2 | ||
|
|
2716ede6e1 | ||
|
|
2bc7b5217b | ||
|
|
046c0e8c79 | ||
|
|
652b2097ad | ||
|
|
ae5e1fe8d8 | ||
|
|
e3a402ed95 | ||
|
|
1abc1005d0 | ||
|
|
909c3fe17b | ||
|
|
07c3e280bf | ||
|
|
b567b4e904 | ||
|
|
60fa50f0d5 | ||
|
|
ceda5ec3d8 | ||
|
|
3d72845c81 | ||
|
|
82e15d84bd | ||
|
|
4e5f95ba0c | ||
|
|
869b972a50 | ||
|
|
bdd20197b3 | ||
|
|
a8dcecdb6d | ||
|
|
5331437664 | ||
|
|
e432bf2886 | ||
|
|
0edad84d86 | ||
|
|
ddf728acd1 | ||
|
|
b1d3671dbb | ||
|
|
3e6b46ec0c | ||
|
|
b16d381626 | ||
|
|
3bd1a1ea03 | ||
|
|
7adb37b94b | ||
|
|
bc08819525 | ||
|
|
a03a37feb1 | ||
|
|
4cd556f5aa | ||
|
|
90aeb811ff | ||
|
|
c6ab380ea4 | ||
|
|
7860f2142c | ||
|
|
18d5d31bd2 | ||
|
|
cfdc364e3f | ||
|
|
763215ecfa | ||
|
|
49991d5aa7 |
16
.github/workflows/release.yaml
vendored
16
.github/workflows/release.yaml
vendored
@@ -6,6 +6,22 @@ on:
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: ntfy
|
||||
POSTGRES_PASSWORD: ntfy
|
||||
POSTGRES_DB: ntfy_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U ntfy"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
20
.github/workflows/test.yaml
vendored
20
.github/workflows/test.yaml
vendored
@@ -3,6 +3,22 @@ on: [ push, pull_request ]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: ntfy
|
||||
POSTGRES_PASSWORD: ntfy
|
||||
POSTGRES_DB: ntfy_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U ntfy"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
@@ -23,8 +39,6 @@ jobs:
|
||||
- name: Build web app (required for tests)
|
||||
run: make web
|
||||
- name: Run tests, formatting, vetting and linting
|
||||
run: make check
|
||||
run: make checkv
|
||||
- name: Run coverage
|
||||
run: make coverage
|
||||
- name: Upload coverage to codecov.io
|
||||
run: make coverage-upload
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,6 +7,9 @@ build/
|
||||
server/docs/
|
||||
server/site/
|
||||
tools/fbsend/fbsend
|
||||
tools/pgimport/pgimport
|
||||
tools/loadtest/loadtest
|
||||
tools/s3cli/s3cli
|
||||
playground/
|
||||
secrets/
|
||||
*.iml
|
||||
|
||||
@@ -40,6 +40,7 @@ ADD ./log ./log
|
||||
ADD ./server ./server
|
||||
ADD ./user ./user
|
||||
ADD ./util ./util
|
||||
ADD ./payments ./payments
|
||||
RUN make VERSION=$VERSION COMMIT=$COMMIT cli-linux-server
|
||||
|
||||
FROM alpine
|
||||
|
||||
25
Makefile
25
Makefile
@@ -1,4 +1,5 @@
|
||||
MAKEFLAGS := --jobs=1
|
||||
NPM := npm
|
||||
PYTHON := python3
|
||||
PIP := pip3
|
||||
VERSION := $(shell git describe --tag)
|
||||
@@ -137,7 +138,7 @@ web: web-deps web-build
|
||||
|
||||
web-build:
|
||||
cd web \
|
||||
&& npm run build \
|
||||
&& $(NPM) run build \
|
||||
&& mv build/index.html build/app.html \
|
||||
&& rm -rf ../server/site \
|
||||
&& mv build ../server/site \
|
||||
@@ -145,20 +146,20 @@ web-build:
|
||||
../server/site/config.js
|
||||
|
||||
web-deps:
|
||||
cd web && npm install
|
||||
cd web && $(NPM) install
|
||||
# If this fails for .svg files, optimize them with svgo
|
||||
|
||||
web-deps-update:
|
||||
cd web && npm update
|
||||
cd web && $(NPM) update
|
||||
|
||||
web-fmt:
|
||||
cd web && npm run format
|
||||
cd web && $(NPM) run format
|
||||
|
||||
web-fmt-check:
|
||||
cd web && npm run format:check
|
||||
cd web && $(NPM) run format:check
|
||||
|
||||
web-lint:
|
||||
cd web && npm run lint
|
||||
cd web && $(NPM) run lint
|
||||
|
||||
# Main server/client build
|
||||
|
||||
@@ -264,23 +265,25 @@ cli-build-results:
|
||||
|
||||
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||
|
||||
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||
|
||||
test: .PHONY
|
||||
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go test $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
|
||||
|
||||
testv: .PHONY
|
||||
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go test -v $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
|
||||
|
||||
race: .PHONY
|
||||
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go test -v -race $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
|
||||
|
||||
coverage:
|
||||
mkdir -p build/coverage
|
||||
go test -v -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go test -v -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools|web)')
|
||||
go tool cover -func build/coverage/coverage.txt
|
||||
|
||||
coverage-html:
|
||||
mkdir -p build/coverage
|
||||
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
go test -race -coverprofile=build/coverage/coverage.txt -covermode=atomic $(shell go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./... | grep -vE 'ntfy/v2/(test|examples|tools)')
|
||||
go tool cover -html build/coverage/coverage.txt
|
||||
|
||||
coverage-upload:
|
||||
|
||||
25
attachment/store.go
Normal file
25
attachment/store.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Store is an interface for storing and retrieving attachment files
|
||||
type Store interface {
|
||||
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
|
||||
Read(id string) (io.ReadCloser, int64, error)
|
||||
Remove(ids ...string) error
|
||||
Size() int64
|
||||
Remaining() int64
|
||||
}
|
||||
|
||||
var (
|
||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
|
||||
errInvalidFileID = errors.New("invalid file ID")
|
||||
)
|
||||
@@ -1,31 +1,29 @@
|
||||
package server
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
|
||||
errInvalidFileID = errors.New("invalid file ID")
|
||||
errFileExists = errors.New("file exists")
|
||||
)
|
||||
const tagFileStore = "file_store"
|
||||
|
||||
type fileCache struct {
|
||||
var errFileExists = errors.New("file exists")
|
||||
|
||||
type fileStore struct {
|
||||
dir string
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
|
||||
// NewFileStore creates a new file-system backed attachment store
|
||||
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -33,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fileCache{
|
||||
return &fileStore{
|
||||
dir: dir,
|
||||
totalSizeCurrent: size,
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
return 0, errFileExists
|
||||
@@ -67,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
mset(metricAttachmentsTotalSize, c.totalSizeCurrent)
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Remove(ids ...string) error {
|
||||
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
file := filepath.Join(c.dir, id)
|
||||
stat, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, stat.Size(), nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if err := os.Remove(file); err != nil {
|
||||
log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
}
|
||||
}
|
||||
size, err := dirSize(c.dir)
|
||||
@@ -90,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
mset(metricAttachmentsTotalSize, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Size() int64 {
|
||||
func (c *fileStore) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *fileCache) Remaining() int64 {
|
||||
func (c *fileStore) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
99
attachment/store_file_test.go
Normal file
99
attachment/store_file_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
oneKilobyteArray = make([]byte, 1024)
|
||||
)
|
||||
|
||||
func TestFileStore_Write_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
|
||||
require.Equal(t, int64(11), s.Size())
|
||||
require.Equal(t, int64(10229), s.Remaining())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_Read_Success(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
|
||||
reader, readSize, err := s.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
defer reader.Close()
|
||||
data, err := io.ReadAll(reader)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
}
|
||||
|
||||
func TestFileStore_Write_Remove_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
|
||||
for i := 0; i < 10; i++ { // 10x999 = 9990
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(999), size)
|
||||
}
|
||||
require.Equal(t, int64(9990), s.Size())
|
||||
require.Equal(t, int64(250), s.Remaining())
|
||||
require.FileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk5")
|
||||
|
||||
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.NoFileExists(t, dir+"/abcdefghijk5")
|
||||
require.Equal(t, int64(7992), s.Size())
|
||||
require.Equal(t, int64(2248), s.Remaining())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
for i := 0; i < 10; i++ {
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024), size)
|
||||
}
|
||||
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkX")
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkl")
|
||||
}
|
||||
|
||||
func TestFileStore_Read_NotFound(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
_, _, err := s.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func newTestFileStore(t *testing.T) (dir string, store Store) {
|
||||
dir = t.TempDir()
|
||||
store, err := NewFileStore(dir, 10*1024)
|
||||
require.Nil(t, err)
|
||||
return dir, store
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, f string) string {
|
||||
b, err := os.ReadFile(f)
|
||||
require.Nil(t, err)
|
||||
return string(b)
|
||||
}
|
||||
150
attachment/store_s3.go
Normal file
150
attachment/store_s3.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagS3Store = "s3_store"
|
||||
)
|
||||
|
||||
type s3Store struct {
|
||||
client *s3.Client
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
|
||||
cfg, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
store := &s3Store{
|
||||
client: s3.New(cfg),
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
if totalSizeLimit > 0 {
|
||||
size, err := store.computeSize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
|
||||
}
|
||||
store.totalSizeCurrent = size
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
|
||||
|
||||
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
|
||||
// uploads, so no temp file or Content-Length is needed.
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
pr, pw := io.Pipe()
|
||||
lw := util.NewLimitWriter(pw, limiters...)
|
||||
var size int64
|
||||
var copyErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
size, copyErr = io.Copy(lw, in)
|
||||
if copyErr != nil {
|
||||
pw.CloseWithError(copyErr)
|
||||
} else {
|
||||
pw.Close()
|
||||
}
|
||||
}()
|
||||
putErr := c.client.PutObject(context.Background(), id, pr)
|
||||
pr.Close()
|
||||
<-done
|
||||
if copyErr != nil {
|
||||
return 0, copyErr
|
||||
}
|
||||
if putErr != nil {
|
||||
return 0, putErr
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
return c.client.GetObject(context.Background(), id)
|
||||
}
|
||||
|
||||
func (c *s3Store) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
}
|
||||
// S3 DeleteObjects supports up to 1000 keys per call
|
||||
for i := 0; i < len(ids); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
batch := ids[i:end]
|
||||
for _, id := range batch {
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
|
||||
}
|
||||
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
|
||||
size, err := c.computeSize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *s3Store) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
|
||||
func (c *s3Store) computeSize() (int64, error) {
|
||||
objects, err := c.client.ListAllObjects(context.Background())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
return totalSize, nil
|
||||
}
|
||||
367
attachment/store_s3_test.go
Normal file
367
attachment/store_s3_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// --- Integration tests using a mock S3 server ---
|
||||
|
||||
func TestS3Store_WriteReadRemove(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
// Write
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, int64(11), store.Size())
|
||||
|
||||
// Read back
|
||||
reader, readSize, err := store.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
|
||||
// Remove
|
||||
require.Nil(t, store.Remove("abcdefghijkl"))
|
||||
require.Equal(t, int64(0), store.Size())
|
||||
|
||||
// Read after remove should fail
|
||||
_, _, err = store.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteNoPrefix(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
|
||||
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(4), size)
|
||||
|
||||
reader, _, err := store.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test", string(data))
|
||||
}
|
||||
|
||||
func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
|
||||
|
||||
// First write fits
|
||||
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(80), store.Size())
|
||||
require.Equal(t, int64(20), store.Remaining())
|
||||
|
||||
// Second write exceeds total limit
|
||||
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, int64(500), store.Size())
|
||||
|
||||
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
|
||||
require.Equal(t, int64(300), store.Size())
|
||||
}
|
||||
|
||||
func TestS3Store_ReadNotFound(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, _, err := store.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestS3Store_InvalidID(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("bad", strings.NewReader("x"))
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
_, _, err = store.Read("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
err = store.Remove("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
|
||||
t.Helper()
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
s := &s3Store{
|
||||
client: &s3.Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
},
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
// Compute initial size (should be 0 for fresh mock)
|
||||
size, err := s.computeSize()
|
||||
require.Nil(t, err)
|
||||
s.totalSizeCurrent = size
|
||||
return s
|
||||
}
|
||||
|
||||
// --- Mock S3 server ---
|
||||
//
|
||||
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() *httptest.Server {
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[path] = body
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
// bucketPath is just the bucket name
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Objects []struct {
|
||||
Key string `xml:"Key"`
|
||||
} `xml:"Object"`
|
||||
}
|
||||
if err := xml.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
for _, obj := range req.Objects {
|
||||
delete(m.objects, bucketPath+"/"+obj.Key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
prefix := r.URL.Query().Get("prefix")
|
||||
m.mu.RLock()
|
||||
var contents []s3ListObject
|
||||
for key, body := range m.objects {
|
||||
// key is "bucket/objectkey", strip bucket prefix
|
||||
objKey := strings.TrimPrefix(key, bucketPath+"/")
|
||||
if objKey == key {
|
||||
continue // different bucket
|
||||
}
|
||||
if prefix == "" || strings.HasPrefix(objKey, prefix) {
|
||||
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
resp := s3ListResponse{
|
||||
Contents: contents,
|
||||
IsTruncated: false,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
xml.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
type s3ListResponse struct {
|
||||
XMLName xml.Name `xml:"ListBucketResult"`
|
||||
Contents []s3ListObject `xml:"Contents"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
}
|
||||
|
||||
type s3ListObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -14,9 +11,14 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
|
||||
t.Skip("temporarily disabled") // FIXME
|
||||
testMessage := util.RandomString(10)
|
||||
app, _, _, _ := newTestApp()
|
||||
require.Nil(t, app.Run([]string{"ntfy", "publish", "ntfytest", "ntfy unit test " + testMessage}))
|
||||
|
||||
50
cmd/serve.go
50
cmd/serve.go
@@ -39,6 +39,8 @@ var flagsServe = append(
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "key-file", Aliases: []string{"key_file", "K"}, EnvVars: []string{"NTFY_KEY_FILE"}, Usage: "private key file, if listen-https is set"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "database-replica-urls", Aliases: []string{"database_replica_urls"}, EnvVars: []string{"NTFY_DATABASE_REPLICA_URLS"}, Usage: "PostgreSQL read replica connection strings for offloading read queries"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
|
||||
@@ -51,6 +53,7 @@ var flagsServe = append(
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}),
|
||||
@@ -143,6 +146,8 @@ func execServe(c *cli.Context) error {
|
||||
keyFile := c.String("key-file")
|
||||
certFile := c.String("cert-file")
|
||||
firebaseKeyFile := c.String("firebase-key-file")
|
||||
databaseURL := c.String("database-url")
|
||||
databaseReplicaURLs := c.StringSlice("database-replica-urls")
|
||||
webPushPrivateKey := c.String("web-push-private-key")
|
||||
webPushPublicKey := c.String("web-push-public-key")
|
||||
webPushFile := c.String("web-push-file")
|
||||
@@ -162,6 +167,7 @@ func execServe(c *cli.Context) error {
|
||||
authAccessRaw := c.StringSlice("auth-access")
|
||||
authTokensRaw := c.StringSlice("auth-tokens")
|
||||
attachmentCacheDir := c.String("attachment-cache-dir")
|
||||
attachmentS3URL := c.String("attachment-s3-url")
|
||||
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
|
||||
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
|
||||
attachmentExpiryDurationStr := c.String("attachment-expiry-duration")
|
||||
@@ -280,12 +286,18 @@ func execServe(c *cli.Context) error {
|
||||
}
|
||||
|
||||
// Check values
|
||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") && !strings.HasPrefix(databaseURL, "postgresql://") {
|
||||
return errors.New("if database-url is set, it must start with postgres:// or postgresql://")
|
||||
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||
} else if len(databaseReplicaURLs) > 0 && databaseURL == "" {
|
||||
return errors.New("database-replica-urls can only be used if database-url is also set")
|
||||
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||
return errors.New("if set, FCM key file must exist")
|
||||
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
||||
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
||||
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || webPushFile == "" || webPushEmailAddress == "" || baseURL == "") {
|
||||
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file, web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
|
||||
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || (webPushFile == "" && databaseURL == "") || webPushEmailAddress == "" || baseURL == "") {
|
||||
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file (or database-url), web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
|
||||
} else if keepaliveInterval < 5*time.Second {
|
||||
return errors.New("keepalive interval cannot be lower than five seconds")
|
||||
} else if managerInterval < 5*time.Second {
|
||||
@@ -304,6 +316,10 @@ func execServe(c *cli.Context) error {
|
||||
return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set")
|
||||
} else if attachmentCacheDir != "" && baseURL == "" {
|
||||
return errors.New("if attachment-cache-dir is set, base-url must also be set")
|
||||
} else if attachmentS3URL != "" && baseURL == "" {
|
||||
return errors.New("if attachment-s3-url is set, base-url must also be set")
|
||||
} else if attachmentS3URL != "" && attachmentCacheDir != "" {
|
||||
return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive")
|
||||
} else if baseURL != "" {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
@@ -321,8 +337,8 @@ func execServe(c *cli.Context) error {
|
||||
return errors.New("if upstream-base-url is set, base-url must also be set")
|
||||
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
|
||||
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
|
||||
} else if authFile == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
|
||||
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
|
||||
} else if authFile == "" && databaseURL == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
|
||||
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file or database-url is not set")
|
||||
} else if enableSignup && !enableLogin {
|
||||
return errors.New("cannot set enable-signup without also setting enable-login")
|
||||
} else if requireLogin && !enableLogin {
|
||||
@@ -331,8 +347,8 @@ func execServe(c *cli.Context) error {
|
||||
return errors.New("cannot set stripe-secret-key or stripe-webhook-key, support for payments is not available in this build (nopayments)")
|
||||
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
|
||||
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
|
||||
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || authFile == "") {
|
||||
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file must also be set")
|
||||
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || (authFile == "" && databaseURL == "")) {
|
||||
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file (or database-url) must also be set")
|
||||
} else if messageSizeLimit > server.DefaultMessageSizeLimit {
|
||||
log.Warn("message-size-limit is greater than 4K, this is not recommended and largely untested, and may lead to issues with some clients")
|
||||
if messageSizeLimit > 5*1024*1024 {
|
||||
@@ -412,6 +428,15 @@ func execServe(c *cli.Context) error {
|
||||
payments.Setup(stripeSecretKey)
|
||||
}
|
||||
|
||||
// Parse Twilio template
|
||||
var twilioCallFormatTemplate *template.Template
|
||||
if twilioCallFormat != "" {
|
||||
twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add default forbidden topics
|
||||
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
|
||||
|
||||
@@ -438,6 +463,7 @@ func execServe(c *cli.Context) error {
|
||||
conf.AuthAccess = authAccess
|
||||
conf.AuthTokens = authTokens
|
||||
conf.AttachmentCacheDir = attachmentCacheDir
|
||||
conf.AttachmentS3URL = attachmentS3URL
|
||||
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
|
||||
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
|
||||
conf.AttachmentExpiryDuration = attachmentExpiryDuration
|
||||
@@ -459,13 +485,7 @@ func execServe(c *cli.Context) error {
|
||||
conf.TwilioAuthToken = twilioAuthToken
|
||||
conf.TwilioPhoneNumber = twilioPhoneNumber
|
||||
conf.TwilioVerifyService = twilioVerifyService
|
||||
if twilioCallFormat != "" {
|
||||
tmpl, err := template.New("twiml").Parse(twilioCallFormat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
||||
}
|
||||
conf.TwilioCallFormat = tmpl
|
||||
}
|
||||
conf.TwilioCallFormat = twilioCallFormatTemplate
|
||||
conf.MessageSizeLimit = int(messageSizeLimit)
|
||||
conf.MessageDelayMax = messageDelayLimit
|
||||
conf.TotalTopicLimit = totalTopicLimit
|
||||
@@ -494,6 +514,8 @@ func execServe(c *cli.Context) error {
|
||||
conf.EnableMetrics = enableMetrics
|
||||
conf.MetricsListenHTTP = metricsListenHTTP
|
||||
conf.ProfileListenHTTP = profileListenHTTP
|
||||
conf.DatabaseURL = databaseURL
|
||||
conf.DatabaseReplicaURLs = databaseReplicaURLs
|
||||
conf.WebPushPrivateKey = webPushPrivateKey
|
||||
conf.WebPushPublicKey = webPushPublicKey
|
||||
conf.WebPushFile = webPushFile
|
||||
|
||||
29
cmd/user.go
29
cmd/user.go
@@ -6,13 +6,15 @@ import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -29,6 +31,7 @@ var flagsUser = append(
|
||||
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
|
||||
)
|
||||
|
||||
var cmdUser = &cli.Command{
|
||||
@@ -365,24 +368,30 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
authFile := c.String("auth-file")
|
||||
authStartupQueries := c.String("auth-startup-queries")
|
||||
authDefaultAccess := c.String("auth-default-access")
|
||||
if authFile == "" {
|
||||
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
|
||||
} else if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
databaseURL := c.String("database-url")
|
||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||
if err != nil {
|
||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||
}
|
||||
authConfig := &user.Config{
|
||||
Filename: authFile,
|
||||
StartupQueries: authStartupQueries,
|
||||
DefaultAccess: authDefault,
|
||||
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
return user.NewManager(authConfig)
|
||||
if databaseURL != "" {
|
||||
host, dbErr := pg.Open(databaseURL)
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
return user.NewPostgresManager(db.New(host, nil), authConfig)
|
||||
} else if authFile != "" {
|
||||
if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
return user.NewSQLiteManager(authFile, authStartupQueries, authConfig)
|
||||
}
|
||||
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||
}
|
||||
|
||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||
|
||||
137
db/db.go
Normal file
137
db/db.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
const (
|
||||
tag = "db"
|
||||
replicaHealthCheckInitialDelay = 5 * time.Second
|
||||
replicaHealthCheckInterval = 30 * time.Second
|
||||
replicaHealthCheckTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// DB wraps a primary *sql.DB and optional read replicas. All standard query/exec methods
|
||||
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
|
||||
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
|
||||
type DB struct {
|
||||
primary *Host
|
||||
replicas []*Host
|
||||
counter atomic.Uint64
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New creates a new DB that wraps the given primary and optional replica connections.
|
||||
// If replicas is nil or empty, ReadOnly() simply returns the primary.
|
||||
// Replicas start unhealthy and are checked immediately by a background goroutine.
|
||||
func New(primary *Host, replicas []*Host) *DB {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
d := &DB{
|
||||
primary: primary,
|
||||
replicas: replicas,
|
||||
cancel: cancel,
|
||||
}
|
||||
if len(d.replicas) > 0 {
|
||||
go d.healthCheckLoop(ctx)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// Query delegates to the primary database.
|
||||
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return d.primary.DB.Query(query, args...)
|
||||
}
|
||||
|
||||
// QueryRow delegates to the primary database.
|
||||
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
return d.primary.DB.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
// Exec delegates to the primary database.
|
||||
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return d.primary.DB.Exec(query, args...)
|
||||
}
|
||||
|
||||
// Begin delegates to the primary database.
|
||||
func (d *DB) Begin() (*sql.Tx, error) {
|
||||
return d.primary.DB.Begin()
|
||||
}
|
||||
|
||||
// Ping delegates to the primary database.
|
||||
func (d *DB) Ping() error {
|
||||
return d.primary.DB.Ping()
|
||||
}
|
||||
|
||||
// Primary returns the underlying primary *sql.DB. This is only intended for
|
||||
// one-time schema setup during store initialization, not for regular queries.
|
||||
func (d *DB) Primary() *sql.DB {
|
||||
return d.primary.DB
|
||||
}
|
||||
|
||||
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
|
||||
// replicas. If all replicas are unhealthy or none are configured, the primary is returned.
|
||||
func (d *DB) ReadOnly() *sql.DB {
|
||||
if len(d.replicas) == 0 {
|
||||
return d.primary.DB
|
||||
}
|
||||
n := len(d.replicas)
|
||||
start := int(d.counter.Add(1) - 1)
|
||||
for i := 0; i < n; i++ {
|
||||
r := d.replicas[(start+i)%n]
|
||||
if r.healthy.Load() {
|
||||
return r.DB
|
||||
}
|
||||
}
|
||||
return d.primary.DB
|
||||
}
|
||||
|
||||
// Close closes the primary database and all replicas, and stops the health-check goroutine.
|
||||
func (d *DB) Close() error {
|
||||
d.cancel()
|
||||
for _, r := range d.replicas {
|
||||
r.DB.Close()
|
||||
}
|
||||
return d.primary.DB.Close()
|
||||
}
|
||||
|
||||
// healthCheckLoop checks replicas immediately, then periodically on a ticker.
|
||||
func (d *DB) healthCheckLoop(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(replicaHealthCheckInitialDelay):
|
||||
d.checkReplicas(ctx)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(replicaHealthCheckInterval):
|
||||
d.checkReplicas(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkReplicas pings each replica with a timeout and updates its health status.
|
||||
func (d *DB) checkReplicas(ctx context.Context) {
|
||||
for _, r := range d.replicas {
|
||||
wasHealthy := r.healthy.Load()
|
||||
pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout)
|
||||
err := r.DB.PingContext(pingCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
r.healthy.Store(false)
|
||||
log.Tag(tag).Error("Database replica %s is unhealthy: %s", r.Addr, err)
|
||||
} else {
|
||||
r.healthy.Store(true)
|
||||
if !wasHealthy {
|
||||
log.Tag(tag).Info("Database replica %s is healthy", r.Addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
120
db/pg/pg.go
Normal file
120
db/pg/pg.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// Open opens a PostgreSQL connection pool for a primary database. It pings the database
|
||||
// to verify connectivity before returning.
|
||||
func Open(dsn string) (*db.Host, error) {
|
||||
d, err := open(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
if err := d.DB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("database ping failed on %v: %w", d.Addr, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// OpenReplica opens a PostgreSQL connection pool for a read replica. Unlike Open, it does
|
||||
// not ping the database, since replicas are health-checked in the background by db.DB.
|
||||
func OpenReplica(dsn string) (*db.Host, error) {
|
||||
return open(dsn)
|
||||
}
|
||||
|
||||
// open opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||
// the DSN before passing it to the driver.
|
||||
func open(dsn string) (*db.Host, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "postgres", "postgresql":
|
||||
// OK
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid database URL scheme %q, must be \"postgres\" or \"postgresql\" (URL: %s)", u.Scheme, censorPassword(u))
|
||||
}
|
||||
q := u.Query()
|
||||
maxOpenConns, err := extractIntParam(q, "pool_max_conns", 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maxIdleConns, err := extractIntParam(q, "pool_max_idle_conns", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxLifetime, err := extractDurationParam(q, "pool_conn_max_lifetime", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connMaxIdleTime, err := extractDurationParam(q, "pool_conn_max_idle_time", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
d, err := sql.Open("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.SetMaxOpenConns(maxOpenConns)
|
||||
if maxIdleConns > 0 {
|
||||
d.SetMaxIdleConns(maxIdleConns)
|
||||
}
|
||||
if connMaxLifetime > 0 {
|
||||
d.SetConnMaxLifetime(connMaxLifetime)
|
||||
}
|
||||
if connMaxIdleTime > 0 {
|
||||
d.SetConnMaxIdleTime(connMaxIdleTime)
|
||||
}
|
||||
return &db.Host{
|
||||
Addr: u.Host,
|
||||
DB: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// censorPassword returns a string representation of the URL with the password replaced by "*****".
|
||||
func censorPassword(u *url.URL) string {
|
||||
if password, hasPassword := u.User.Password(); hasPassword {
|
||||
return strings.Replace(u.String(), ":"+password+"@", ":*****@", 1)
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||
s := q.Get(key)
|
||||
if s == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
q.Del(key)
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
53
db/pg/pg_test.go
Normal file
53
db/pg/pg_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpen_InvalidScheme(t *testing.T) {
|
||||
_, err := Open("postgresql+psycopg2://user:pass@localhost/db")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `invalid database URL scheme "postgresql+psycopg2"`)
|
||||
require.Contains(t, err.Error(), "*****")
|
||||
require.NotContains(t, err.Error(), "pass")
|
||||
}
|
||||
|
||||
func TestOpen_InvalidURL(t *testing.T) {
|
||||
_, err := Open("not a valid url\x00")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid database URL")
|
||||
}
|
||||
|
||||
func TestCensorPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with password",
|
||||
url: "postgres://user:secret@localhost/db",
|
||||
expected: "postgres://user:*****@localhost/db",
|
||||
},
|
||||
{
|
||||
name: "without password",
|
||||
url: "postgres://localhost/db",
|
||||
expected: "postgres://localhost/db",
|
||||
},
|
||||
{
|
||||
name: "user only",
|
||||
url: "postgres://user@localhost/db",
|
||||
expected: "postgres://user@localhost/db",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.url)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, censorPassword(u))
|
||||
})
|
||||
}
|
||||
}
|
||||
64
db/test/test.go
Normal file
64
db/test/test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package dbtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const testPoolMaxConns = "2"
|
||||
|
||||
// CreateTestPostgresSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it.
|
||||
// It registers a cleanup function to drop the schema when the test finishes.
|
||||
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||
func CreateTestPostgresSchema(t *testing.T) string {
|
||||
t.Helper()
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||
}
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("pool_max_conns", testPoolMaxConns)
|
||||
u.RawQuery = q.Encode()
|
||||
dsn = u.String()
|
||||
setupHost, err := pg.Open(dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupHost.DB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupHost.DB.Close())
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
schemaDSN := u.String()
|
||||
t.Cleanup(func() {
|
||||
cleanHost, err := pg.Open(dsn)
|
||||
if err == nil {
|
||||
cleanHost.DB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanHost.DB.Close()
|
||||
}
|
||||
})
|
||||
return schemaDSN
|
||||
}
|
||||
|
||||
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *db.DB connection to it.
|
||||
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
||||
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||
func CreateTestPostgres(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
schemaDSN := CreateTestPostgresSchema(t)
|
||||
testHost, err := pg.Open(schemaDSN)
|
||||
require.Nil(t, err)
|
||||
d := db.New(testHost, nil)
|
||||
t.Cleanup(func() {
|
||||
d.Close()
|
||||
})
|
||||
return d
|
||||
}
|
||||
25
db/types.go
Normal file
25
db/types.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Beginner is an interface for types that can begin a database transaction.
|
||||
// Both *sql.DB and *DB implement this.
|
||||
type Beginner interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// Querier is an interface for types that can execute SQL queries.
|
||||
// *sql.DB, *sql.Tx, and *DB all implement this.
|
||||
type Querier interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// Host pairs a *sql.DB with the host:port it was opened against.
|
||||
type Host struct {
|
||||
Addr string // "host:port"
|
||||
DB *sql.DB
|
||||
healthy atomic.Bool
|
||||
}
|
||||
36
db/util.go
Normal file
36
db/util.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package db
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func ExecTx(db Beginner, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := f(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// QueryTx executes a function within a database transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
||||
func QueryTx[T any](db Beginner, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
430
docs/config.md
430
docs/config.md
@@ -53,6 +53,16 @@ Here are a few working sample configs using a `/etc/ntfy/server.yml` file:
|
||||
behind-proxy: true
|
||||
```
|
||||
|
||||
=== "server.yml (PostgreSQL, behind proxy)"
|
||||
``` yaml
|
||||
base-url: "https://ntfy.example.com"
|
||||
listen-http: ":2586"
|
||||
database-url: "postgres://ntfy:mypassword@db.example.com:5432/ntfy?sslmode=require"
|
||||
attachment-cache-dir: "/var/cache/ntfy/attachments"
|
||||
behind-proxy: true
|
||||
auth-default-access: "deny-all"
|
||||
```
|
||||
|
||||
=== "server.yml (ntfy.sh config)"
|
||||
``` yaml
|
||||
# All the things: Behind a proxy, Firebase, cache, attachments,
|
||||
@@ -125,16 +135,349 @@ using Docker Compose (i.e. `docker-compose.yml`):
|
||||
command: serve
|
||||
```
|
||||
|
||||
## Config generator
|
||||
|
||||
This generator helps you configure your self-hosted ntfy instance. It's not fully featured, but it is a good starting point. Please refer to the relevant sections in the doc for more details.
|
||||
|
||||
<div style="text-align: center;">
|
||||
<button type="button" id="cg-open-btn" class="cg-open-btn">Open config generator</button>
|
||||
</div>
|
||||
|
||||
<figure markdown style="padding-left: 50px; padding-right: 50px; cursor: pointer;" onclick="document.getElementById('cg-open-btn').click();">
|
||||
<img src="../../static/img/config-generator.png"/>
|
||||
<figcaption>The config generator helps you create a custom config for your self-hosted ntfy instance. Click to open.</figcaption>
|
||||
</figure>
|
||||
|
||||
<div id="cg-modal" class="cg-modal" style="display:none">
|
||||
<div class="cg-modal-backdrop"></div>
|
||||
<div class="cg-modal-dialog">
|
||||
<div class="cg-modal-header">
|
||||
<div class="cg-modal-header-left">
|
||||
<span class="cg-modal-title">Config generator</span><span class="cg-badge-beta">BETA</span>
|
||||
<span class="cg-modal-desc">This generator helps you configure your self-hosted ntfy instance. It's not fully featured, but it is a good starting point.</span>
|
||||
</div>
|
||||
<div class="cg-modal-header-actions">
|
||||
<button type="button" id="cg-reset-btn" class="cg-modal-reset" title="Reset all values">Reset</button>
|
||||
<button type="button" id="cg-close-btn" class="cg-modal-close" title="Close">×</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-modal-body">
|
||||
<div class="cg-mobile-toggle">
|
||||
<button class="cg-mobile-toggle-btn active" data-show="left">Edit</button>
|
||||
<button class="cg-mobile-toggle-btn" data-show="right">Preview</button>
|
||||
</div>
|
||||
<div id="cg-left">
|
||||
<div class="cg-nav">
|
||||
<div class="cg-nav-tab active" data-panel="cg-panel-general">General</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-database" id="cg-nav-database">Database</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-auth" id="cg-nav-auth">Users</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-cache" id="cg-nav-cache">Message Cache</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-attach" id="cg-nav-attach">Attachments</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-webpush" id="cg-nav-webpush">Web Push</div>
|
||||
<div class="cg-nav-tab cg-hidden" data-panel="cg-panel-email" id="cg-nav-email">Email</div>
|
||||
</div>
|
||||
<div class="cg-panels">
|
||||
<div class="cg-panel active" id="cg-panel-general">
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>What URL will ntfy be reachable on?</label>
|
||||
<input type="text" data-key="base-url" placeholder="https://ntfy.example.com">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Will ntfy run behind a proxy (e.g. nginx, Caddy)? <a href="/config/#behind-a-proxy-tls-etc" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-proxy" value="no" checked><span>No</span></label>
|
||||
<label><input type="radio" name="cg-proxy" value="yes"><span>Yes</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Will this ntfy server be open or private? <a href="/config/#access-control" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-server-type" value="open" checked><span>Open</span></label>
|
||||
<label><input type="radio" name="cg-server-type" value="private"><span>Private</span></label>
|
||||
<label><input type="radio" name="cg-server-type" value="custom"><span>Custom</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Will iOS/iPhone users use this server? <a href="/config/#ios-instant-notifications" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-ios" value="no" checked><span>No</span></label>
|
||||
<label><input type="radio" name="cg-ios" value="yes"><span>Yes</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Do you want to use ntfy as a UnifiedPush distributor? <a href="/config/#example-unifiedpush" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-unifiedpush" value="no" checked><span>No</span></label>
|
||||
<label><input type="radio" name="cg-unifiedpush" value="yes"><span>Yes</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field">
|
||||
<label>Which features do you want to enable?</label>
|
||||
<div class="cg-feature-grid">
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-auth"> User management and access control</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-auth">Configure</button></div>
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-cache"> Persistent message cache</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-cache">Configure</button></div>
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-attach"> Attachments</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-attach">Configure</button></div>
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-webpush"> Web push</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-webpush">Configure</button></div>
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-smtp-out"> Email notifications</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-email">Configure</button></div>
|
||||
<div class="cg-feature-row"><label><input type="checkbox" id="cg-feat-smtp-in"> Email publishing</label><button type="button" class="cg-btn-configure cg-hidden" data-panel="cg-panel-email">Configure</button></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field cg-hidden" id="cg-wizard-db">
|
||||
<label>Which database backend would you like to use? <a href="/config/#database-options" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-db-type" value="sqlite" checked><span>SQLite</span></label>
|
||||
<label><input type="radio" name="cg-db-type" value="postgres"><span>PostgreSQL</span></label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-auth">
|
||||
<div class="cg-panel-desc">Configure user management, access control, and provisioned users/ACLs. See <a href="/config/#access-control" target="_blank">access control</a> for details.</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Where should the user database be stored?</label>
|
||||
<input type="text" data-key="auth-file" placeholder="/var/lib/ntfy/auth.db">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>What should the default access policy be? <a href="/config/#access-control" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<select id="cg-default-access-select">
|
||||
<option value="read-write" selected>Read & Write</option>
|
||||
<option value="read-only">Read Only</option>
|
||||
<option value="write-only">Write Only</option>
|
||||
<option value="deny-all">Deny All</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Should login to the web app be enabled?</label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-login-mode" value="disabled" checked><span>Disabled</span></label>
|
||||
<label><input type="radio" name="cg-login-mode" value="enabled"><span>Enabled</span></label>
|
||||
<label><input type="radio" name="cg-login-mode" value="required"><span>Required</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Should it be possible to sign up via the web app?</label>
|
||||
<div class="cg-btn-group">
|
||||
<label><input type="radio" name="cg-enable-signup" value="no" checked><span>No</span></label>
|
||||
<label><input type="radio" name="cg-enable-signup" value="yes"><span>Yes</span></label>
|
||||
</div>
|
||||
</div>
|
||||
<input type="hidden" data-key="auth-default-access">
|
||||
<input type="checkbox" data-key="enable-login" id="cg-enable-login-hidden" style="display:none">
|
||||
<input type="checkbox" data-key="require-login" id="cg-require-login-hidden" style="display:none">
|
||||
<input type="checkbox" data-key="enable-signup" id="cg-enable-signup-hidden" style="display:none">
|
||||
<div class="cg-field">
|
||||
<label>Provisioned users <a href="/config/#users-and-roles" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-repeatable-container" id="cg-auth-users-container"></div>
|
||||
<button type="button" class="cg-btn-add" data-add-type="user">+ Add user</button>
|
||||
</div>
|
||||
<div class="cg-field">
|
||||
<label>Provisioned ACLs <a href="/config/#access-control-list-acl" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-repeatable-container" id="cg-auth-acls-container"></div>
|
||||
<button type="button" class="cg-btn-add" data-add-type="acl">+ Add ACL</button>
|
||||
</div>
|
||||
<div class="cg-field">
|
||||
<label>Provisioned tokens <a href="/config/#access-tokens" target="_blank" class="cg-help"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" fill="currentColor" viewBox="0 0 16 16"><path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0M5.496 6.033h.825c.138 0 .248-.113.266-.25.09-.656.54-1.134 1.342-1.134.686 0 1.314.343 1.314 1.168 0 .635-.374.927-.965 1.371-.673.489-1.206 1.06-1.168 1.987l.003.217a.25.25 0 0 0 .25.246h.811a.25.25 0 0 0 .25-.25v-.105c0-.718.273-.927 1.01-1.486.609-.463 1.244-.977 1.244-2.056 0-1.511-1.276-2.241-2.673-2.241-1.267 0-2.655.59-2.75 2.286a.237.237 0 0 0 .241.247m2.325 6.443c.61 0 1.029-.394 1.029-.927 0-.552-.42-.94-1.029-.94-.584 0-1.009.388-1.009.94 0 .533.425.927 1.01.927z"/></svg></a></label>
|
||||
<div class="cg-repeatable-container" id="cg-auth-tokens-container"></div>
|
||||
<button type="button" class="cg-btn-add" data-add-type="token">+ Add token</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-cache">
|
||||
<div class="cg-panel-desc">Configure the message cache to allow devices to retrieve missed notifications. See <a href="/config/#message-cache" target="_blank">message cache</a> for details.</div>
|
||||
<div class="cg-field cg-inline-field" id="cg-cache-file-field">
|
||||
<label>Where should the cache be stored?</label>
|
||||
<input type="text" data-key="cache-file" placeholder="/var/cache/ntfy/cache.db">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>How long should messages be cached?</label>
|
||||
<input type="text" data-key="cache-duration" placeholder="12h">
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-attach">
|
||||
<div class="cg-panel-desc">Allow users to upload and attach files to notifications. See <a href="/config/#attachments" target="_blank">attachments</a> for details.</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Where should attachments be stored?</label>
|
||||
<input type="text" data-key="attachment-cache-dir" placeholder="/var/cache/ntfy/attachments">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Max file size per attachment?</label>
|
||||
<input type="text" data-key="attachment-file-size-limit" placeholder="15M">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Total attachment storage limit?</label>
|
||||
<input type="text" data-key="attachment-total-size-limit" placeholder="5G">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>How long before attachments expire?</label>
|
||||
<input type="text" data-key="attachment-expiry-duration" placeholder="3h">
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-webpush">
|
||||
<div class="cg-panel-desc">Enable browser push notifications via the Web Push API. VAPID keys are generated automatically. See <a href="/config/#web-push" target="_blank">web push</a> for details.</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Where should web push data be stored?</label>
|
||||
<input type="text" data-key="web-push-file" placeholder="/var/lib/ntfy/webpush.db">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Contact email address</label>
|
||||
<input type="text" data-key="web-push-email-address" placeholder="admin@example.com">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Private key</label>
|
||||
<input type="text" data-key="web-push-private-key" placeholder="Auto-generated" readonly>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Public key</label>
|
||||
<input type="text" data-key="web-push-public-key" placeholder="Auto-generated" readonly>
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label></label>
|
||||
<button type="button" id="cg-regen-keys" class="cg-btn-add">Regenerate keys</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-email">
|
||||
<div class="cg-panel-desc">Configure outgoing email notifications and/or incoming email publishing. See <a href="/config/#e-mail-notifications" target="_blank">email notifications</a> and <a href="/config/#e-mail-publishing" target="_blank">email publishing</a> for details.</div>
|
||||
<div id="cg-email-out-section" class="cg-hidden">
|
||||
<div class="cg-field"><label><strong>Outgoing (notifications)</strong></label></div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>SMTP server address</label>
|
||||
<input type="text" data-key="smtp-sender-addr" placeholder="smtp.example.com:587">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Sender email</label>
|
||||
<input type="text" data-key="smtp-sender-from" placeholder="ntfy@example.com">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>SMTP username</label>
|
||||
<input type="text" data-key="smtp-sender-user" placeholder="Username">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>SMTP password</label>
|
||||
<input type="password" data-key="smtp-sender-pass" placeholder="Password">
|
||||
</div>
|
||||
</div>
|
||||
<div id="cg-email-in-section" class="cg-hidden">
|
||||
<div class="cg-field"><label><strong>Incoming (publishing)</strong></label></div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Listen address</label>
|
||||
<input type="text" data-key="smtp-server-listen" placeholder=":25">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Domain</label>
|
||||
<input type="text" data-key="smtp-server-domain" placeholder="ntfy.example.com">
|
||||
</div>
|
||||
<div class="cg-field cg-inline-field">
|
||||
<label>Address prefix</label>
|
||||
<input type="text" data-key="smtp-server-addr-prefix" placeholder="ntfy-">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="cg-panel" id="cg-panel-database">
|
||||
<div class="cg-panel-desc">Configure the PostgreSQL connection. See <a href="/config/#postgresql-experimental" target="_blank">PostgreSQL</a> for details.</div>
|
||||
<div class="cg-field">
|
||||
<label>Database URL</label>
|
||||
<input type="text" data-key="database-url" placeholder="postgres://user:pass@host:5432/ntfy">
|
||||
</div>
|
||||
</div>
|
||||
<input type="hidden" data-key="upstream-base-url">
|
||||
<input type="checkbox" data-key="behind-proxy" id="cg-behind-proxy" style="display:none">
|
||||
</div>
|
||||
</div>
|
||||
<div id="cg-right">
|
||||
<div class="cg-output-tabs">
|
||||
<div class="cg-output-tab active" data-format="server-yml">server.yml</div>
|
||||
<div class="cg-output-tab" data-format="docker-compose">docker-compose.yml</div>
|
||||
<div class="cg-output-tab" data-format="env-vars">Env variables</div>
|
||||
<button type="button" id="cg-copy-btn" class="cg-btn-copy" title="Copy to clipboard"><svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg></button>
|
||||
</div>
|
||||
<div class="cg-output-wrap">
|
||||
<pre><code id="cg-code"><span class="cg-empty-msg">Configure options on the left to generate your config...</span></code></pre>
|
||||
<div id="cg-warnings" class="cg-hidden"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Database options
|
||||
ntfy uses a database for storing messages ([message cache](#message-cache)), users and [access control](#access-control), and [web push](#web-push) subscriptions.
|
||||
You can choose between **SQLite** and **PostgreSQL** as the database backend.
|
||||
|
||||
### SQLite
|
||||
By default, ntfy uses SQLite with separate database files for each store. This is the simplest setup and requires
|
||||
no external dependencies:
|
||||
|
||||
* `cache-file`: Database file for the [message cache](#message-cache).
|
||||
* `auth-file`: Database file for authentication and [access control](#access-control). If set, enables auth.
|
||||
* `web-push-file`: Database file for [web push](#web-push) subscriptions.
|
||||
|
||||
### PostgreSQL (EXPERIMENTAL)
|
||||
As an alternative, you can configure ntfy to use PostgreSQL for **all** database-backed stores by setting the
|
||||
`database-url` option to a PostgreSQL connection string.
|
||||
|
||||
When `database-url` is set, ntfy will use PostgreSQL for the [message cache](#message-cache),
|
||||
[access control](#access-control), and [web push](#web-push) subscriptions instead of SQLite. The `cache-file`,
|
||||
`auth-file`, and `web-push-file` options **must not** be set in this case.
|
||||
|
||||
Note that setting `database-url` implicitly enables authentication and access control (equivalent to setting
|
||||
`auth-file` with SQLite). The default access is `read-write`, so anonymous users can still read and write to all
|
||||
topics. To restrict access, set `auth-default-access` to `deny-all` (see [access control](#access-control)).
|
||||
|
||||
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
|
||||
|
||||
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
|
||||
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages, checking access permissions, etc)
|
||||
are distributed across the replicas using round-robin, while all writes and correctness-critical reads continue to go
|
||||
to the primary. If a replica becomes unhealthy, ntfy automatically falls back to the primary until the replica recovers.
|
||||
You can also set this via the environment variable `NTFY_DATABASE_REPLICA_URLS` (comma-separated) or the command line
|
||||
flag `--database-replica-urls`.
|
||||
|
||||
Examples:
|
||||
|
||||
=== "Simple"
|
||||
```yaml
|
||||
database-url: "postgres://user:pass@host:5432/ntfy"
|
||||
```
|
||||
|
||||
=== "With SSL and pool tuning"
|
||||
```yaml
|
||||
database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50&pool_conn_max_idle_time=5m"
|
||||
```
|
||||
|
||||
=== "With CA certificate"
|
||||
```yaml
|
||||
database-url: "postgres://user:pass@host:25060/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
|
||||
```
|
||||
|
||||
=== "With read replicas"
|
||||
```yaml
|
||||
database-url: "postgres://user:pass@primary:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
|
||||
database-replica-urls:
|
||||
- "postgres://user:pass@replica1:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
|
||||
- "postgres://user:pass@replica2:5432/ntfy?sslmode=require&sslrootcert=/etc/ntfy/db-ca-cert.pem&pool_max_conns=30"
|
||||
```
|
||||
|
||||
The database URL supports the standard [PostgreSQL connection parameters](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
|
||||
as query parameters, such as `sslmode`, `connect_timeout`, `sslcert`, `sslkey`, `sslrootcert`, and `application_name`.
|
||||
See the [pgx driver documentation](https://pkg.go.dev/github.com/jackc/pgx/v5) for the full list of supported parameters.
|
||||
|
||||
In addition, ntfy supports the following custom query parameters to tune the connection pool (these apply to both
|
||||
the primary and replica URLs):
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|---------------------------|---------|----------------------------------------------------------------------------------|
|
||||
| `pool_max_conns` | 10 | Maximum number of open connections to the database |
|
||||
| `pool_max_idle_conns` | - | Maximum number of idle connections in the pool |
|
||||
| `pool_conn_max_lifetime` | - | Maximum amount of time a connection may be reused (Go duration, e.g. `5m`, `1h`) |
|
||||
| `pool_conn_max_idle_time` | - | Maximum amount of time a connection may be idle (Go duration, e.g. `30s`, `5m`) |
|
||||
|
||||
|
||||
## Message cache
|
||||
If desired, ntfy can temporarily keep notifications in an in-memory or an on-disk cache. Caching messages for a short period
|
||||
of time is important to allow [phones](subscribe/phone.md) and other devices with brittle Internet connections to be able to retrieve
|
||||
notifications that they may have missed.
|
||||
|
||||
By default, ntfy keeps messages **in-memory for 12 hours**, which means that **cached messages do not survive an application
|
||||
restart**. You can override this behavior using the following config settings:
|
||||
restart**. You can override this behavior by setting `cache-file` (SQLite) or `database-url` (PostgreSQL).
|
||||
|
||||
* `cache-file`: if set, ntfy will store messages in a SQLite based cache (default is empty, which means in-memory cache).
|
||||
**This is required if you'd like messages to be retained across restarts**.
|
||||
* `cache-duration`: defines the duration for which messages are stored in the cache (default is `12h`).
|
||||
|
||||
You can also entirely disable the cache by setting `cache-duration` to `0`. When the cache is disabled, messages are only
|
||||
@@ -146,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri
|
||||
|
||||
## Attachments
|
||||
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
|
||||
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
|
||||
Once these options are set and the directory is writable by the server user, you can upload attachments via PUT.
|
||||
this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored
|
||||
either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`).
|
||||
Once configured, you can upload attachments via PUT.
|
||||
|
||||
By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues
|
||||
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
|
||||
By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues
|
||||
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
|
||||
feature) to download the file. The following config options are relevant to attachments:
|
||||
|
||||
* `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs
|
||||
* `attachment-cache-dir` is the cache directory for attached files
|
||||
* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G)
|
||||
* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`)
|
||||
* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`)
|
||||
* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G)
|
||||
* `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M)
|
||||
* `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h)
|
||||
|
||||
Here's an example config using mostly the defaults (except for the cache directory, which is empty by default):
|
||||
### Filesystem storage
|
||||
Here's an example config using the local filesystem for attachment storage:
|
||||
|
||||
=== "/etc/ntfy/server.yml (minimal)"
|
||||
``` yaml
|
||||
@@ -178,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo
|
||||
visitor-attachment-daily-bandwidth-limit: "500M"
|
||||
```
|
||||
|
||||
### S3 storage
|
||||
As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3,
|
||||
MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage.
|
||||
|
||||
The `attachment-s3-url` option uses the following format:
|
||||
|
||||
```
|
||||
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
```
|
||||
|
||||
When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores).
|
||||
|
||||
=== "/etc/ntfy/server.yml (AWS S3)"
|
||||
``` yaml
|
||||
base-url: "https://ntfy.sh"
|
||||
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1"
|
||||
```
|
||||
|
||||
=== "/etc/ntfy/server.yml (MinIO/custom endpoint)"
|
||||
``` yaml
|
||||
base-url: "https://ntfy.sh"
|
||||
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com"
|
||||
```
|
||||
|
||||
Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit`
|
||||
and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse.
|
||||
|
||||
@@ -185,14 +555,15 @@ and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is
|
||||
By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how
|
||||
ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization.
|
||||
|
||||
ntfy's auth is implemented with a simple [SQLite](https://www.sqlite.org/)-based backend. It implements two roles
|
||||
(`user` and `admin`) and per-topic `read` and `write` permissions using an [access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list).
|
||||
Access control entries can be applied to users as well as the special everyone user (`*`), which represents anonymous API access.
|
||||
ntfy's auth implements two roles (`user` and `admin`) and per-topic `read` and `write` permissions using an
|
||||
[access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list). Access control entries can be applied
|
||||
to users as well as the special everyone user (`*`), which represents anonymous API access.
|
||||
|
||||
To set up auth, **configure the following options**:
|
||||
|
||||
* `auth-file` is the user/access database; it is created automatically if it doesn't already exist; suggested
|
||||
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used)
|
||||
* `auth-file` is the user/access database (SQLite); it is created automatically if it doesn't already exist; suggested
|
||||
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used). Alternatively, if `database-url` is set,
|
||||
auth is automatically enabled using PostgreSQL (see [database options](#database-options)).
|
||||
* `auth-default-access` defines the default/fallback access if no access control entry is found; it can be
|
||||
set to `read-write` (default), `read-only`, `write-only` or `deny-all`. **If you are setting up a private instance,
|
||||
you'll want to set this to `deny-all`** (see [private instance example](#example-private-instance)).
|
||||
@@ -454,7 +825,7 @@ Here's an example:
|
||||
```
|
||||
# Comma-separated list
|
||||
NTFY_AUTH_FILE='/var/lib/ntfy/user.db'
|
||||
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,ben:$2a$10$NKbrNb7HPMjtQXWJ0f1pouw03LDLT/WzlO9VAv44x84bRCkh19h6m:user'
|
||||
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,backup-service:$2a$10$NKbrNb7HPMjtQXWJ0f1pouw03LDLT/WzlO9VAv44x84bRCkh19h6m:user'
|
||||
NTFY_AUTH_TOKENS='phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76,backup-service:tk_f099we8uzj7xi5qshzajwp6jffvkz:Backup script'
|
||||
```
|
||||
|
||||
@@ -470,7 +841,8 @@ and access tokens in the `auth-tokens` section (see [access tokens via the confi
|
||||
|
||||
Here's an example that defines a single admin user `phil` with the password `mypass`, and a regular user `backup-script`
|
||||
with the password `backup-script`. The admin user has full access to all topics, while regular user can only
|
||||
access the `backups` topic with read-write permissions. The `auth-default-access` is set to `deny-all`, which means
|
||||
access the `backups` topic with read-write permissions. `phil` has a token `tk_3gd7d2yftt4b8ixyfe9mnmro88o76`
|
||||
with the label "My personal token". The `auth-default-access` is set to `deny-all`, which means
|
||||
that all other users and anonymous access are denied by default.
|
||||
|
||||
=== "Config via /etc/ntfy/server.yml"
|
||||
@@ -481,7 +853,7 @@ that all other users and anonymous access are denied by default.
|
||||
- "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin"
|
||||
- "backup-script:$2a$10$/ehiQt.w7lhTmHXq.RNsOOkIwiPPeWFIzWYO3DRxNixnWKLX8.uj.:user"
|
||||
auth-access:
|
||||
- "backup-service:backups:rw"
|
||||
- "backup-script:backups:rw"
|
||||
auth-tokens:
|
||||
- "phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76:My personal token"
|
||||
```
|
||||
@@ -491,7 +863,7 @@ that all other users and anonymous access are denied by default.
|
||||
NTFY_AUTH_FILE='/var/lib/ntfy/user.db'
|
||||
NTFY_AUTH_DEFAULT_ACCESS='deny-all'
|
||||
NTFY_AUTH_USERS='phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:admin,backup-script:$2a$10$/ehiQt.w7lhTmHXq.RNsOOkIwiPPeWFIzWYO3DRxNixnWKLX8.uj.:user'
|
||||
NTFY_AUTH_ACCESS='backup-service:backups:rw'
|
||||
NTFY_AUTH_ACCESS='backup-script:backups:rw'
|
||||
NTFY_AUTH_TOKENS='phil:tk_3gd7d2yftt4b8ixyfe9mnmro88o76:My personal token'
|
||||
```
|
||||
|
||||
@@ -1141,12 +1513,15 @@ a database to keep track of the browser's subscriptions, and an admin email addr
|
||||
|
||||
- `web-push-public-key` is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||
- `web-push-private-key` is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
|
||||
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db` (not required if `database-url` is set)
|
||||
- `web-push-email-address` is the admin email address send to the push provider, e.g. `sysadmin@example.com`
|
||||
- `web-push-startup-queries` is an optional list of queries to run on startup`
|
||||
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
|
||||
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
|
||||
|
||||
Alternatively, you can use PostgreSQL instead of SQLite by setting `database-url`
|
||||
(see [PostgreSQL database](#postgresql-experimental)).
|
||||
|
||||
Limitations:
|
||||
|
||||
- Like foreground browser notifications, background push notifications require the web app to be served over HTTPS. A _valid_
|
||||
@@ -1172,9 +1547,10 @@ web-push-file: /var/cache/ntfy/webpush.db
|
||||
web-push-email-address: sysadmin@example.com
|
||||
```
|
||||
|
||||
The `web-push-file` is used to store the push subscriptions. Unused subscriptions will send out a warning after 55 days,
|
||||
and will automatically expire after 60 days (default). If the gateway returns an error (e.g. 410 Gone when a user has unsubscribed),
|
||||
subscriptions are also removed automatically.
|
||||
The `web-push-file` is used to store the push subscriptions in a local SQLite database. Alternatively, if `database-url`
|
||||
is set, subscriptions are stored in PostgreSQL and `web-push-file` is not required. Unused subscriptions will send out
|
||||
a warning after 55 days, and will automatically expire after 60 days (default). If the gateway returns an error
|
||||
(e.g. 410 Gone when a user has unsubscribed), subscriptions are also removed automatically.
|
||||
|
||||
The web app refreshes subscriptions on start and regularly on an interval, but this file should be persisted across restarts. If the subscription
|
||||
file is deleted or lost, any web apps that aren't open will not receive new web push notifications until you open then.
|
||||
@@ -1755,17 +2131,20 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
|
||||
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
|
||||
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
|
||||
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
|
||||
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for all database-backed stores (message cache, user manager, web push) instead of SQLite. See [database options](#database-options). |
|
||||
| `database-replica-urls` | `NTFY_DATABASE_REPLICA_URLS` | *list of strings (connection URLs)* | - | PostgreSQL read replica connection strings. Non-critical read-only queries are distributed across replicas (round-robin) with automatic fallback to primary. Requires `database-url`. See [read replicas](#read-replicas). |
|
||||
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
|
||||
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
|
||||
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
|
||||
| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) |
|
||||
| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) |
|
||||
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control. If set, enables authentication and access control. See [access control](#access-control). |
|
||||
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). |
|
||||
| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. |
|
||||
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
|
||||
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
|
||||
| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header |
|
||||
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. |
|
||||
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. |
|
||||
| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. |
|
||||
| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. |
|
||||
| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. |
|
||||
| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. |
|
||||
@@ -1868,6 +2247,7 @@ OPTIONS:
|
||||
--auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES]
|
||||
--auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS]
|
||||
--attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR]
|
||||
--attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL]
|
||||
--attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
|
||||
--attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
|
||||
--attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION]
|
||||
|
||||
@@ -340,10 +340,6 @@ Then either follow the steps for building with or without Firebase.
|
||||
Without Firebase, you may want to still change the default `app_base_url` in [values.xml](https://github.com/binwiederhier/ntfy-android/blob/main/app/src/main/res/values/values.xml)
|
||||
if you're self-hosting the server. Then run:
|
||||
```
|
||||
# Remove Google dependencies (FCM)
|
||||
sed -i -e '/google-services/d' build.gradle
|
||||
sed -i -e '/google-services/d' app/build.gradle
|
||||
|
||||
# To build an unsigned .apk (app/build/outputs/apk/fdroid/*.apk)
|
||||
./gradlew assembleFdroidRelease
|
||||
|
||||
@@ -351,6 +347,8 @@ sed -i -e '/google-services/d' app/build.gradle
|
||||
./gradlew bundleFdroidRelease
|
||||
```
|
||||
|
||||
The F-Droid flavor automatically excludes Google Services dependencies.
|
||||
|
||||
### Build Play flavor (FCM)
|
||||
!!! info
|
||||
I do build the ntfy Android app using IntelliJ IDEA (Android Studio), so I don't know if these Gradle commands will
|
||||
|
||||
@@ -661,6 +661,8 @@ Add the following function and alias to your `.bashrc` or `.bash_profile`:
|
||||
local token=$(< ~/.ntfy_token) # Securely read the token
|
||||
local status_icon="$([ $exit_status -eq 0 ] && echo magic_wand || echo warning)"
|
||||
local last_command=$(history | tail -n1 | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
|
||||
# for zsh users, use the same sed pattern but get the history differently.
|
||||
# local last_command=$(history "$HISTCMD" | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
|
||||
|
||||
curl -s -X POST "https://n.example.dev/alerts" \
|
||||
-H "Authorization: Bearer $token" \
|
||||
@@ -692,4 +694,4 @@ To test failure notifications:
|
||||
false; alert # Always fails (exit 1)
|
||||
ls --invalid; alert # Invalid option
|
||||
cat nonexistent_file; alert # File not found
|
||||
```
|
||||
```
|
||||
|
||||
@@ -30,37 +30,37 @@ deb/rpm packages.
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
@@ -116,7 +116,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -124,7 +124,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -132,7 +132,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -140,7 +140,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -150,28 +150,28 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
@@ -213,18 +213,18 @@ pkg install go-ntfy
|
||||
|
||||
## macOS
|
||||
The [ntfy CLI](subscribe/cli.md) (`ntfy publish` and `ntfy subscribe` only) is supported on macOS as well.
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz),
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz),
|
||||
extract it and place it somewhere in your `PATH` (e.g. `/usr/local/bin/ntfy`).
|
||||
|
||||
If run as `root`, ntfy will look for its config at `/etc/ntfy/client.yml`. For all other users, it'll look for it at
|
||||
`~/Library/Application Support/ntfy/client.yml` (sample included in the tarball).
|
||||
|
||||
```bash
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz > ntfy_2.16.0_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.16.0_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz > ntfy_2.19.2_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.19.2_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
mkdir ~/Library/Application\ Support/ntfy
|
||||
cp ntfy_2.16.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
cp ntfy_2.19.2_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
ntfy --help
|
||||
```
|
||||
|
||||
@@ -245,7 +245,7 @@ brew install ntfy
|
||||
The ntfy server and CLI are fully supported on Windows. You can run the ntfy server directly or as a Windows service.
|
||||
To install, you can either
|
||||
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_windows_amd64.zip),
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_windows_amd64.zip),
|
||||
extract it and place the `ntfy.exe` binary somewhere in your `%Path%`.
|
||||
* Or install ntfy from the [Scoop](https://scoop.sh) main repository via `scoop install ntfy`
|
||||
|
||||
|
||||
@@ -184,6 +184,8 @@ I've added a ⭐ to projects or posts that have a significant following, or had
|
||||
- [BRun](https://github.com/cbrake/brun) - Native Linux automation platform connecting triggers to actions without containers (Go)
|
||||
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
|
||||
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
|
||||
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
|
||||
- [zabbix-ntfy](https://github.com/torgrimt/zabbix-ntfy) - Zabbix server Mediatype to add support for ntfy.sh services
|
||||
|
||||
## Blog + forum posts
|
||||
|
||||
@@ -304,7 +306,7 @@ ntfy community. Thanks to everyone running a public server. **You guys rock!**
|
||||
| URL | Country |
|
||||
|---------------------------------------------------|--------------------|
|
||||
| [ntfy.sh](https://ntfy.sh/) (*Official*) | 🇺🇸 United States |
|
||||
| [ntfy.tedomum.net](https://ntfy.tedomum.net/) | 🇫🇷 France |
|
||||
| [ntfy.tedomum.fr](https://ntfy.tedomum.fr/) | 🇫🇷 France |
|
||||
| [ntfy.jae.fi](https://ntfy.jae.fi/) | 🇫🇮 Finland |
|
||||
| [ntfy.adminforge.de](https://ntfy.adminforge.de/) | 🇩🇪 Germany |
|
||||
| [ntfy.envs.net](https://ntfy.envs.net) | 🇩🇪 Germany |
|
||||
|
||||
@@ -2304,6 +2304,11 @@ _Supported on:_ :material-android: :material-firefox:
|
||||
The `copy` action **copies a given value to the clipboard when the action button is tapped**. This is useful for
|
||||
one-time passcodes, tokens, or any other value you want to quickly copy without opening the full notification.
|
||||
|
||||
!!! info
|
||||
The copy action button is only shown in the web app and Android app notification list, **not** in browser desktop
|
||||
notifications. This is because browsers do not allow clipboard access from notification actions without direct
|
||||
user interaction with the page.
|
||||
|
||||
Here's an example using the [`X-Actions` header](#using-a-header):
|
||||
|
||||
=== "Command line (curl)"
|
||||
|
||||
174
docs/releases.md
174
docs/releases.md
@@ -6,12 +6,145 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
|
||||
| Component | Version | Release date |
|
||||
|------------------|---------|--------------|
|
||||
| ntfy server | v2.16.0 | Jan 19, 2026 |
|
||||
| ntfy Android app | v1.22.2 | Jan 25, 2026 |
|
||||
| ntfy server | v2.19.2 | Mar 16, 2026 |
|
||||
| ntfy Android app | v1.24.0 | Mar 5, 2026 |
|
||||
| ntfy iOS app | v1.3 | Nov 26, 2023 |
|
||||
|
||||
Please check out the release notes for [upcoming releases](#not-released-yet) below.
|
||||
|
||||
### ntfy server v2.19.2
|
||||
Released March 16, 2026
|
||||
|
||||
This is another small bugfix release for PostgreSQL, avoiding races between primary and read replica, as well as to
|
||||
further reduce primary load.
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix race condition in web push subscription causing FK constraint violation when concurrent requests hit the same endpoint
|
||||
* Route authorization query to read-only database replica to reduce primary database load
|
||||
|
||||
## ntfy server v2.19.1
|
||||
Released March 15, 2026
|
||||
|
||||
This is a bugfix release to avoid PostgreSQL insert failures due to invalid UTF-8 messages. It also fixes `database-url`
|
||||
validation incorrectly rejecting `postgresql://` connection strings.
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix invalid UTF-8 in HTTP headers (e.g. Latin-1 encoded text) causing PostgreSQL insert failures and dropping entire message batches
|
||||
* Fix `database-url` validation rejecting `postgresql://` connection strings ([#1657](https://github.com/binwiederhier/ntfy/issues/1657)/[#1658](https://github.com/binwiederhier/ntfy/pull/1658))
|
||||
|
||||
## ntfy server v2.19.0
|
||||
Released March 15, 2026
|
||||
|
||||
This is a fast-follow release that enables Postgres read replica support.
|
||||
|
||||
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
|
||||
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages,
|
||||
checking access permissions, etc) are distributed across the replicas using round-robin, while all writes and
|
||||
correctness-critical reads continue to go to the primary. If a replica becomes unhealthy, ntfy automatically falls back
|
||||
to the primary until the replica recovers.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Support [PostgreSQL read replicas](config.md#postgresql-experimental) for offloading non-critical read queries via `database-replica-urls` config option ([#1648](https://github.com/binwiederhier/ntfy/pull/1648))
|
||||
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files ([#1654](https://github.com/binwiederhier/ntfy/pull/1654))
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
|
||||
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
|
||||
|
||||
## ntfy server v2.18.0
|
||||
Released March 7, 2026
|
||||
|
||||
This is the biggest release I've ever done on the server. It's 14,997 added lines of code, and 10,202 lines removed, all from
|
||||
one [pull request](https://github.com/binwiederhier/ntfy/pull/1619) that adds [PostgreSQL support](config.md#postgresql-experimental).
|
||||
|
||||
The code was written by Cursor and Claude, but reviewed and heavily tested over 2-3 weeks by me. I created comparison documents,
|
||||
went through all queries multiple times and reviewed the logic over and over again. I also did load tests and manual regression tests,
|
||||
which took lots of evenings.
|
||||
|
||||
I'll not instantly switch ntfy.sh over. Instead, I'm kindly asking the community to test the Postgres support and report back to me
|
||||
if things are working (or not working). There is a [one-off migration tool](https://github.com/binwiederhier/ntfy/tree/main/tools/pgimport) (entirely written by AI) that you can use to migrate.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Add experimental [PostgreSQL support](config.md#postgresql-experimental) as an alternative database backend (message cache, user manager, web push subscriptions) via `database-url` config option ([#1114](https://github.com/binwiederhier/ntfy/issues/1114)/[#1619](https://github.com/binwiederhier/ntfy/pull/1619), thanks to [@brettinternet](https://github.com/brettinternet) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Preserve `<br>` line breaks in HTML-only emails received via SMTP ([#690](https://github.com/binwiederhier/ntfy/issues/690), [#1620](https://github.com/binwiederhier/ntfy/pull/1620), thanks to [@uzkikh](https://github.com/uzkikh) for the fix and to [@teastrainer](https://github.com/teastrainer) for reporting)
|
||||
|
||||
## ntfy Android v1.24.0
|
||||
Released March 5, 2026
|
||||
|
||||
This is a tiny release that will revert the "reconnecting ..." behavior of the foreground notification. Lots of people
|
||||
have complained about it, so I'm replacing it with a notification that shows up when the server connection has failed
|
||||
for >15 minutes, hoping that people will be less annoyed by that.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Show notification when connection to server has been lost for 15+ minutes, with dismiss, snooze and never-show-again actions
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix crash in settings when fragment is detached during backup/restore or log operations
|
||||
|
||||
## ntfy Android v1.23.0
|
||||
Released February 22, 2026
|
||||
|
||||
This release adds support for search within a topic, and adds [copy action](publish.md#copy-to-clipboard) support
|
||||
to the Android app.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
|
||||
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
|
||||
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
|
||||
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
|
||||
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
|
||||
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
|
||||
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
|
||||
|
||||
## ntfy server v2.17.0
|
||||
Released February 8, 2026
|
||||
|
||||
This release adds support for templating in the priority field, a new "copy" action button to copy values to the clipboard,
|
||||
a red notification dot on the favicon for unread messages, and an admin-only version endpoint. It also includes several
|
||||
crash fixes, web app improvements, and documentation updates.
|
||||
|
||||
❤️ If you like ntfy, **please consider sponsoring me** via [GitHub Sponsors](https://github.com/sponsors/binwiederhier), [Liberapay](https://en.liberapay.com/ntfy/), Bitcoin (`1626wjrw3uWk9adyjCfYwafw4sQWujyjn8`),
|
||||
or by buying a [paid plan via the web app](https://ntfy.sh/app). ntfy will always remain open source.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
|
||||
* Server: Add admin-only `GET /v1/version` endpoint returning server version, build commit, and date ([#1599](https://github.com/binwiederhier/ntfy/issues/1599), thanks to [@crivchri](https://github.com/crivchri) for reporting)
|
||||
* Server/Web: [Support "copy" action](publish.md#copy-to-clipboard) button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
|
||||
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
|
||||
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
|
||||
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
|
||||
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
|
||||
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
|
||||
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
|
||||
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
|
||||
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
|
||||
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
|
||||
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
|
||||
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
|
||||
|
||||
## ntfy Android app v1.22.2
|
||||
Released January 20, 2026
|
||||
|
||||
@@ -1665,43 +1798,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
|
||||
## Not released yet
|
||||
|
||||
### ntfy Android v1.23.x (UNRELEASED)
|
||||
### ntfy server v2.20.x (UNRELEASED)
|
||||
|
||||
**Features:**
|
||||
|
||||
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
|
||||
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
|
||||
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
|
||||
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
|
||||
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
|
||||
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
|
||||
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
|
||||
|
||||
### ntfy server v2.17.x (UNRELEASED)
|
||||
|
||||
**Features:**
|
||||
|
||||
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
|
||||
* Server/Web: Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
|
||||
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
|
||||
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
|
||||
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
|
||||
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
|
||||
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
|
||||
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
|
||||
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
|
||||
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
|
||||
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
|
||||
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
|
||||
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
|
||||
* Reject invalid e-mail addresses (e.g. multiple comma-separated recipients) with HTTP 400
|
||||
|
||||
853
docs/static/css/config-generator.css
vendored
Normal file
853
docs/static/css/config-generator.css
vendored
Normal file
@@ -0,0 +1,853 @@
|
||||
/* Config Generator */
|
||||
|
||||
/* Hidden utility */
|
||||
.cg-hidden {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
/* Open button */
|
||||
.cg-open-btn {
|
||||
display: inline-block;
|
||||
padding: 8px 20px;
|
||||
background: var(--md-primary-fg-color);
|
||||
color: #fff;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
|
||||
.cg-open-btn:hover {
|
||||
opacity: 0.85;
|
||||
}
|
||||
|
||||
/* Modal overlay */
|
||||
.cg-modal {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.cg-modal-backdrop {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
.cg-modal-dialog {
|
||||
position: absolute;
|
||||
inset: 24px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
background: #fff;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.25);
|
||||
overflow: hidden;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.cg-modal-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 12px 20px;
|
||||
border-bottom: 1px solid #ddd;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.cg-modal-header-left {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 12px;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.cg-modal-title {
|
||||
font-weight: 600;
|
||||
font-size: 0.95rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.cg-badge-beta {
|
||||
display: inline-block;
|
||||
padding: 1px 8px;
|
||||
margin-left: 8px;
|
||||
background: var(--md-primary-fg-color);
|
||||
color: #fff;
|
||||
font-size: 0.6rem;
|
||||
font-weight: 600;
|
||||
border-radius: 10px;
|
||||
letter-spacing: 0.5px;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.cg-modal-desc {
|
||||
font-size: 0.75rem;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
.cg-modal-header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.cg-modal-reset {
|
||||
background: none;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
color: #777;
|
||||
cursor: pointer;
|
||||
padding: 4px 12px;
|
||||
font-family: inherit;
|
||||
transition: color 0.15s, border-color 0.15s;
|
||||
}
|
||||
|
||||
.cg-modal-reset:hover {
|
||||
color: #333;
|
||||
border-color: #999;
|
||||
}
|
||||
|
||||
.cg-modal-close {
|
||||
background: none;
|
||||
border: none;
|
||||
font-size: 1.4rem;
|
||||
color: #999;
|
||||
cursor: pointer;
|
||||
padding: 0 4px;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.cg-modal-close:hover {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
/* Modal body: left + right */
|
||||
.cg-modal-body {
|
||||
display: flex;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* Left panel */
|
||||
#cg-left {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-right: 1px solid #ddd;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.cg-nav {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0;
|
||||
border-bottom: 1px solid #ddd;
|
||||
flex-shrink: 0;
|
||||
padding: 0 16px;
|
||||
}
|
||||
|
||||
.cg-nav-tab {
|
||||
padding: 9px 14px;
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 500;
|
||||
color: #777;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -1px;
|
||||
user-select: none;
|
||||
transition: color 0.15s, border-color 0.15s;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.cg-nav-tab:hover {
|
||||
color: #444;
|
||||
}
|
||||
|
||||
.cg-nav-tab.active {
|
||||
color: var(--md-primary-fg-color);
|
||||
border-bottom-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
.cg-panels {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px 20px;
|
||||
}
|
||||
|
||||
.cg-panel {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.cg-panel.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.cg-panel-desc {
|
||||
font-size: 0.75rem;
|
||||
color: #888;
|
||||
margin-bottom: 12px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.cg-panel-desc a {
|
||||
color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
.cg-help {
|
||||
color: var(--md-primary-fg-color);
|
||||
text-decoration: none;
|
||||
margin-left: 4px;
|
||||
vertical-align: middle;
|
||||
flex-shrink: 0;
|
||||
transition: color 0.15s;
|
||||
}
|
||||
|
||||
.cg-help:hover {
|
||||
color: var(--md-primary-fg-color--dark, #2a6e5f);
|
||||
}
|
||||
|
||||
/* Right panel */
|
||||
#cg-right {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.cg-output-tabs {
|
||||
display: flex;
|
||||
border-bottom: 1px solid #ddd;
|
||||
flex-shrink: 0;
|
||||
padding: 0 16px;
|
||||
}
|
||||
|
||||
.cg-output-tab {
|
||||
padding: 9px 14px;
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 500;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -1px;
|
||||
color: #777;
|
||||
transition: color 0.15s, border-color 0.15s;
|
||||
user-select: none;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.cg-output-tab:hover {
|
||||
color: #444;
|
||||
}
|
||||
|
||||
.cg-output-tab.active {
|
||||
color: var(--md-primary-fg-color);
|
||||
border-bottom-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
.cg-btn-copy {
|
||||
margin-left: auto;
|
||||
background: none;
|
||||
color: #777;
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
margin-bottom: -1px;
|
||||
padding: 9px 10px;
|
||||
cursor: pointer;
|
||||
line-height: 1;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: color 0.15s;
|
||||
}
|
||||
|
||||
.cg-btn-copy:hover {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.cg-output-wrap {
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
padding: 16px 20px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.cg-output-wrap pre {
|
||||
margin: 0;
|
||||
padding: 8px 10px;
|
||||
background: #f5f5f5;
|
||||
color: #333;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 6px;
|
||||
overflow-x: auto;
|
||||
font-size: 0.76rem;
|
||||
line-height: 1.5;
|
||||
flex: 1;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.cg-empty-msg {
|
||||
color: #888;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.cg-warning {
|
||||
padding: 6px 10px;
|
||||
margin-top: 8px;
|
||||
background: #fff3cd;
|
||||
color: #856404;
|
||||
border: 1px solid #ffc107;
|
||||
border-radius: 4px;
|
||||
font-size: 0.76rem;
|
||||
}
|
||||
|
||||
/* Form fields */
|
||||
.cg-field {
|
||||
margin-bottom: 0;
|
||||
padding: 8px 12px;
|
||||
}
|
||||
|
||||
.cg-field:nth-child(odd) {
|
||||
background: #f8f8f8;
|
||||
}
|
||||
|
||||
.cg-field:nth-child(even) {
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.cg-field > label {
|
||||
display: block;
|
||||
font-weight: 500;
|
||||
margin-bottom: 4px;
|
||||
font-size: 0.78rem;
|
||||
color: #555;
|
||||
}
|
||||
|
||||
.cg-field input[type="text"],
|
||||
.cg-field input[type="password"],
|
||||
.cg-field select {
|
||||
width: 100%;
|
||||
padding: 6px 8px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 0.78rem;
|
||||
font-family: inherit;
|
||||
box-sizing: border-box;
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.cg-field input[type="text"]:focus,
|
||||
.cg-field input[type="password"]:focus,
|
||||
.cg-field select:focus {
|
||||
border-color: var(--md-primary-fg-color);
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 2px rgba(51, 133, 116, 0.15);
|
||||
}
|
||||
|
||||
.cg-checkbox {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.cg-checkbox input[type="checkbox"] {
|
||||
accent-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
.cg-checkbox label {
|
||||
font-weight: 500;
|
||||
font-size: 0.78rem;
|
||||
margin: 0;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.cg-radio-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.cg-radio-group label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-weight: 400;
|
||||
font-size: 0.78rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.cg-radio-group input[type="radio"] {
|
||||
accent-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
/* Inline field: label + control side by side */
|
||||
.cg-inline-field {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.cg-inline-field > label {
|
||||
margin-bottom: 0;
|
||||
width: 60%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.cg-panel:not(#cg-panel-general) .cg-inline-field > label {
|
||||
width: 50%;
|
||||
}
|
||||
|
||||
.cg-inline-field > input[type="text"],
|
||||
.cg-inline-field > select {
|
||||
padding: 4px 10px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-family: inherit;
|
||||
box-sizing: border-box;
|
||||
line-height: 1.4;
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.cg-inline-field > input[type="text"]:focus,
|
||||
.cg-inline-field > select:focus {
|
||||
border-color: var(--md-primary-fg-color);
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 2px rgba(51, 133, 116, 0.15);
|
||||
}
|
||||
|
||||
#cg-email-in-section {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.cg-pg-label {
|
||||
font-size: 0.75rem;
|
||||
color: #888;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Button group toggle */
|
||||
.cg-btn-group {
|
||||
display: flex;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.cg-btn-group label {
|
||||
cursor: pointer;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.cg-btn-group input[type="radio"] {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.cg-btn-group span {
|
||||
display: block;
|
||||
padding: 4px 14px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.4;
|
||||
border: 1px solid #ccc;
|
||||
color: #555;
|
||||
background: #fff;
|
||||
transition: background 0.15s, color 0.15s, border-color 0.15s;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.cg-btn-group label:first-child span {
|
||||
border-radius: 4px 0 0 4px;
|
||||
}
|
||||
|
||||
.cg-btn-group label:last-child span {
|
||||
border-radius: 0 4px 4px 0;
|
||||
}
|
||||
|
||||
.cg-btn-group label + label span {
|
||||
margin-left: -1px;
|
||||
}
|
||||
|
||||
.cg-btn-group input[type="radio"]:checked + span {
|
||||
background: var(--md-primary-fg-color);
|
||||
color: #fff;
|
||||
border-color: var(--md-primary-fg-color);
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.cg-feature-grid {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 5px;
|
||||
}
|
||||
|
||||
.cg-feature-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.cg-feature-grid label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-size: 0.78rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.cg-feature-grid input[type="checkbox"] {
|
||||
accent-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
.cg-btn-configure {
|
||||
background: none;
|
||||
border: 1px solid var(--md-primary-fg-color);
|
||||
border-radius: 10px;
|
||||
color: var(--md-primary-fg-color);
|
||||
font-size: 0.68rem;
|
||||
font-family: inherit;
|
||||
padding: 1px 10px;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
|
||||
.cg-btn-configure:hover {
|
||||
background: var(--md-primary-fg-color);
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
/* Repeatable rows */
|
||||
.cg-repeatable-row {
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
align-items: center;
|
||||
margin-bottom: 6px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.cg-repeatable-row input,
|
||||
.cg-repeatable-row select {
|
||||
flex: 1;
|
||||
min-width: 80px;
|
||||
padding: 5px 6px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
font-family: inherit;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.cg-repeatable-row input:focus,
|
||||
.cg-repeatable-row select:focus {
|
||||
border-color: var(--md-primary-fg-color);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.cg-repeatable-row input:disabled,
|
||||
.cg-repeatable-row select:disabled {
|
||||
background: #eee;
|
||||
color: #999;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.cg-btn-remove {
|
||||
background: none;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9rem;
|
||||
padding: 5px 8px;
|
||||
color: #999;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.cg-btn-remove:hover {
|
||||
background: #fee;
|
||||
border-color: #c66;
|
||||
color: #c33;
|
||||
}
|
||||
|
||||
.cg-btn-add {
|
||||
background: none;
|
||||
border: 1px dashed #bbb;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
padding: 5px 10px;
|
||||
font-size: 0.75rem;
|
||||
color: #777;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.cg-btn-add:hover {
|
||||
border-color: var(--md-primary-fg-color);
|
||||
color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
/* Dark mode */
|
||||
body[data-md-color-scheme="slate"] .cg-modal-dialog {
|
||||
background: #1e1e2e;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-header {
|
||||
border-bottom-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-title {
|
||||
color: #ddd;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-desc {
|
||||
color: #777;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-reset {
|
||||
border-color: #555;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-reset:hover {
|
||||
border-color: #888;
|
||||
color: #ddd;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-close {
|
||||
color: #777;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-modal-close:hover {
|
||||
color: #ddd;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] #cg-left {
|
||||
border-right-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-nav {
|
||||
border-bottom-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-nav-tab {
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-nav-tab:hover {
|
||||
color: #bbb;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-output-tabs {
|
||||
border-bottom-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-output-tab {
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-output-tab:hover {
|
||||
color: #bbb;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-copy {
|
||||
color: #777;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-copy:hover {
|
||||
color: #bbb;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-output-wrap pre {
|
||||
background: #161620;
|
||||
color: #ddd;
|
||||
border-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-field:nth-child(odd) {
|
||||
background: #232334;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-field:nth-child(even) {
|
||||
background: #1e1e2e;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-panel-desc {
|
||||
color: #777;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-field > label {
|
||||
color: #aaa;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-field input[type="text"],
|
||||
body[data-md-color-scheme="slate"] .cg-field input[type="password"],
|
||||
body[data-md-color-scheme="slate"] .cg-field select {
|
||||
background: #2a2a3a;
|
||||
border-color: #555;
|
||||
color: #ddd;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-group span {
|
||||
background: #2a2a3a;
|
||||
border-color: #555;
|
||||
color: #aaa;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-group input[type="radio"]:checked + span {
|
||||
background: var(--md-primary-fg-color);
|
||||
color: #fff;
|
||||
border-color: var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-checkbox label {
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-radio-group label {
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-feature-grid label {
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-repeatable-row input,
|
||||
body[data-md-color-scheme="slate"] .cg-repeatable-row select {
|
||||
background: #2a2a3a;
|
||||
border-color: #555;
|
||||
color: #ddd;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-repeatable-row input:disabled,
|
||||
body[data-md-color-scheme="slate"] .cg-repeatable-row select:disabled {
|
||||
background: #1a1a28;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-remove {
|
||||
border-color: #555;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-btn-add {
|
||||
border-color: #555;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-warning {
|
||||
background: #3a2e00;
|
||||
color: #ffc107;
|
||||
border-color: #665200;
|
||||
}
|
||||
|
||||
/* Mobile toggle bar (hidden on desktop) */
|
||||
.cg-mobile-toggle {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Responsive */
|
||||
@media (max-width: 900px) {
|
||||
.cg-modal-dialog {
|
||||
inset: 0;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.cg-modal-header {
|
||||
padding: 8px 16px;
|
||||
}
|
||||
|
||||
.cg-modal-title {
|
||||
font-size: 0.85rem;
|
||||
}
|
||||
|
||||
.cg-modal-desc {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.cg-modal-body {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.cg-mobile-toggle {
|
||||
display: flex;
|
||||
flex-shrink: 0;
|
||||
border-bottom: 1px solid #ddd;
|
||||
}
|
||||
|
||||
.cg-mobile-toggle-btn {
|
||||
flex: 1;
|
||||
padding: 8px 0;
|
||||
border: none;
|
||||
background: #f5f5f5;
|
||||
font-size: 0.78rem;
|
||||
font-weight: 500;
|
||||
font-family: inherit;
|
||||
color: #777;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}
|
||||
|
||||
.cg-mobile-toggle-btn.active {
|
||||
background: #fff;
|
||||
color: var(--md-primary-fg-color);
|
||||
box-shadow: inset 0 -2px 0 var(--md-primary-fg-color);
|
||||
}
|
||||
|
||||
#cg-left {
|
||||
border-right: none;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
#cg-right {
|
||||
flex: 1;
|
||||
display: none;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
#cg-right.cg-mobile-active {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
#cg-left.cg-mobile-hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.cg-nav {
|
||||
overflow-x: auto;
|
||||
flex-wrap: nowrap;
|
||||
-webkit-overflow-scrolling: touch;
|
||||
}
|
||||
|
||||
.cg-inline-field {
|
||||
flex-direction: column;
|
||||
align-items: stretch;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.cg-inline-field > label {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.cg-panel:not(#cg-panel-general) .cg-inline-field > label {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
/* Dark mode mobile toggle */
|
||||
body[data-md-color-scheme="slate"] .cg-mobile-toggle {
|
||||
border-bottom-color: #444;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-mobile-toggle-btn {
|
||||
background: #2a2a3a;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
body[data-md-color-scheme="slate"] .cg-mobile-toggle-btn.active {
|
||||
background: #1e1e2e;
|
||||
color: var(--md-primary-fg-color);
|
||||
}
|
||||
BIN
docs/static/img/config-generator.png
vendored
Normal file
BIN
docs/static/img/config-generator.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 318 KiB |
1220
docs/static/js/bcrypt.js
vendored
Normal file
1220
docs/static/js/bcrypt.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1357
docs/static/js/config-generator.js
vendored
Normal file
1357
docs/static/js/config-generator.js
vendored
Normal file
File diff suppressed because it is too large
Load Diff
66
go.mod
66
go.mod
@@ -1,25 +1,25 @@
|
||||
module heckel.io/ntfy/v2
|
||||
|
||||
go 1.24.6
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
cloud.google.com/go/firestore v1.21.0 // indirect
|
||||
cloud.google.com/go/storage v1.59.2 // indirect
|
||||
cloud.google.com/go/storage v1.61.3 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
|
||||
github.com/emersion/go-smtp v0.18.0
|
||||
github.com/gabriel-vasile/mimetype v1.4.13
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/mattn/go-sqlite3 v1.14.34
|
||||
github.com/olebedev/when v1.1.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/urfave/cli/v2 v2.27.7
|
||||
golang.org/x/crypto v0.47.0
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0
|
||||
golang.org/x/term v0.39.0
|
||||
golang.org/x/time v0.14.0
|
||||
google.golang.org/api v0.265.0
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/time v0.15.0
|
||||
google.golang.org/api v0.271.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
@@ -30,17 +30,18 @@ require github.com/pkg/errors v0.9.1 // indirect
|
||||
require (
|
||||
firebase.google.com/go/v4 v4.19.0
|
||||
github.com/SherClockHolmes/webpush-go v1.4.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/microcosm-cc/bluemonday v1.0.27
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/stripe/stripe-go/v74 v74.30.0
|
||||
golang.org/x/sys v0.40.0
|
||||
golang.org/x/text v0.33.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/text v0.35.0
|
||||
)
|
||||
|
||||
require (
|
||||
cel.dev/expr v0.25.1 // indirect
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.1 // indirect
|
||||
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
cloud.google.com/go/iam v1.5.3 // indirect
|
||||
@@ -57,8 +58,8 @@ require (
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
@@ -68,35 +69,38 @@ require (
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.18.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.5 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/prometheus/procfs v0.20.1 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
|
||||
go.opentelemetry.io/otel v1.40.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.40.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.40.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.40.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.42.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
go.opentelemetry.io/otel v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.42.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.4 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
google.golang.org/appengine/v2 v2.0.6 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect
|
||||
google.golang.org/grpc v1.78.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c // indirect
|
||||
google.golang.org/grpc v1.79.2 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
141
go.sum
141
go.sum
@@ -2,8 +2,8 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
|
||||
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
|
||||
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
|
||||
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568 h1:PJt3KrySfZkKdcEV2wlyNkfAPbMZGjtnv5oLrT4tWPg=
|
||||
cloud.google.com/go/auth v0.18.3-0.20260310051336-87cdcc9f7568/go.mod h1:/Tt0rLCp4FHXEBtdyYqvIZPcJzbpJ/fmqtgIaXseDK4=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -12,14 +12,14 @@ cloud.google.com/go/firestore v1.21.0 h1:BhopUsx7kh6NFx77ccRsHhrtkbJUmDAxNY3uapW
|
||||
cloud.google.com/go/firestore v1.21.0/go.mod h1:1xH6HNcnkf/gGyR8udd6pFO4Z7GWJSwLKQMx/u6UrP4=
|
||||
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
|
||||
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
|
||||
cloud.google.com/go/logging v1.13.1 h1:O7LvmO0kGLaHY/gq8cV7T0dyp6zJhYAOtZPX4TF3QtY=
|
||||
cloud.google.com/go/logging v1.13.1/go.mod h1:XAQkfkMBxQRjQek96WLPNze7vsOmay9H5PqfsNYDqvw=
|
||||
cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA=
|
||||
cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak=
|
||||
cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8=
|
||||
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
|
||||
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
|
||||
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
|
||||
cloud.google.com/go/storage v1.59.2 h1:gmOAuG1opU8YvycMNpP+DvHfT9BfzzK5Cy+arP+Nocw=
|
||||
cloud.google.com/go/storage v1.59.2/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI=
|
||||
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
|
||||
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
|
||||
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
|
||||
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
|
||||
firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8=
|
||||
@@ -58,14 +58,14 @@ github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 h1:oP4q0fw+fOSWn3
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
|
||||
github.com/emersion/go-smtp v0.17.0 h1:tq90evlrcyqRfE6DSXaWVH54oX6OuZOQECEmhWBMEtI=
|
||||
github.com/emersion/go-smtp v0.17.0/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ=
|
||||
github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329 h1:K+fnvUM0VZ7ZFJf0n4L/BRlnsb9pL/GuDG6FqaH+PwM=
|
||||
github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329/go.mod h1:Alz8LEClvR7xKsrq3qzoc4N0guvVNSS8KmSChGYr9hs=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
@@ -96,14 +96,22 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
|
||||
github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc=
|
||||
github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI=
|
||||
github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@@ -112,8 +120,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
|
||||
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
@@ -133,8 +141,8 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
|
||||
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||
@@ -144,6 +152,7 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
@@ -156,36 +165,36 @@ github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBi
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/rxAWVlHsHIZ3fT2sA=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 h1:XmiuHzgJt067+a6kwyAzkhXooYVv3/TOw9cM2VfJgUM=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0/go.mod h1:KDgtbWKTQs4bM+VPUr6WlL9m/WXcmkCcBlIzqxPGzmI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0=
|
||||
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
||||
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0 h1:wm/Q0GAAykXv83wzcKzGGqAnnfLFyFe7RslekZuv+VI=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0/go.mod h1:ra3Pa40+oKjvYh+ZD3EdxFZZB0xdMfuileHAm4nNN7w=
|
||||
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
|
||||
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
|
||||
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
|
||||
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
|
||||
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
|
||||
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.42.0 h1:kpt2PEJuOuqYkPcktfJqWWDjTEd/FNgrxcniL7kQrXQ=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.42.0/go.mod h1:W9zQ439utxymRrXsUOzZbFX4JhLxXU4+ZnCt8GG7yA8=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
|
||||
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
@@ -200,10 +209,10 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -211,8 +220,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -225,8 +234,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -236,8 +245,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -249,10 +258,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
@@ -263,18 +272,18 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.265.0 h1:FZvfUdI8nfmuNrE34aOWFPmLC+qRBEiNm3JdivTvAAU=
|
||||
google.golang.org/api v0.265.0/go.mod h1:uAvfEl3SLUj/7n6k+lJutcswVojHPp2Sp08jWCu8hLY=
|
||||
google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY=
|
||||
google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q=
|
||||
google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw=
|
||||
google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI=
|
||||
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20 h1:/CU1zrxTpGylJJbe3Ru94yy6sZRbzALq2/oxl3pGB3U=
|
||||
google.golang.org/genproto v0.0.0-20260203192932-546029d2fa20/go.mod h1:Tt+08/KdKEt3l8x3Pby3HLQxMB3uk/MzaQ4ZIv0ORTs=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 h1:7ei4lp52gK1uSejlA8AZl5AJjeLUOHBQscRQZUgAcu0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20/go.mod h1:ZdbssH/1SOVnjnDlXzxDHK2MCidiqXtbYccJNzNYPEE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c h1:ZhFDeBMmFc/4g8/GwxnJ4rzB3O4GwQVNr+8Mh7Y5z4g=
|
||||
google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c/go.mod h1:hf4r/rBuzaTkLUWRO03771Xvcs6P5hwdQK3UUEJjqo0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c h1:OyQPd6I3pN/9gDxz6L13kYGJgqkpdrAohJRBeXyxlgI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c/go.mod h1:X2gu9Qwng7Nn009s/r3RUxqkzQNqOrAy79bluY7ojIg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c h1:xgCzyF2LFIO/0X2UAoVRiXKU5Xg6VjToG4i2/ecSswk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU=
|
||||
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
592
message/cache.go
Normal file
592
message/cache.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagMessageCache = "message_cache"
|
||||
)
|
||||
|
||||
var errNoRows = errors.New("no rows found")
|
||||
|
||||
// queries holds the database-specific SQL queries
|
||||
type queries struct {
|
||||
insertMessage string
|
||||
deleteMessage string
|
||||
selectScheduledMessageIDsBySeqID string
|
||||
deleteScheduledBySequenceID string
|
||||
updateMessagesForTopicExpiry string
|
||||
selectMessagesByID string
|
||||
selectMessagesSinceTime string
|
||||
selectMessagesSinceTimeScheduled string
|
||||
selectMessagesSinceID string
|
||||
selectMessagesSinceIDScheduled string
|
||||
selectMessagesLatest string
|
||||
selectMessagesDue string
|
||||
selectMessagesExpired string
|
||||
updateMessagePublished string
|
||||
selectMessagesCount string
|
||||
selectTopics string
|
||||
updateAttachmentDeleted string
|
||||
selectAttachmentsExpired string
|
||||
selectAttachmentsSizeBySender string
|
||||
selectAttachmentsSizeByUserID string
|
||||
selectStats string
|
||||
updateStats string
|
||||
updateMessageTime string
|
||||
}
|
||||
|
||||
// Cache stores published messages
|
||||
type Cache struct {
|
||||
db *db.DB
|
||||
queue *util.BatchingQueue[*model.Message]
|
||||
nop bool
|
||||
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
||||
queries queries
|
||||
}
|
||||
|
||||
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||
var queue *util.BatchingQueue[*model.Message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||
}
|
||||
c := &Cache{
|
||||
db: db,
|
||||
queue: queue,
|
||||
nop: nop,
|
||||
mu: mu,
|
||||
queries: queries,
|
||||
}
|
||||
go c.processMessageBatches()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) maybeLock() {
|
||||
if c.mu != nil {
|
||||
c.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) maybeUnlock() {
|
||||
if c.mu != nil {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
||||
func (c *Cache) AddMessage(m *model.Message) error {
|
||||
if c.queue != nil {
|
||||
c.queue.Enqueue(m)
|
||||
return nil
|
||||
}
|
||||
return c.addMessages([]*model.Message{m})
|
||||
}
|
||||
|
||||
// AddMessages synchronously stores a batch of messages to the message cache
|
||||
func (c *Cache) AddMessages(ms []*model.Message) error {
|
||||
return c.addMessages(ms)
|
||||
}
|
||||
|
||||
func (c *Cache) addMessages(ms []*model.Message) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
if c.nop {
|
||||
return nil
|
||||
}
|
||||
if len(ms) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := time.Now()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
stmt, err := tx.Prepare(c.queries.insertMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
for _, m := range ms {
|
||||
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
|
||||
return model.ErrUnexpectedMessageType
|
||||
}
|
||||
published := m.Time <= time.Now().Unix()
|
||||
tags := util.SanitizeUTF8(strings.Join(m.Tags, ","))
|
||||
var attachmentName, attachmentType, attachmentURL string
|
||||
var attachmentSize, attachmentExpires int64
|
||||
var attachmentDeleted bool
|
||||
if m.Attachment != nil {
|
||||
attachmentName = util.SanitizeUTF8(m.Attachment.Name)
|
||||
attachmentType = util.SanitizeUTF8(m.Attachment.Type)
|
||||
attachmentSize = m.Attachment.Size
|
||||
attachmentExpires = m.Attachment.Expires
|
||||
attachmentURL = util.SanitizeUTF8(m.Attachment.URL)
|
||||
}
|
||||
var actionsStr string
|
||||
if len(m.Actions) > 0 {
|
||||
actionsBytes, err := json.Marshal(m.Actions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionsStr = string(actionsBytes)
|
||||
}
|
||||
var sender string
|
||||
if m.Sender.IsValid() {
|
||||
sender = m.Sender.String()
|
||||
}
|
||||
_, err := stmt.Exec(
|
||||
m.ID,
|
||||
m.SequenceID,
|
||||
m.Time,
|
||||
m.Event,
|
||||
m.Expires,
|
||||
util.SanitizeUTF8(m.Topic),
|
||||
util.SanitizeUTF8(m.Message),
|
||||
util.SanitizeUTF8(m.Title),
|
||||
m.Priority,
|
||||
tags,
|
||||
util.SanitizeUTF8(m.Click),
|
||||
util.SanitizeUTF8(m.Icon),
|
||||
actionsStr,
|
||||
attachmentName,
|
||||
attachmentType,
|
||||
attachmentSize,
|
||||
attachmentExpires,
|
||||
attachmentURL,
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
util.SanitizeUTF8(m.ContentType),
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
|
||||
return err
|
||||
}
|
||||
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
|
||||
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
if since.IsNone() {
|
||||
return make([]*model.Message, 0), nil
|
||||
} else if since.IsLatest() {
|
||||
return c.messagesLatest(topic)
|
||||
} else if since.IsID() {
|
||||
return c.messagesSinceID(topic, since, scheduled)
|
||||
}
|
||||
return c.messagesSinceTime(topic, since, scheduled)
|
||||
}
|
||||
|
||||
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
rdb := c.db.ReadOnly()
|
||||
if scheduled {
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||
} else {
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
rdb := c.db.ReadOnly()
|
||||
if scheduled {
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
||||
} else {
|
||||
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesDue returns all messages that are due for publishing
|
||||
func (c *Cache) MessagesDue() ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||
func (c *Cache) MessagesExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||
func (c *Cache) Message(id string) (*model.Message, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rows.Next() {
|
||||
return nil, model.ErrMessageNotFound
|
||||
}
|
||||
defer rows.Close()
|
||||
return readMessage(rows)
|
||||
}
|
||||
|
||||
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
||||
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarkPublished marks a message as published
|
||||
func (c *Cache) MarkPublished(m *model.Message) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
// MessagesCount returns the total number of messages in the cache
|
||||
func (c *Cache) MessagesCount() (int, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Topics returns a list of all topics with messages in the cache
|
||||
func (c *Cache) Topics() ([]string, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics = append(topics, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
// DeleteMessages deletes the messages with the given IDs
|
||||
func (c *Cache) DeleteMessages(ids ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
|
||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close() // Close rows before executing delete in same transaction
|
||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ExpireMessages marks messages in the given topics as expired
|
||||
func (c *Cache) ExpireMessages(topics ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||
func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
|
||||
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
||||
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||
defer rows.Close()
|
||||
var size int64
|
||||
if !rows.Next() {
|
||||
return 0, errors.New("no rows found")
|
||||
}
|
||||
if err := rows.Scan(&size); err != nil {
|
||||
return 0, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// UpdateStats updates the total message count statistic
|
||||
func (c *Cache) UpdateStats(messages int64) error {
|
||||
c.maybeLock()
|
||||
defer c.maybeUnlock()
|
||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
// Stats returns the total message count statistic
|
||||
func (c *Cache) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *Cache) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *Cache) processMessageBatches() {
|
||||
if c.queue == nil {
|
||||
return
|
||||
}
|
||||
for messages := range c.queue.Dequeue() {
|
||||
if err := c.addMessages(messages); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
|
||||
defer rows.Close()
|
||||
messages := make([]*model.Message, 0)
|
||||
for rows.Next() {
|
||||
m, err := readMessage(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func readMessage(rows *sql.Rows) (*model.Message, error) {
|
||||
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||
var priority int
|
||||
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
||||
err := rows.Scan(
|
||||
&id,
|
||||
&sequenceID,
|
||||
×tamp,
|
||||
&event,
|
||||
&expires,
|
||||
&topic,
|
||||
&msg,
|
||||
&title,
|
||||
&priority,
|
||||
&tagsStr,
|
||||
&click,
|
||||
&icon,
|
||||
&actionsStr,
|
||||
&attachmentName,
|
||||
&attachmentType,
|
||||
&attachmentSize,
|
||||
&attachmentExpires,
|
||||
&attachmentURL,
|
||||
&sender,
|
||||
&user,
|
||||
&contentType,
|
||||
&encoding,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
}
|
||||
var actions []*model.Action
|
||||
if actionsStr != "" {
|
||||
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
senderIP, err := netip.ParseAddr(sender)
|
||||
if err != nil {
|
||||
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
||||
}
|
||||
var att *model.Attachment
|
||||
if attachmentName != "" && attachmentURL != "" {
|
||||
att = &model.Attachment{
|
||||
Name: attachmentName,
|
||||
Type: attachmentType,
|
||||
Size: attachmentSize,
|
||||
Expires: attachmentExpires,
|
||||
URL: attachmentURL,
|
||||
}
|
||||
}
|
||||
return &model.Message{
|
||||
ID: id,
|
||||
SequenceID: sequenceID,
|
||||
Time: timestamp,
|
||||
Expires: expires,
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
Title: title,
|
||||
Priority: priority,
|
||||
Tags: tags,
|
||||
Click: click,
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: senderIP,
|
||||
User: user,
|
||||
ContentType: contentType,
|
||||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
111
message/cache_postgres.go
Normal file
111
message/cache_postgres.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// PostgreSQL runtime query constants
|
||||
const (
|
||||
postgresInsertMessageQuery = `
|
||||
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
||||
`
|
||||
postgresDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
||||
postgresSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
postgresDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
postgresUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
||||
postgresSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE mid = $1
|
||||
`
|
||||
postgresSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
postgresSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2
|
||||
ORDER BY time, id
|
||||
`
|
||||
postgresSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1
|
||||
AND id > COALESCE((SELECT id FROM message WHERE mid = $2), 0)
|
||||
AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
postgresSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1
|
||||
AND (id > COALESCE((SELECT id FROM message WHERE mid = $2), 0) OR published = FALSE)
|
||||
ORDER BY time, id
|
||||
`
|
||||
postgresSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND published = TRUE
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
postgresSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE time <= $1 AND published = FALSE
|
||||
ORDER BY time, id
|
||||
`
|
||||
postgresSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
||||
postgresUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
||||
postgresSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
||||
postgresSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
||||
|
||||
postgresUpdateAttachmentDeletedQuery = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
||||
postgresSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
||||
postgresSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
||||
postgresSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
||||
|
||||
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||
)
|
||||
|
||||
var postgresQueries = queries{
|
||||
insertMessage: postgresInsertMessageQuery,
|
||||
deleteMessage: postgresDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: postgresDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: postgresUpdateMessagesForTopicExpiryQuery,
|
||||
selectMessagesByID: postgresSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: postgresSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: postgresSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: postgresSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: postgresSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: postgresSelectMessagesLatestQuery,
|
||||
selectMessagesDue: postgresSelectMessagesDueQuery,
|
||||
selectMessagesExpired: postgresSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: postgresUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: postgresSelectMessagesCountQuery,
|
||||
selectTopics: postgresSelectTopicsQuery,
|
||||
updateAttachmentDeleted: postgresUpdateAttachmentDeletedQuery,
|
||||
selectAttachmentsExpired: postgresSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: postgresSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: postgresSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: postgresSelectStatsQuery,
|
||||
updateStats: postgresUpdateStatsQuery,
|
||||
updateMessageTime: postgresUpdateMessageTimeQuery,
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||
func NewPostgresStore(d *db.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||
if err := setupPostgres(d.Primary()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCache(d, postgresQueries, nil, batchSize, batchTimeout, false), nil
|
||||
}
|
||||
85
message/cache_postgres_schema.go
Normal file
85
message/cache_postgres_schema.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
postgresCreateTablesQuery = `
|
||||
CREATE TABLE IF NOT EXISTS message (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time BIGINT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size BIGINT NOT NULL,
|
||||
attachment_expires BIGINT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
sender TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
|
||||
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS message_stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BIGINT
|
||||
);
|
||||
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
postgresCurrentSchemaVersion = 14
|
||||
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
} else if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
143
message/cache_sqlite.go
Normal file
143
message/cache_sqlite.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// SQLite runtime query constants
|
||||
const (
|
||||
sqliteInsertMessageQuery = `
|
||||
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
sqliteSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) OR published = 0)
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND published = 1
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
sqliteSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||
|
||||
sqliteUpdateAttachmentDeletedQuery = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
|
||||
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||
)
|
||||
|
||||
var sqliteQueries = queries{
|
||||
insertMessage: sqliteInsertMessageQuery,
|
||||
deleteMessage: sqliteDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
||||
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
|
||||
selectMessagesDue: sqliteSelectMessagesDueQuery,
|
||||
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
||||
selectTopics: sqliteSelectTopicsQuery,
|
||||
updateAttachmentDeleted: sqliteUpdateAttachmentDeletedQuery,
|
||||
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: sqliteSelectStatsQuery,
|
||||
updateStats: sqliteUpdateStatsQuery,
|
||||
updateMessageTime: sqliteUpdateMessageTimeQuery,
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a SQLite file-backed cache
|
||||
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) {
|
||||
parentDir := filepath.Dir(filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
d, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(d, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCache(db.New(&db.Host{DB: d}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||
}
|
||||
|
||||
// NewMemStore creates an in-memory cache
|
||||
func NewMemStore() (*Cache, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||
}
|
||||
|
||||
// NewNopStore creates an in-memory cache that discards all messages;
|
||||
// it is always empty and can be used if caching is entirely disabled
|
||||
func NewNopStore() (*Cache, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||
}
|
||||
|
||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
|
||||
// sql database, so if the stdlib's sql engine happens to open another connection and
|
||||
// you've only specified ":memory:", that connection will see a brand new database.
|
||||
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
|
||||
// Every connection to this string will point to the same in-memory database."
|
||||
func createMemoryFilename() string {
|
||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||
}
|
||||
453
message/cache_sqlite_schema.go
Normal file
453
message/cache_sqlite_schema.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
sqliteCreateTablesQuery = `
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
)
|
||||
|
||||
// Schema version management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 14
|
||||
sqliteCreateSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 0 -> 1
|
||||
sqliteMigrate0To1AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 1 -> 2
|
||||
sqliteMigrate1To2AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
sqliteMigrate2To3AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
// 3 -> 4
|
||||
sqliteMigrate3To4AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
sqliteMigrate4To5AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS messages_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_owner TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||
INSERT
|
||||
INTO messages_new (
|
||||
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||
SELECT
|
||||
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||
FROM messages;
|
||||
DROP TABLE messages;
|
||||
ALTER TABLE messages_new RENAME TO messages;
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
sqliteMigrate5To6AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 6 -> 7
|
||||
sqliteMigrate6To7AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||
`
|
||||
|
||||
// 7 -> 8
|
||||
sqliteMigrate7To8AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 8 -> 9
|
||||
sqliteMigrate8To9AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
`
|
||||
|
||||
// 9 -> 10
|
||||
sqliteMigrate9To10AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
sqliteMigrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
sqliteMigrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 12 -> 13
|
||||
sqliteMigrate12To13AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
`
|
||||
|
||||
// 13 -> 14
|
||||
sqliteMigrate13To14AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: sqliteMigrateFrom0,
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
6: sqliteMigrateFrom6,
|
||||
7: sqliteMigrateFrom7,
|
||||
8: sqliteMigrateFrom8,
|
||||
9: sqliteMigrateFrom9,
|
||||
10: sqliteMigrateFrom10,
|
||||
11: sqliteMigrateFrom11,
|
||||
12: sqliteMigrateFrom12,
|
||||
13: sqliteMigrateFrom13,
|
||||
}
|
||||
)
|
||||
|
||||
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
var messagesCount int
|
||||
if err := db.QueryRow(sqliteSelectMessagesCountQuery).Scan(&messagesCount); err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
// If 'messages' table exists (schema >= 0), check 'schemaVersion' table
|
||||
var schemaVersion int
|
||||
db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion) // Error means schema version is zero!
|
||||
// Do migrations
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db, cacheDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom0(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom2(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom3(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom4(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom5(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom6(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 7); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom7(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom8(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 9); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
292
message/cache_sqlite_test.go
Normal file
292
message/cache_sqlite_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestSqliteStore_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 0" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
require.Equal(t, "some message 5", messages[5].Message)
|
||||
require.Equal(t, "", messages[5].Title)
|
||||
require.Nil(t, messages[5].Tags)
|
||||
require.Equal(t, 0, messages[5].Priority)
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From1(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 1" schema
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(512) NOT NULL,
|
||||
title VARCHAR(256) NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags VARCHAR(256) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
|
||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(delayedMessage))
|
||||
|
||||
// 10, not 11!
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
|
||||
// 11!
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
require.Nil(t, rows.Scan(&indexName))
|
||||
require.Equal(t, "idx_topic", indexName)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From9(t *testing.T) {
|
||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 9" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
insertQuery := `
|
||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(
|
||||
insertQuery,
|
||||
fmt.Sprintf("abcd%d", i),
|
||||
time.Now().Unix(),
|
||||
"mytopic",
|
||||
fmt.Sprintf("some message %d", i),
|
||||
"", // title
|
||||
0, // priority
|
||||
"", // tags
|
||||
"", // click
|
||||
"", // icon
|
||||
"", // actions
|
||||
"", // attachment_name
|
||||
"", // attachment_type
|
||||
0, // attachment_size
|
||||
0, // attachment_expires
|
||||
"", // attachment_url
|
||||
"9.9.9.9", // sender
|
||||
"", // encoding
|
||||
1, // published
|
||||
)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
cacheDuration := 17 * time.Hour
|
||||
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Check version
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var version int
|
||||
require.Nil(t, rows.Scan(&version))
|
||||
require.Equal(t, 14, version)
|
||||
require.Nil(t, rows.Close())
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
for _, m := range messages {
|
||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
startupQueries := `pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;`
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.FileExists(t, filename+"-wal")
|
||||
require.FileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_None(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.NoFileExists(t, filename+"-wal")
|
||||
require.NoFileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNopStore(t *testing.T) {
|
||||
s, err := message.NewNopStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, messages)
|
||||
|
||||
topics, err := s.Topics()
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache {
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer db.Close()
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, 14, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
967
message/cache_test.go
Normal file
967
message/cache_test.go
Normal file
@@ -0,0 +1,967 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
dbtest "heckel.io/ntfy/v2/db/test"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func newSqliteTestStore(t *testing.T) *message.Cache {
|
||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newMemTestStore(t *testing.T) *message.Cache {
|
||||
s, err := message.NewMemStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newTestPostgresStore(t *testing.T) *message.Cache {
|
||||
testDB := dbtest.CreateTestPostgres(t)
|
||||
store, err := message.NewPostgresStore(testDB, 0, 0)
|
||||
require.Nil(t, err)
|
||||
return store
|
||||
}
|
||||
|
||||
func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) {
|
||||
t.Run("sqlite", func(t *testing.T) {
|
||||
f(t, newSqliteTestStore(t))
|
||||
})
|
||||
t.Run("mem", func(t *testing.T) {
|
||||
f(t, newMemTestStore(t))
|
||||
})
|
||||
t.Run("postgres", func(t *testing.T) {
|
||||
f(t, newTestPostgresStore(t))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Messages(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// count
|
||||
count, err := s.MessagesCount()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, count)
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "my message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
require.Equal(t, model.MessageEvent, messages[0].Event)
|
||||
require.Equal(t, "", messages[0].Title)
|
||||
require.Equal(t, 0, messages[0].Priority)
|
||||
require.Nil(t, messages[0].Tags)
|
||||
require.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
||||
require.Empty(t, messages)
|
||||
|
||||
// mytopic: since m1 (by ID)
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, m2.ID, messages[0].ID)
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// mytopic: latest
|
||||
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: since all
|
||||
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
||||
require.Empty(t, messages)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessagesLock(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessagesScheduled(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
|
||||
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||
require.Equal(t, "message 2", messages[2].Message)
|
||||
|
||||
messages, _ = s.MessagesDue()
|
||||
require.Empty(t, messages)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Topics(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
||||
|
||||
topics, err := s.Topics()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, 2, len(topics))
|
||||
require.Contains(t, topics, "topic1")
|
||||
require.Contains(t, topics, "topic2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
require.Equal(t, 5, messages[0].Priority)
|
||||
require.Equal(t, "some title", messages[0].Title)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessagesSinceID(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = 200
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
||||
m4.Time = 400
|
||||
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
||||
m6.Time = 600
|
||||
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
||||
m7.Time = 700
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
require.Nil(t, s.AddMessage(m4))
|
||||
require.Nil(t, s.AddMessage(m5))
|
||||
require.Nil(t, s.AddMessage(m6))
|
||||
require.Nil(t, s.AddMessage(m7))
|
||||
|
||||
// Case 1: Since ID exists, exclude scheduled
|
||||
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
|
||||
// Case 2: Since ID exists, include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
||||
require.Equal(t, 5, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message)
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||
|
||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
||||
require.Equal(t, 7, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 2", messages[1].Message)
|
||||
require.Equal(t, "message 4", messages[2].Message)
|
||||
require.Equal(t, "message 6", messages[3].Message)
|
||||
require.Equal(t, "message 7", messages[4].Message)
|
||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||
|
||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "message 5", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Prune(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = now - 10
|
||||
m1.Expires = now - 5
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = now - 5
|
||||
m2.Expires = now + 5 // In the future
|
||||
|
||||
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
||||
m3.Time = now - 12
|
||||
m3.Expires = now - 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
count, err := s.MessagesCount()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, count)
|
||||
|
||||
expiredMessageIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
||||
|
||||
count, err = s.MessagesCount()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Attachments(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 5000,
|
||||
Expires: expires1,
|
||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: expires2,
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.User = "u_BAsbaAa"
|
||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: expires3,
|
||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
require.Equal(t, "flower for you", messages[0].Message)
|
||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(10000), size)
|
||||
|
||||
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||
|
||||
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(20000), size)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AttachmentsExpired(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.SequenceID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Sender(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Create a scheduled (unpublished) message
|
||||
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
scheduledMsg.ID = "scheduled1"
|
||||
scheduledMsg.SequenceID = "seq123"
|
||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||
require.Nil(t, s.AddMessage(scheduledMsg))
|
||||
|
||||
// Create a published message with different sequence ID
|
||||
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
||||
publishedMsg.ID = "published1"
|
||||
publishedMsg.SequenceID = "seq456"
|
||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||
require.Nil(t, s.AddMessage(publishedMsg))
|
||||
|
||||
// Create a scheduled message in a different topic
|
||||
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
||||
otherTopicMsg.ID = "other1"
|
||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(otherTopicMsg))
|
||||
|
||||
// Verify all messages exist (including scheduled)
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Delete scheduled message by sequence ID and verify returned IDs
|
||||
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(deletedIDs))
|
||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||
|
||||
// Verify scheduled message is deleted
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
|
||||
// Verify other topic's message still exists (topic-scoped deletion)
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "other scheduled", messages[0].Message)
|
||||
|
||||
// Deleting non-existent sequence ID should return empty list
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
// Deleting published message should not affect it (only deletes unpublished)
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessageByID(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Add a message
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Title = "some title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve by ID
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "some message", retrieved.Message)
|
||||
require.Equal(t, "some title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
||||
|
||||
// Non-existent ID returns ErrMessageNotFound
|
||||
_, err = s.Message("doesnotexist")
|
||||
require.Equal(t, model.ErrMessageNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MarkPublished(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Add a scheduled message (future time -> unpublished)
|
||||
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
m.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Verify it does not appear in non-scheduled queries
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Verify it does appear in scheduled queries
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Mark as published
|
||||
require.Nil(t, s.MarkPublished(m))
|
||||
|
||||
// Now it should appear in non-scheduled queries too
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "scheduled message", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_ExpireMessages(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Add messages to two topics
|
||||
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2 := model.NewDefaultMessage("topic1", "message 2")
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("topic2", "message 3")
|
||||
m3.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Verify all messages exist
|
||||
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Expire topic1 messages
|
||||
require.Nil(t, s.ExpireMessages("topic1"))
|
||||
|
||||
// topic1 messages should now be expired (expires set to past)
|
||||
expiredIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(expiredIDs))
|
||||
sort.Strings(expiredIDs)
|
||||
expectedIDs := []string{m1.ID, m2.ID}
|
||||
sort.Strings(expectedIDs)
|
||||
require.Equal(t, expectedIDs, expiredIDs)
|
||||
|
||||
// topic2 should be unaffected
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 3", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Add a message with an expired attachment (file needs cleanup)
|
||||
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||
m1.ID = "msg1"
|
||||
m1.SequenceID = "msg1"
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m1.Attachment = &model.Attachment{
|
||||
Name: "old.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 50000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/old.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message with another expired attachment
|
||||
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
||||
m2.ID = "msg2"
|
||||
m2.SequenceID = "msg2"
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2.Attachment = &model.Attachment{
|
||||
Name: "another.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 30000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/another.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Both should show as expired attachments needing cleanup
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(ids))
|
||||
|
||||
// Mark msg1's attachment as deleted (file cleaned up)
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
||||
|
||||
// Now only msg2 should show as needing cleanup
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "msg2", ids[0])
|
||||
|
||||
// Mark msg2 too
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
||||
|
||||
// No more expired attachments to clean up
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(ids))
|
||||
|
||||
// Messages themselves still exist
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Stats(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Initial stats should be zero
|
||||
messages, err := s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), messages)
|
||||
|
||||
// Update stats
|
||||
require.Nil(t, s.UpdateStats(42))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(42), messages)
|
||||
|
||||
// Update again (overwrites)
|
||||
require.Nil(t, s.UpdateStats(100))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(100), messages)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessages(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Batch add multiple messages
|
||||
msgs := []*model.Message{
|
||||
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||
model.NewDefaultMessage("mytopic", "batch 2"),
|
||||
model.NewDefaultMessage("othertopic", "batch 3"),
|
||||
}
|
||||
require.Nil(t, s.AddMessages(msgs))
|
||||
|
||||
// Verify all were inserted
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "batch 3", messages[0].Message)
|
||||
|
||||
// Empty batch should succeed
|
||||
require.Nil(t, s.AddMessages([]*model.Message{}))
|
||||
|
||||
// Batch with invalid event type should fail
|
||||
badMsgs := []*model.Message{
|
||||
model.NewKeepaliveMessage("mytopic"),
|
||||
}
|
||||
require.NotNil(t, s.AddMessages(badMsgs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessagesDue(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Add a message scheduled in the past (i.e. it's due now)
|
||||
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||
// Set expires in the future so it doesn't get pruned
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message scheduled in the future (not due)
|
||||
m2 := model.NewDefaultMessage("mytopic", "future message")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Mark m1 as published so it won't be "due"
|
||||
// (MessagesDue returns unpublished messages whose time <= now)
|
||||
// m1 is auto-published (time <= now), so it should not be due
|
||||
// m2 is unpublished (time in future), not due yet
|
||||
due, err := s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
||||
// We need to manipulate the database to create a truly "due" message:
|
||||
// a message with published=false and time <= now
|
||||
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
||||
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Not due yet
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Wait for it to become due
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(due))
|
||||
require.Equal(t, "truly due message", due[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Create a message with all fields populated
|
||||
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||
m.SequenceID = "custom_seq_id"
|
||||
m.Title = "A Title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"warning", "srv01"}
|
||||
m.Click = "https://example.com/click"
|
||||
m.Icon = "https://example.com/icon.png"
|
||||
m.Actions = []*model.Action{
|
||||
{
|
||||
ID: "action1",
|
||||
Action: "view",
|
||||
Label: "Open Site",
|
||||
URL: "https://example.com",
|
||||
Clear: true,
|
||||
},
|
||||
{
|
||||
ID: "action2",
|
||||
Action: "http",
|
||||
Label: "Call Webhook",
|
||||
URL: "https://example.com/hook",
|
||||
Method: "PUT",
|
||||
Headers: map[string]string{"X-Token": "secret"},
|
||||
Body: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
m.ContentType = "text/markdown"
|
||||
m.Encoding = "base64"
|
||||
m.Sender = netip.MustParseAddr("9.8.7.6")
|
||||
m.User = "u_TestUser123"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve and verify every field
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
||||
require.Equal(t, m.Time, retrieved.Time)
|
||||
require.Equal(t, m.Expires, retrieved.Expires)
|
||||
require.Equal(t, model.MessageEvent, retrieved.Event)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "hello world", retrieved.Message)
|
||||
require.Equal(t, "A Title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
||||
require.Equal(t, "https://example.com/click", retrieved.Click)
|
||||
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
||||
require.Equal(t, "text/markdown", retrieved.ContentType)
|
||||
require.Equal(t, "base64", retrieved.Encoding)
|
||||
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
||||
require.Equal(t, "u_TestUser123", retrieved.User)
|
||||
|
||||
// Verify actions round-trip
|
||||
require.Equal(t, 2, len(retrieved.Actions))
|
||||
|
||||
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
||||
require.Equal(t, "view", retrieved.Actions[0].Action)
|
||||
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
||||
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
||||
require.Equal(t, true, retrieved.Actions[0].Clear)
|
||||
|
||||
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
||||
require.Equal(t, "http", retrieved.Actions[1].Action)
|
||||
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
||||
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
||||
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
||||
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
||||
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// 0xc9 0x43: Latin-1 "ÉC" — 0xc9 starts a 2-byte UTF-8 sequence but 0x43 ('C') is not a continuation byte
|
||||
m := model.NewDefaultMessage("mytopic", "\xc9Cas du serveur")
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "\uFFFDCas du serveur", messages[0].Message)
|
||||
|
||||
// 0xae: Latin-1 "®" — isolated byte above 0x7F, not a valid UTF-8 start for single byte
|
||||
m2 := model.NewDefaultMessage("mytopic", "Product\xae Pro")
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Product\uFFFD Pro", messages[1].Message)
|
||||
|
||||
// 0xe8 0x6d 0x65: Latin-1 "ème" — 0xe8 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
|
||||
m3 := model.NewDefaultMessage("mytopic", "probl\xe8me critique")
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "probl\uFFFDme critique", messages[2].Message)
|
||||
|
||||
// 0xb2: Latin-1 "²" — isolated byte in 0x80-0xBF range (UTF-8 continuation byte without lead)
|
||||
m4 := model.NewDefaultMessage("mytopic", "CO\xb2 level high")
|
||||
require.Nil(t, s.AddMessage(m4))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "CO\uFFFD level high", messages[3].Message)
|
||||
|
||||
// 0xe9 0x6d 0x61: Latin-1 "éma" — 0xe9 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
|
||||
m5 := model.NewDefaultMessage("mytopic", "th\xe9matique")
|
||||
require.Nil(t, s.AddMessage(m5))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "th\uFFFDmatique", messages[4].Message)
|
||||
|
||||
// 0xed 0x64 0x65: Latin-1 "íde" — 0xed starts a 3-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
|
||||
m6 := model.NewDefaultMessage("mytopic", "vid\xed\x64eo surveillance")
|
||||
require.Nil(t, s.AddMessage(m6))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "vid\uFFFDdeo surveillance", messages[5].Message)
|
||||
|
||||
// 0xf3 0x6e 0x3a 0x20: Latin-1 "ón: " — 0xf3 starts a 4-byte UTF-8 sequence but 0x6e ('n') is not a continuation byte
|
||||
m7 := model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta")
|
||||
require.Nil(t, s.AddMessage(m7))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "notificaci\uFFFDn: alerta", messages[6].Message)
|
||||
|
||||
// 0xb7: Latin-1 "·" — isolated continuation byte
|
||||
m8 := model.NewDefaultMessage("mytopic", "item\xb7value")
|
||||
require.Nil(t, s.AddMessage(m8))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "item\uFFFDvalue", messages[7].Message)
|
||||
|
||||
// 0xa8: Latin-1 "¨" — isolated continuation byte
|
||||
m9 := model.NewDefaultMessage("mytopic", "na\xa8ve")
|
||||
require.Nil(t, s.AddMessage(m9))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "na\uFFFDve", messages[8].Message)
|
||||
|
||||
// 0xdf 0x64: Latin-1 "ßd" — 0xdf starts a 2-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
|
||||
m10 := model.NewDefaultMessage("mytopic", "gro\xdf\x64ruck")
|
||||
require.Nil(t, s.AddMessage(m10))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "gro\uFFFDdruck", messages[9].Message)
|
||||
|
||||
// 0xe4 0x67 0x74: Latin-1 "ägt" — 0xe4 starts a 3-byte UTF-8 sequence but 0x67 ('g') is not a continuation byte
|
||||
m11 := model.NewDefaultMessage("mytopic", "tr\xe4gt Last")
|
||||
require.Nil(t, s.AddMessage(m11))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tr\uFFFDgt Last", messages[10].Message)
|
||||
|
||||
// 0xe9 0x65 0x20: Latin-1 "ée " — 0xe9 starts a 3-byte UTF-8 sequence but 0x65 ('e') is not a continuation byte
|
||||
m12 := model.NewDefaultMessage("mytopic", "journ\xe9\x65 termin\xe9\x65")
|
||||
require.Nil(t, s.AddMessage(m12))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "journ\uFFFDe termin\uFFFDe", messages[11].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_NullByte(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// 0x00: NUL byte — valid UTF-8 but rejected by PostgreSQL
|
||||
m := model.NewDefaultMessage("mytopic", "hello\x00world")
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "helloworld", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8InTitleAndTags(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Invalid UTF-8 can arrive via HTTP headers (Title, Tags) which bypass body validation
|
||||
m := model.NewDefaultMessage("mytopic", "valid message")
|
||||
m.Title = "\xc9clipse du syst\xe8me"
|
||||
m.Tags = []string{"probl\xe8me", "syst\xe9me"}
|
||||
m.Click = "https://example.com/\xae"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", messages[0].Title)
|
||||
require.Equal(t, "probl\uFFFDme", messages[0].Tags[0])
|
||||
require.Equal(t, "syst\uFFFDme", messages[0].Tags[1])
|
||||
require.Equal(t, "https://example.com/\uFFFD", messages[0].Click)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8BatchDoesNotDropValidMessages(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Previously, a single invalid message would roll back the entire batch transaction.
|
||||
// Sanitization ensures all messages in a batch are written successfully.
|
||||
msgs := []*model.Message{
|
||||
model.NewDefaultMessage("mytopic", "valid message 1"),
|
||||
model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta"),
|
||||
model.NewDefaultMessage("mytopic", "valid message 3"),
|
||||
}
|
||||
require.Nil(t, s.AddMessages(msgs))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(messages))
|
||||
})
|
||||
}
|
||||
@@ -42,8 +42,11 @@ extra:
|
||||
link: https://github.com/binwiederhier
|
||||
extra_javascript:
|
||||
- static/js/extra.js
|
||||
- static/js/bcrypt.js
|
||||
- static/js/config-generator.js
|
||||
extra_css:
|
||||
- static/css/extra.css
|
||||
- static/css/config-generator.css
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
|
||||
232
model/model.go
Normal file
232
model/model.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// List of possible events
|
||||
const (
|
||||
OpenEvent = "open"
|
||||
KeepaliveEvent = "keepalive"
|
||||
MessageEvent = "message"
|
||||
MessageDeleteEvent = "message_delete"
|
||||
MessageClearEvent = "message_clear"
|
||||
PollRequestEvent = "poll_request"
|
||||
)
|
||||
|
||||
// MessageIDLength is the length of a randomly generated message ID
|
||||
const MessageIDLength = 12
|
||||
|
||||
// Errors for message operations
|
||||
var (
|
||||
ErrUnexpectedMessageType = errors.New("unexpected message type")
|
||||
ErrMessageNotFound = errors.New("message not found")
|
||||
)
|
||||
|
||||
// Message represents a message published to a topic
|
||||
type Message struct {
|
||||
ID string `json:"id"` // Random message ID
|
||||
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*Action `json:"actions,omitempty"`
|
||||
Attachment *Attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
|
||||
// Context returns a log context for the message
|
||||
func (m *Message) Context() log.Context {
|
||||
fields := map[string]any{
|
||||
"topic": m.Topic,
|
||||
"message_id": m.ID,
|
||||
"message_sequence_id": m.SequenceID,
|
||||
"message_time": m.Time,
|
||||
"message_event": m.Event,
|
||||
"message_body_size": len(m.Message),
|
||||
}
|
||||
if m.Sender.IsValid() {
|
||||
fields["message_sender"] = m.Sender.String()
|
||||
}
|
||||
if m.User != "" {
|
||||
fields["message_user"] = m.User
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// SanitizeUTF8 replaces invalid UTF-8 sequences and strips NUL bytes from all user-supplied
|
||||
// string fields. This is called early in the publish path so that all downstream consumers
|
||||
// (Firebase, WebPush, SMTP, cache) receive clean UTF-8 strings.
|
||||
func (m *Message) SanitizeUTF8() {
|
||||
m.Topic = util.SanitizeUTF8(m.Topic)
|
||||
m.Message = util.SanitizeUTF8(m.Message)
|
||||
m.Title = util.SanitizeUTF8(m.Title)
|
||||
m.Click = util.SanitizeUTF8(m.Click)
|
||||
m.Icon = util.SanitizeUTF8(m.Icon)
|
||||
m.ContentType = util.SanitizeUTF8(m.ContentType)
|
||||
for i, tag := range m.Tags {
|
||||
m.Tags[i] = util.SanitizeUTF8(tag)
|
||||
}
|
||||
if m.Attachment != nil {
|
||||
m.Attachment.Name = util.SanitizeUTF8(m.Attachment.Name)
|
||||
m.Attachment.Type = util.SanitizeUTF8(m.Attachment.Type)
|
||||
m.Attachment.URL = util.SanitizeUTF8(m.Attachment.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// ForJSON returns a copy of the message suitable for JSON output.
|
||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||
func (m *Message) ForJSON() *Message {
|
||||
if m.SequenceID == m.ID {
|
||||
clone := *m
|
||||
clone.SequenceID = ""
|
||||
return &clone
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Attachment represents a file attachment on a message
|
||||
type Attachment struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// Action represents a user-defined action on a message
|
||||
type Action struct {
|
||||
ID string `json:"id"`
|
||||
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
||||
Label string `json:"label"` // action button label
|
||||
Clear bool `json:"clear"` // clear notification after successful execution
|
||||
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
||||
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
||||
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
||||
Body string `json:"body,omitempty"` // used in "http" action
|
||||
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
||||
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
||||
Value string `json:"value,omitempty"` // used in "copy" action
|
||||
}
|
||||
|
||||
// NewAction creates a new action with initialized maps
|
||||
func NewAction() *Action {
|
||||
return &Action{
|
||||
Headers: make(map[string]string),
|
||||
Extras: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessage creates a new message with the current timestamp
|
||||
func NewMessage(event, topic, msg string) *Message {
|
||||
return &Message{
|
||||
ID: util.RandomString(MessageIDLength),
|
||||
Time: time.Now().Unix(),
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOpenMessage is a convenience method to create an open message
|
||||
func NewOpenMessage(topic string) *Message {
|
||||
return NewMessage(OpenEvent, topic, "")
|
||||
}
|
||||
|
||||
// NewKeepaliveMessage is a convenience method to create a keepalive message
|
||||
func NewKeepaliveMessage(topic string) *Message {
|
||||
return NewMessage(KeepaliveEvent, topic, "")
|
||||
}
|
||||
|
||||
// NewDefaultMessage is a convenience method to create a notification message
|
||||
func NewDefaultMessage(topic, msg string) *Message {
|
||||
return NewMessage(MessageEvent, topic, msg)
|
||||
}
|
||||
|
||||
// NewActionMessage creates a new action message (message_delete or message_clear)
|
||||
func NewActionMessage(event, topic, sequenceID string) *Message {
|
||||
m := NewMessage(event, topic, "")
|
||||
m.SequenceID = sequenceID
|
||||
return m
|
||||
}
|
||||
|
||||
// NewPollRequestMessage is a convenience method to create a poll request message
|
||||
func NewPollRequestMessage(topic, pollID string) *Message {
|
||||
m := NewMessage(PollRequestEvent, topic, "New message")
|
||||
m.PollID = pollID
|
||||
return m
|
||||
}
|
||||
|
||||
// ValidMessageID returns true if the given string is a valid message ID
|
||||
func ValidMessageID(s string) bool {
|
||||
return util.ValidRandomString(s, MessageIDLength)
|
||||
}
|
||||
|
||||
// SinceMarker represents a point in time or message ID from which to retrieve messages
|
||||
type SinceMarker struct {
|
||||
time time.Time
|
||||
id string
|
||||
}
|
||||
|
||||
// NewSinceTime creates a new SinceMarker from a Unix timestamp
|
||||
func NewSinceTime(timestamp int64) SinceMarker {
|
||||
return SinceMarker{time.Unix(timestamp, 0), ""}
|
||||
}
|
||||
|
||||
// NewSinceID creates a new SinceMarker from a message ID
|
||||
func NewSinceID(id string) SinceMarker {
|
||||
return SinceMarker{time.Unix(0, 0), id}
|
||||
}
|
||||
|
||||
// IsAll returns true if this is the "all messages" marker
|
||||
func (t SinceMarker) IsAll() bool {
|
||||
return t == SinceAllMessages
|
||||
}
|
||||
|
||||
// IsNone returns true if this is the "no messages" marker
|
||||
func (t SinceMarker) IsNone() bool {
|
||||
return t == SinceNoMessages
|
||||
}
|
||||
|
||||
// IsLatest returns true if this is the "latest message" marker
|
||||
func (t SinceMarker) IsLatest() bool {
|
||||
return t == SinceLatestMessage
|
||||
}
|
||||
|
||||
// IsID returns true if this marker references a specific message ID
|
||||
func (t SinceMarker) IsID() bool {
|
||||
return t.id != "" && t.id != SinceLatestMessage.id
|
||||
}
|
||||
|
||||
// Time returns the time component of the marker
|
||||
func (t SinceMarker) Time() time.Time {
|
||||
return t.time
|
||||
}
|
||||
|
||||
// ID returns the message ID component of the marker
|
||||
func (t SinceMarker) ID() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
// Common SinceMarker values for subscribing to messages
|
||||
var (
|
||||
SinceAllMessages = SinceMarker{time.Unix(0, 0), ""}
|
||||
SinceNoMessages = SinceMarker{time.Unix(1, 0), ""}
|
||||
SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"}
|
||||
)
|
||||
488
s3/client.go
Normal file
488
s3/client.go
Normal file
@@ -0,0 +1,488 @@
|
||||
// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces,
|
||||
// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP
|
||||
// requests with AWS Signature V4 signing, no AWS SDK dependency required.
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects,
|
||||
// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix
|
||||
// are fixed at construction time. All operations target the same bucket and prefix.
|
||||
//
|
||||
// Fields must not be modified after the Client is passed to any method or goroutine.
|
||||
type Client struct {
|
||||
AccessKey string // AWS access key ID
|
||||
SecretKey string // AWS secret access key
|
||||
Region string // e.g. "us-east-1"
|
||||
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
|
||||
Bucket string // S3 bucket name
|
||||
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
|
||||
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
|
||||
HTTPClient *http.Client // if nil, http.DefaultClient is used
|
||||
}
|
||||
|
||||
// New creates a new S3 client from the given Config.
|
||||
func New(config *Config) *Client {
|
||||
return &Client{
|
||||
AccessKey: config.AccessKey,
|
||||
SecretKey: config.SecretKey,
|
||||
Region: config.Region,
|
||||
Endpoint: config.Endpoint,
|
||||
Bucket: config.Bucket,
|
||||
Prefix: config.Prefix,
|
||||
PathStyle: config.PathStyle,
|
||||
}
|
||||
}
|
||||
|
||||
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
|
||||
// configured prefix. The body size does not need to be known in advance.
|
||||
//
|
||||
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
|
||||
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
|
||||
// into memory.
|
||||
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
|
||||
first := make([]byte, partSize)
|
||||
n, err := io.ReadFull(body, first)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
|
||||
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
combined := io.MultiReader(bytes.NewReader(first), body)
|
||||
return c.putObjectMultipart(ctx, key, combined)
|
||||
}
|
||||
|
||||
// GetObject downloads an object. The key is automatically prefixed with the client's configured
|
||||
// prefix. The caller must close the returned ReadCloser.
|
||||
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
|
||||
}
|
||||
if resp.StatusCode/100 != 2 {
|
||||
err := parseError(resp)
|
||||
resp.Body.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return resp.Body, resp.ContentLength, nil
|
||||
}
|
||||
|
||||
// DeleteObjects removes multiple objects in a single batch request. Keys are automatically
|
||||
// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the
|
||||
// caller is responsible for batching if needed.
|
||||
//
|
||||
// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present
|
||||
// in the response, they are returned as a combined error.
|
||||
func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<Delete><Quiet>true</Quiet>")
|
||||
for _, key := range keys {
|
||||
body.WriteString("<Object><Key>")
|
||||
xml.EscapeText(&body, []byte(c.objectKey(key)))
|
||||
body.WriteString("</Key></Object>")
|
||||
}
|
||||
body.WriteString("</Delete>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
|
||||
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
|
||||
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
|
||||
|
||||
reqURL := c.bucketURL() + "?delete="
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
req.Header.Set("Content-MD5", contentMD5)
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
|
||||
// S3 may return HTTP 200 with per-key errors in the response body
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
|
||||
}
|
||||
var result deleteResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
|
||||
}
|
||||
if len(result.Errors) > 0 {
|
||||
var msgs []string
|
||||
for _, e := range result.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message))
|
||||
}
|
||||
return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
|
||||
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
|
||||
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
|
||||
query := url.Values{"list-type": {"2"}}
|
||||
if prefix := c.prefixForList(); prefix != "" {
|
||||
query.Set("prefix", prefix)
|
||||
}
|
||||
if continuationToken != "" {
|
||||
query.Set("continuation-token", continuationToken)
|
||||
}
|
||||
if maxKeys > 0 {
|
||||
query.Set("max-keys", strconv.Itoa(maxKeys))
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects: %w", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
|
||||
}
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
}
|
||||
var result listObjectsV2Response
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects XML: %w", err)
|
||||
}
|
||||
objects := make([]Object, len(result.Contents))
|
||||
for i, obj := range result.Contents {
|
||||
objects[i] = Object(obj)
|
||||
}
|
||||
return &ListResult{
|
||||
Objects: objects,
|
||||
IsTruncated: result.IsTruncated,
|
||||
NextContinuationToken: result.NextContinuationToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListAllObjects returns all objects under the client's configured prefix by paginating through
|
||||
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
|
||||
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
|
||||
const maxPages = 10000
|
||||
var all []Object
|
||||
var token string
|
||||
for page := 0; page < maxPages; page++ {
|
||||
result, err := c.ListObjects(ctx, token, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, result.Objects...)
|
||||
if !result.IsTruncated {
|
||||
return all, nil
|
||||
}
|
||||
token = result.NextContinuationToken
|
||||
}
|
||||
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
|
||||
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject request: %w", err)
|
||||
}
|
||||
req.ContentLength = size
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
|
||||
// chunks, uploading each as a separate part. This allows uploading without knowing the total
|
||||
// body size in advance.
|
||||
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
||||
// Step 1: Initiate multipart upload
|
||||
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Upload parts
|
||||
var parts []completedPart
|
||||
buf := make([]byte, partSize)
|
||||
partNumber := 1
|
||||
for {
|
||||
n, err := io.ReadFull(body, buf)
|
||||
if n > 0 {
|
||||
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
|
||||
if uploadErr != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return uploadErr
|
||||
}
|
||||
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
|
||||
partNumber++
|
||||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
|
||||
}
|
||||
|
||||
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
|
||||
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
|
||||
reqURL := c.objectURL(fullKey) + "?uploads"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = 0
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
|
||||
}
|
||||
var result initiateMultipartUploadResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
|
||||
}
|
||||
return result.UploadID, nil
|
||||
}
|
||||
|
||||
// uploadPart uploads a single part of a multipart upload and returns the ETag.
|
||||
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(data))
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
etag := resp.Header.Get("ETag")
|
||||
return etag, nil
|
||||
}
|
||||
|
||||
// completeMultipartUpload finalizes a multipart upload with the given parts.
|
||||
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<CompleteMultipartUpload>")
|
||||
for _, p := range parts {
|
||||
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
|
||||
}
|
||||
body.WriteString("</CompleteMultipartUpload>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
// Read response body to check for errors (S3 can return 200 with an error body)
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
|
||||
}
|
||||
// Check if the response contains an error
|
||||
var errResp ErrorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
errResp.StatusCode = resp.StatusCode
|
||||
return &errResp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
|
||||
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
|
||||
func (c *Client) signV4(req *http.Request, payloadHash string) {
|
||||
now := time.Now().UTC()
|
||||
datestamp := now.Format("20060102")
|
||||
amzDate := now.Format("20060102T150405Z")
|
||||
|
||||
// Required headers
|
||||
req.Header.Set("Host", c.hostHeader())
|
||||
req.Header.Set("X-Amz-Date", amzDate)
|
||||
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
|
||||
|
||||
// Canonical headers (all headers we set, sorted by lowercase key)
|
||||
signedKeys := make([]string, 0, len(req.Header))
|
||||
canonHeaders := make(map[string]string, len(req.Header))
|
||||
for k := range req.Header {
|
||||
lk := strings.ToLower(k)
|
||||
signedKeys = append(signedKeys, lk)
|
||||
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
|
||||
}
|
||||
sort.Strings(signedKeys)
|
||||
signedHeadersStr := strings.Join(signedKeys, ";")
|
||||
var chBuf strings.Builder
|
||||
for _, k := range signedKeys {
|
||||
chBuf.WriteString(k)
|
||||
chBuf.WriteByte(':')
|
||||
chBuf.WriteString(canonHeaders[k])
|
||||
chBuf.WriteByte('\n')
|
||||
}
|
||||
|
||||
// Canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
req.Method,
|
||||
canonicalURI(req.URL),
|
||||
canonicalQueryString(req.URL.Query()),
|
||||
chBuf.String(),
|
||||
signedHeadersStr,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// String to sign
|
||||
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
|
||||
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
|
||||
|
||||
// Signing key
|
||||
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
|
||||
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
|
||||
[]byte(c.Region)),
|
||||
[]byte("s3")),
|
||||
[]byte("aws4_request"))
|
||||
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
req.Header.Set("Authorization", fmt.Sprintf(
|
||||
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
c.AccessKey, credentialScope, signedHeadersStr, signature,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Client) httpClient() *http.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// objectKey prepends the configured prefix to the given key.
|
||||
func (c *Client) objectKey(key string) string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/" + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// prefixForList returns the prefix to use in ListObjectsV2 requests,
|
||||
// with a trailing slash so that only objects under the prefix directory are returned.
|
||||
func (c *Client) prefixForList() string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// bucketURL returns the base URL for bucket-level operations.
|
||||
func (c *Client) bucketURL() string {
|
||||
if c.PathStyle {
|
||||
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
|
||||
}
|
||||
|
||||
// objectURL returns the full URL for an object (key should already include the prefix).
|
||||
// Each path segment is URI-encoded to handle special characters in keys.
|
||||
func (c *Client) objectURL(key string) string {
|
||||
segments := strings.Split(key, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = uriEncode(seg)
|
||||
}
|
||||
return c.bucketURL() + "/" + strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
// hostHeader returns the value for the Host header.
|
||||
func (c *Client) hostHeader() string {
|
||||
if c.PathStyle {
|
||||
return c.Endpoint
|
||||
}
|
||||
return c.Bucket + "." + c.Endpoint
|
||||
}
|
||||
860
s3/client_test.go
Normal file
860
s3/client_test.go
Normal file
@@ -0,0 +1,860 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Mock S3 server ---
|
||||
//
|
||||
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() (*httptest.Server, *mockS3Server) {
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m), m
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[path] = body
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}
|
||||
|
||||
type listObjectsResponse struct {
|
||||
XMLName xml.Name `xml:"ListBucketResult"`
|
||||
Contents []listObject `xml:"Contents"`
|
||||
// Pagination support
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextContinuationToken string `xml:"NextContinuationToken"`
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
// bucketPath is just the bucket name
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Objects []struct {
|
||||
Key string `xml:"Key"`
|
||||
} `xml:"Object"`
|
||||
}
|
||||
if err := xml.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
for _, obj := range req.Objects {
|
||||
delete(m.objects, bucketPath+"/"+obj.Key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
prefix := r.URL.Query().Get("prefix")
|
||||
contToken := r.URL.Query().Get("continuation-token")
|
||||
|
||||
m.mu.RLock()
|
||||
var allKeys []string
|
||||
for key := range m.objects {
|
||||
objKey := strings.TrimPrefix(key, bucketPath+"/")
|
||||
if objKey == key {
|
||||
continue // different bucket
|
||||
}
|
||||
if prefix == "" || strings.HasPrefix(objKey, prefix) {
|
||||
allKeys = append(allKeys, objKey)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
sort.Strings(allKeys)
|
||||
|
||||
// Simple continuation token: it's the key to start after
|
||||
startIdx := 0
|
||||
if contToken != "" {
|
||||
for i, k := range allKeys {
|
||||
if k == contToken {
|
||||
startIdx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxKeys := 1000
|
||||
if mk := r.URL.Query().Get("max-keys"); mk != "" {
|
||||
fmt.Sscanf(mk, "%d", &maxKeys)
|
||||
}
|
||||
|
||||
endIdx := startIdx + maxKeys
|
||||
truncated := false
|
||||
nextToken := ""
|
||||
if endIdx < len(allKeys) {
|
||||
truncated = true
|
||||
nextToken = allKeys[endIdx-1]
|
||||
allKeys = allKeys[startIdx:endIdx]
|
||||
} else {
|
||||
allKeys = allKeys[startIdx:]
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
var contents []listObject
|
||||
for _, objKey := range allKeys {
|
||||
body := m.objects[bucketPath+"/"+objKey]
|
||||
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
resp := listObjectsResponse{
|
||||
Contents: contents,
|
||||
IsTruncated: truncated,
|
||||
NextContinuationToken: nextToken,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
xml.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) objectCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.objects)
|
||||
}
|
||||
|
||||
// --- Helper to create a test client pointing at mock server ---
|
||||
|
||||
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
return &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
}
|
||||
}
|
||||
|
||||
// --- URL parsing tests ---
|
||||
|
||||
func TestParseURL_Success(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "attachments", cfg.Prefix)
|
||||
require.Equal(t, "us-east-1", cfg.Region)
|
||||
require.Equal(t, "AKID", cfg.AccessKey)
|
||||
require.Equal(t, "SECRET", cfg.SecretKey)
|
||||
require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint)
|
||||
require.False(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_NoPrefix(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "", cfg.Prefix)
|
||||
}
|
||||
|
||||
func TestParseURL_WithEndpoint(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "prefix", cfg.Prefix)
|
||||
require.Equal(t, "s3.example.com", cfg.Endpoint)
|
||||
require.True(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_EndpointHTTP(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "localhost:9000", cfg.Endpoint)
|
||||
require.True(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_EndpointTrailingSlash(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "s3.example.com", cfg.Endpoint)
|
||||
}
|
||||
|
||||
func TestParseURL_NestedPrefix(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "a/b/c", cfg.Prefix)
|
||||
}
|
||||
|
||||
func TestParseURL_MissingRegion(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID:SECRET@my-bucket")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "region")
|
||||
}
|
||||
|
||||
func TestParseURL_MissingCredentials(t *testing.T) {
|
||||
_, err := ParseURL("s3://my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access key")
|
||||
}
|
||||
|
||||
func TestParseURL_MissingSecretKey(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID@my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "secret key")
|
||||
}
|
||||
|
||||
func TestParseURL_WrongScheme(t *testing.T) {
|
||||
_, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "scheme")
|
||||
}
|
||||
|
||||
func TestParseURL_EmptyBucket(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID:SECRET@?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bucket")
|
||||
}
|
||||
|
||||
// --- Unit tests: URL construction ---
|
||||
|
||||
func TestClient_BucketURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
|
||||
}
|
||||
|
||||
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "s3.example.com", c.hostHeader())
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
|
||||
}
|
||||
|
||||
func TestClient_ObjectKey(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
require.Equal(t, "attachments/file123", c.objectKey("file123"))
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
require.Equal(t, "file123", c2.objectKey("file123"))
|
||||
}
|
||||
|
||||
func TestClient_PrefixForList(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
require.Equal(t, "attachments/", c.prefixForList())
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
require.Equal(t, "", c2.prefixForList())
|
||||
}
|
||||
|
||||
// --- Integration tests using mock S3 server ---
|
||||
|
||||
func TestClient_PutGetObject(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
reader, size, err := client.GetObject(ctx, "test-key")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
}
|
||||
|
||||
func TestClient_PutGetObject_WithPrefix(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "test-key")
|
||||
require.Nil(t, err)
|
||||
data, _ := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Equal(t, "hello", string(data))
|
||||
}
|
||||
|
||||
func TestClient_GetObject_NotFound(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
_, _, err := client.GetObject(context.Background(), "nonexistent")
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
}
|
||||
|
||||
func TestClient_DeleteObjects(t *testing.T) {
|
||||
server, mock := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put several objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, 5, mock.objectCount())
|
||||
|
||||
// Delete some
|
||||
err := client.DeleteObjects(ctx, []string{"key-1", "key-3"})
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, mock.objectCount())
|
||||
|
||||
// Verify deleted ones are gone
|
||||
_, _, err = client.GetObject(ctx, "key-1")
|
||||
require.Error(t, err)
|
||||
_, _, err = client.GetObject(ctx, "key-3")
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify remaining ones are still there
|
||||
reader, _, err := client.GetObject(ctx, "key-0")
|
||||
require.Nil(t, err)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func TestClient_ListObjects(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Client with prefix "pfx": list should only return objects under pfx/
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
for i := 0; i < 3; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Also put an object outside the prefix using a no-prefix client
|
||||
clientNoPrefix := newTestClient(server, "my-bucket", "")
|
||||
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
|
||||
require.Nil(t, err)
|
||||
|
||||
// List with prefix client: should only see 3
|
||||
result, err := client.ListObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 3)
|
||||
require.False(t, result.IsTruncated)
|
||||
|
||||
// List with no-prefix client: should see all 4
|
||||
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 4)
|
||||
}
|
||||
|
||||
func TestClient_ListObjects_Pagination(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put 5 objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// List with max-keys=2
|
||||
result, err := client.ListObjects(ctx, "", 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 2)
|
||||
require.True(t, result.IsTruncated)
|
||||
require.NotEmpty(t, result.NextContinuationToken)
|
||||
|
||||
// Get next page
|
||||
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result2.Objects, 2)
|
||||
require.True(t, result2.IsTruncated)
|
||||
|
||||
// Get last page
|
||||
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result3.Objects, 1)
|
||||
require.False(t, result3.IsTruncated)
|
||||
}
|
||||
|
||||
func TestClient_ListAllObjects(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, 10)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_LargeBody(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1 MB object
|
||||
data := make([]byte, 1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "large", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "large")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
|
||||
data := make([]byte, 12*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "multipart")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(12*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ExactPartSize(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
|
||||
data := make([]byte, 5*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "exact")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(5*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_NestedKey(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
|
||||
require.Nil(t, err)
|
||||
data, _ := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Equal(t, "nested", string(data))
|
||||
}
|
||||
|
||||
// --- Scale test: 20k objects (ntfy-adjacent) ---
|
||||
|
||||
func TestClient_ListAllObjects_20k(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping 20k object test in short mode")
|
||||
}
|
||||
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "attachments")
|
||||
|
||||
ctx := context.Background()
|
||||
const numObjects = 20000
|
||||
const batchSize = 500
|
||||
|
||||
// Insert 20k objects in batches to keep it fast
|
||||
for batch := 0; batch < numObjects/batchSize; batch++ {
|
||||
for i := 0; i < batchSize; i++ {
|
||||
idx := batch*batchSize + i
|
||||
key := fmt.Sprintf("%08d", idx)
|
||||
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// List all 20k objects with pagination
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, numObjects)
|
||||
|
||||
// Verify total size
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
require.Equal(t, int64(numObjects), totalSize)
|
||||
|
||||
// Delete 1000 objects (simulating attachment expiry cleanup)
|
||||
keys := make([]string, 1000)
|
||||
for i := range keys {
|
||||
keys[i] = fmt.Sprintf("%08d", i)
|
||||
}
|
||||
err = client.DeleteObjects(ctx, keys)
|
||||
require.Nil(t, err)
|
||||
|
||||
// List again: should have 19000
|
||||
objects, err = client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, numObjects-1000)
|
||||
}
|
||||
|
||||
// --- Real S3 integration test ---
|
||||
//
|
||||
// Set the following environment variables to run this test against a real S3 bucket:
|
||||
//
|
||||
// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET
|
||||
//
|
||||
// Optional:
|
||||
//
|
||||
// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com")
|
||||
// S3_PATH_STYLE: set to "true" for path-style addressing
|
||||
// S3_PREFIX: key prefix to use (default: "ntfy-s3-test")
|
||||
func TestClient_RealBucket(t *testing.T) {
|
||||
accessKey := os.Getenv("S3_ACCESS_KEY")
|
||||
secretKey := os.Getenv("S3_SECRET_KEY")
|
||||
region := os.Getenv("S3_REGION")
|
||||
bucket := os.Getenv("S3_BUCKET")
|
||||
|
||||
if accessKey == "" || secretKey == "" || region == "" || bucket == "" {
|
||||
t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET")
|
||||
}
|
||||
|
||||
endpoint := os.Getenv("S3_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
|
||||
}
|
||||
pathStyle := os.Getenv("S3_PATH_STYLE") == "true"
|
||||
prefix := os.Getenv("S3_PREFIX")
|
||||
if prefix == "" {
|
||||
prefix = "ntfy-s3-test"
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
Endpoint: endpoint,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Clean up any leftover objects from previous runs
|
||||
existing, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
if len(existing) > 0 {
|
||||
keys := make([]string, len(existing))
|
||||
for i, obj := range existing {
|
||||
// Strip the prefix since DeleteObjects will re-add it
|
||||
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
|
||||
}
|
||||
// Batch delete in groups of 1000
|
||||
for i := 0; i < len(keys); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
err := client.DeleteObjects(ctx, keys[i:end])
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("PutGetDelete", func(t *testing.T) {
|
||||
key := "test-object"
|
||||
content := "hello from ntfy s3 test"
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, key, strings.NewReader(content))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(len(content)), size)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, content, string(data))
|
||||
|
||||
// Delete
|
||||
err = client.DeleteObjects(ctx, []string{key})
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get after delete should fail
|
||||
_, _, err = client.GetObject(ctx, key)
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ListObjects", func(t *testing.T) {
|
||||
// Use a sub-prefix client for isolation
|
||||
listClient := &Client{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
Endpoint: endpoint,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix + "/list-test",
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
|
||||
// Put 10 objects
|
||||
for i := 0; i < 10; i++ {
|
||||
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// List
|
||||
objects, err := listClient.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, 10)
|
||||
|
||||
// Clean up
|
||||
keys := make([]string, 10)
|
||||
for i := range keys {
|
||||
keys[i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
err = listClient.DeleteObjects(ctx, keys)
|
||||
require.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("LargeObject", func(t *testing.T) {
|
||||
key := "large-object"
|
||||
data := make([]byte, 5*1024*1024) // 5 MB
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
err := client.PutObject(ctx, key, bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(len(data)), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
|
||||
err = client.DeleteObjects(ctx, []string{key})
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
76
s3/types.go
Normal file
76
s3/types.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package s3
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
|
||||
type Config struct {
|
||||
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
|
||||
PathStyle bool
|
||||
Bucket string
|
||||
Prefix string
|
||||
Region string
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
}
|
||||
|
||||
// Object represents an S3 object returned by list operations.
|
||||
type Object struct {
|
||||
Key string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// ListResult holds the response from a ListObjectsV2 call.
|
||||
type ListResult struct {
|
||||
Objects []Object
|
||||
IsTruncated bool
|
||||
NextContinuationToken string
|
||||
}
|
||||
|
||||
// ErrorResponse is returned when S3 responds with a non-2xx status code.
|
||||
type ErrorResponse struct {
|
||||
StatusCode int
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Body string `xml:"-"` // raw response body
|
||||
}
|
||||
|
||||
func (e *ErrorResponse) Error() string {
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// listObjectsV2Response is the XML response from S3 ListObjectsV2
|
||||
type listObjectsV2Response struct {
|
||||
Contents []listObject `xml:"Contents"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextContinuationToken string `xml:"NextContinuationToken"`
|
||||
}
|
||||
|
||||
type listObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
}
|
||||
|
||||
// deleteResult is the XML response from S3 DeleteObjects
|
||||
type deleteResult struct {
|
||||
Errors []deleteError `xml:"Error"`
|
||||
}
|
||||
|
||||
type deleteError struct {
|
||||
Key string `xml:"Key"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
}
|
||||
|
||||
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
|
||||
type initiateMultipartUploadResult struct {
|
||||
UploadID string `xml:"UploadId"`
|
||||
}
|
||||
|
||||
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
|
||||
type completedPart struct {
|
||||
PartNumber int
|
||||
ETag string
|
||||
}
|
||||
166
s3/util.go
Normal file
166
s3/util.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
|
||||
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
|
||||
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
|
||||
unsignedPayload = "UNSIGNED-PAYLOAD"
|
||||
|
||||
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
|
||||
maxResponseBytes = 10 * 1024 * 1024
|
||||
|
||||
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
|
||||
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
|
||||
// part size of 5 MB for all parts except the last.
|
||||
partSize = 5 * 1024 * 1024
|
||||
)
|
||||
|
||||
// ParseURL parses an S3 URL of the form:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
//
|
||||
// When endpoint is specified, path-style addressing is enabled automatically.
|
||||
func ParseURL(s3URL string) (*Config, error) {
|
||||
u, err := url.Parse(s3URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme != "s3" {
|
||||
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return nil, fmt.Errorf("s3: bucket name must be specified as host")
|
||||
}
|
||||
bucket := u.Host
|
||||
prefix := strings.TrimPrefix(u.Path, "/")
|
||||
accessKey := u.User.Username()
|
||||
secretKey, _ := u.User.Password()
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
|
||||
}
|
||||
region := u.Query().Get("region")
|
||||
if region == "" {
|
||||
return nil, fmt.Errorf("s3: region query parameter is required")
|
||||
}
|
||||
endpointParam := u.Query().Get("endpoint")
|
||||
var endpoint string
|
||||
var pathStyle bool
|
||||
if endpointParam != "" {
|
||||
// Custom endpoint: strip scheme prefix to extract host[:port]
|
||||
ep := strings.TrimRight(endpointParam, "/")
|
||||
ep = strings.TrimPrefix(ep, "https://")
|
||||
ep = strings.TrimPrefix(ep, "http://")
|
||||
endpoint = ep
|
||||
pathStyle = true
|
||||
} else {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
|
||||
pathStyle = false
|
||||
}
|
||||
return &Config{
|
||||
Endpoint: endpoint,
|
||||
PathStyle: pathStyle,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
Region: region,
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseError reads an S3 error response and returns an *ErrorResponse.
|
||||
func parseError(resp *http.Response) error {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: reading error response: %w", err)
|
||||
}
|
||||
return parseErrorFromBytes(resp.StatusCode, body)
|
||||
}
|
||||
|
||||
func parseErrorFromBytes(statusCode int, body []byte) error {
|
||||
errResp := &ErrorResponse{
|
||||
StatusCode: statusCode,
|
||||
Body: string(body),
|
||||
}
|
||||
// Try to parse XML error; if it fails, we still have StatusCode and Body
|
||||
_ = xml.Unmarshal(body, errResp)
|
||||
return errResp
|
||||
}
|
||||
|
||||
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
|
||||
// percent-encoded per RFC 3986; forward slashes are preserved.
|
||||
func canonicalURI(u *url.URL) string {
|
||||
p := u.Path
|
||||
if p == "" {
|
||||
return "/"
|
||||
}
|
||||
segments := strings.Split(p, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = uriEncode(seg)
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
// canonicalQueryString builds the query string for the canonical request. Keys and values
|
||||
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
|
||||
func canonicalQueryString(values url.Values) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for k := range values {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var pairs []string
|
||||
for _, k := range keys {
|
||||
ek := uriEncode(k)
|
||||
vs := make([]string, len(values[k]))
|
||||
copy(vs, values[k])
|
||||
sort.Strings(vs)
|
||||
for _, v := range vs {
|
||||
pairs = append(pairs, ek+"="+uriEncode(v))
|
||||
}
|
||||
}
|
||||
return strings.Join(pairs, "&")
|
||||
}
|
||||
|
||||
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
|
||||
// characters (A-Z a-z 0-9 - _ . ~).
|
||||
func uriEncode(s string) string {
|
||||
var buf strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
buf.WriteByte(b)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", b)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
181
s3/util_test.go
Normal file
181
s3/util_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestURIEncode(t *testing.T) {
|
||||
// Unreserved characters are not encoded
|
||||
require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz"))
|
||||
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
|
||||
require.Equal(t, "0123456789", uriEncode("0123456789"))
|
||||
require.Equal(t, "-_.~", uriEncode("-_.~"))
|
||||
|
||||
// Spaces use %20, not +
|
||||
require.Equal(t, "hello%20world", uriEncode("hello world"))
|
||||
|
||||
// Slashes are encoded (canonicalURI handles slash splitting separately)
|
||||
require.Equal(t, "a%2Fb", uriEncode("a/b"))
|
||||
|
||||
// Special characters
|
||||
require.Equal(t, "%2B", uriEncode("+"))
|
||||
require.Equal(t, "%3D", uriEncode("="))
|
||||
require.Equal(t, "%26", uriEncode("&"))
|
||||
require.Equal(t, "%40", uriEncode("@"))
|
||||
require.Equal(t, "%23", uriEncode("#"))
|
||||
|
||||
// Mixed
|
||||
require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt"))
|
||||
require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes"))
|
||||
|
||||
// Empty string
|
||||
require.Equal(t, "", uriEncode(""))
|
||||
}
|
||||
|
||||
func TestCanonicalURI(t *testing.T) {
|
||||
// Simple path
|
||||
u, _ := url.Parse("https://example.com/bucket/key")
|
||||
require.Equal(t, "/bucket/key", canonicalURI(u))
|
||||
|
||||
// Root path
|
||||
u, _ = url.Parse("https://example.com/")
|
||||
require.Equal(t, "/", canonicalURI(u))
|
||||
|
||||
// Empty path
|
||||
u, _ = url.Parse("https://example.com")
|
||||
require.Equal(t, "/", canonicalURI(u))
|
||||
|
||||
// Path with special characters
|
||||
u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces")
|
||||
require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u))
|
||||
|
||||
// Nested path
|
||||
u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt")
|
||||
require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u))
|
||||
}
|
||||
|
||||
func TestCanonicalQueryString(t *testing.T) {
|
||||
// Multiple keys sorted alphabetically
|
||||
vals := url.Values{
|
||||
"prefix": {"test/"},
|
||||
"list-type": {"2"},
|
||||
}
|
||||
require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals))
|
||||
|
||||
// Empty values
|
||||
require.Equal(t, "", canonicalQueryString(url.Values{}))
|
||||
|
||||
// Single key
|
||||
require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}}))
|
||||
|
||||
// Key with multiple values (sorted)
|
||||
vals = url.Values{"key": {"b", "a"}}
|
||||
require.Equal(t, "key=a&key=b", canonicalQueryString(vals))
|
||||
|
||||
// Keys requiring encoding
|
||||
vals = url.Values{"continuation-token": {"abc+def"}}
|
||||
require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals))
|
||||
}
|
||||
|
||||
func TestSHA256Hex(t *testing.T) {
|
||||
// SHA-256 of empty string
|
||||
require.Equal(t, emptyPayloadHash, sha256Hex([]byte("")))
|
||||
|
||||
// SHA-256 of known value
|
||||
require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello")))
|
||||
}
|
||||
|
||||
func TestHmacSHA256(t *testing.T) {
|
||||
// Known test vector: HMAC-SHA256("key", "message")
|
||||
result := hmacSHA256([]byte("key"), []byte("message"))
|
||||
require.Len(t, result, 32) // SHA-256 produces 32 bytes
|
||||
require.NotEqual(t, make([]byte, 32), result)
|
||||
|
||||
// Same inputs should produce same output
|
||||
result2 := hmacSHA256([]byte("key"), []byte("message"))
|
||||
require.Equal(t, result, result2)
|
||||
|
||||
// Different inputs should produce different output
|
||||
result3 := hmacSHA256([]byte("different-key"), []byte("message"))
|
||||
require.NotEqual(t, result, result3)
|
||||
}
|
||||
|
||||
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
|
||||
c := &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
|
||||
// All required SigV4 headers must be set
|
||||
require.NotEmpty(t, req.Header.Get("Host"))
|
||||
require.NotEmpty(t, req.Header.Get("X-Amz-Date"))
|
||||
require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256"))
|
||||
|
||||
// Authorization header must have correct format
|
||||
auth := req.Header.Get("Authorization")
|
||||
require.Contains(t, auth, "AWS4-HMAC-SHA256")
|
||||
require.Contains(t, auth, "Credential=AKID/")
|
||||
require.Contains(t, auth, "/us-east-1/s3/aws4_request")
|
||||
require.Contains(t, auth, "SignedHeaders=")
|
||||
require.Contains(t, auth, "Signature=")
|
||||
}
|
||||
|
||||
func TestSignV4_UnsignedPayload(t *testing.T) {
|
||||
c := &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, unsignedPayload)
|
||||
|
||||
require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256"))
|
||||
}
|
||||
|
||||
func TestSignV4_DifferentRegions(t *testing.T) {
|
||||
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
|
||||
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
|
||||
c1.signV4(req1, emptyPayloadHash)
|
||||
|
||||
req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil)
|
||||
c2.signV4(req2, emptyPayloadHash)
|
||||
|
||||
// Different regions should produce different signatures
|
||||
require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func TestParseError_XMLResponse(t *testing.T) {
|
||||
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
|
||||
err := parseErrorFromBytes(404, xmlBody)
|
||||
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
require.Equal(t, "The specified key does not exist.", errResp.Message)
|
||||
}
|
||||
|
||||
func TestParseError_NonXMLResponse(t *testing.T) {
|
||||
err := parseErrorFromBytes(500, []byte("internal server error"))
|
||||
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 500, errResp.StatusCode)
|
||||
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
|
||||
require.Contains(t, errResp.Body, "internal server error")
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -39,7 +40,7 @@ type actionParser struct {
|
||||
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
|
||||
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
|
||||
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
|
||||
func parseActions(s string) (actions []*action, err error) {
|
||||
func parseActions(s string) (actions []*model.Action, err error) {
|
||||
// Parse JSON or simple format
|
||||
s = strings.TrimSpace(s)
|
||||
if strings.HasPrefix(s, "[") {
|
||||
@@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) {
|
||||
}
|
||||
|
||||
// parseActionsFromJSON converts a JSON array into an array of actions
|
||||
func parseActionsFromJSON(s string) ([]*action, error) {
|
||||
actions := make([]*action, 0)
|
||||
func parseActionsFromJSON(s string) ([]*model.Action, error) {
|
||||
actions := make([]*model.Action, 0)
|
||||
if err := json.Unmarshal([]byte(s), &actions); err != nil {
|
||||
return nil, fmt.Errorf("JSON error: %w", err)
|
||||
}
|
||||
@@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) {
|
||||
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
|
||||
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
|
||||
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
|
||||
func parseActionsFromSimple(s string) ([]*action, error) {
|
||||
func parseActionsFromSimple(s string) ([]*model.Action, error) {
|
||||
if !utf8.ValidString(s) {
|
||||
return nil, errors.New("invalid utf-8 string")
|
||||
}
|
||||
@@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) {
|
||||
}
|
||||
|
||||
// Parse loops trough parseAction() until the end of the string is reached
|
||||
func (p *actionParser) Parse() ([]*action, error) {
|
||||
actions := make([]*action, 0)
|
||||
func (p *actionParser) Parse() ([]*model.Action, error) {
|
||||
actions := make([]*model.Action, 0)
|
||||
for !p.eof() {
|
||||
a, err := p.parseAction()
|
||||
if err != nil {
|
||||
@@ -134,8 +135,8 @@ func (p *actionParser) Parse() ([]*action, error) {
|
||||
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
|
||||
// and then uses populateAction to interpret the keys/values. The function terminates
|
||||
// when EOF or ";" is reached.
|
||||
func (p *actionParser) parseAction() (*action, error) {
|
||||
a := newAction()
|
||||
func (p *actionParser) parseAction() (*model.Action, error) {
|
||||
a := model.NewAction()
|
||||
section := 0
|
||||
for {
|
||||
key, value, last, err := p.parseSection()
|
||||
@@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) {
|
||||
|
||||
// populateAction is the "business logic" of the parser. It applies the key/value
|
||||
// pair to the action instance.
|
||||
func populateAction(newAction *action, section int, key, value string) error {
|
||||
func populateAction(newAction *model.Action, section int, key, value string) error {
|
||||
// Auto-expand keys based on their index
|
||||
if key == "" && section == 0 {
|
||||
key = "action"
|
||||
|
||||
@@ -95,6 +95,8 @@ type Config struct {
|
||||
ListenUnixMode fs.FileMode
|
||||
KeyFile string
|
||||
CertFile string
|
||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||
DatabaseReplicaURLs []string // PostgreSQL read replica connection strings
|
||||
FirebaseKeyFile string
|
||||
CacheFile string
|
||||
CacheDuration time.Duration
|
||||
@@ -110,6 +112,7 @@ type Config struct {
|
||||
AuthBcryptCost int
|
||||
AuthStatsQueueWriterInterval time.Duration
|
||||
AttachmentCacheDir string
|
||||
AttachmentS3URL string
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
AttachmentExpiryDuration time.Duration
|
||||
@@ -199,6 +202,7 @@ func NewConfig() *Config {
|
||||
ListenUnixMode: 0,
|
||||
KeyFile: "",
|
||||
CertFile: "",
|
||||
DatabaseURL: "",
|
||||
FirebaseKeyFile: "",
|
||||
CacheFile: "",
|
||||
CacheDuration: DefaultCacheDuration,
|
||||
|
||||
@@ -142,6 +142,7 @@ var (
|
||||
errHTTPBadRequestTemplateFileNotFound = &errHTTP{40047, http.StatusBadRequest, "invalid request: template file not found", "https://ntfy.sh/docs/publish/#message-templating", nil}
|
||||
errHTTPBadRequestTemplateFileInvalid = &errHTTP{40048, http.StatusBadRequest, "invalid request: template file invalid", "https://ntfy.sh/docs/publish/#message-templating", nil}
|
||||
errHTTPBadRequestSequenceIDInvalid = &errHTTP{40049, http.StatusBadRequest, "invalid request: sequence ID invalid", "https://ntfy.sh/docs/publish/#updating-deleting-notifications", nil}
|
||||
errHTTPBadRequestEmailAddressInvalid = &errHTTP{40050, http.StatusBadRequest, "invalid request: invalid e-mail address", "https://ntfy.sh/docs/publish/#e-mail-notifications", nil}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
oneKilobyteArray = make([]byte, 1024)
|
||||
)
|
||||
|
||||
func TestFileCache_Write_Success(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
|
||||
require.Equal(t, int64(11), c.Size())
|
||||
require.Equal(t, int64(10229), c.Remaining())
|
||||
}
|
||||
|
||||
func TestFileCache_Write_Remove_Success(t *testing.T) {
|
||||
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
|
||||
for i := 0; i < 10; i++ { // 10x999 = 9990
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(999), size)
|
||||
}
|
||||
require.Equal(t, int64(9990), c.Size())
|
||||
require.Equal(t, int64(250), c.Remaining())
|
||||
require.FileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk5")
|
||||
|
||||
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.NoFileExists(t, dir+"/abcdefghijk5")
|
||||
require.Equal(t, int64(7992), c.Size())
|
||||
require.Equal(t, int64(2248), c.Remaining())
|
||||
}
|
||||
|
||||
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
for i := 0; i < 10; i++ {
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024), size)
|
||||
}
|
||||
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkX")
|
||||
}
|
||||
|
||||
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkl")
|
||||
}
|
||||
|
||||
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
|
||||
dir = t.TempDir()
|
||||
cache, err := newFileCache(dir, 10*1024)
|
||||
require.Nil(t, err)
|
||||
return dir, cache
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, f string) string {
|
||||
b, err := os.ReadFile(f)
|
||||
require.Nil(t, err)
|
||||
return string(b)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -23,7 +24,6 @@ const (
|
||||
tagSMTP = "smtp" // Receive email
|
||||
tagEmail = "email" // Send email
|
||||
tagTwilio = "twilio"
|
||||
tagFileCache = "file_cache"
|
||||
tagMessageCache = "message_cache"
|
||||
tagStripe = "stripe"
|
||||
tagAccount = "account"
|
||||
@@ -35,7 +35,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage}
|
||||
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage, http.StatusRequestEntityTooLarge}
|
||||
rateLimitingErrorCodes = []int{http.StatusTooManyRequests, http.StatusRequestEntityTooLarge}
|
||||
)
|
||||
|
||||
@@ -55,12 +55,12 @@ func logvr(v *visitor, r *http.Request) *log.Event {
|
||||
}
|
||||
|
||||
// logvrm creates a new log event with HTTP request, visitor fields and message fields
|
||||
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
|
||||
func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event {
|
||||
return logvr(v, r).With(m)
|
||||
}
|
||||
|
||||
// logvrm creates a new log event with visitor fields and message fields
|
||||
func logvm(v *visitor, m *message) *log.Event {
|
||||
func logvm(v *visitor, m *model.Message) *log.Event {
|
||||
return logv(v).With(m)
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,825 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSqliteCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessages(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "my message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
require.Equal(t, messageEvent, messages[0].Event)
|
||||
require.Equal(t, "", messages[0].Title)
|
||||
require.Equal(t, 0, messages[0].Priority)
|
||||
require.Nil(t, messages[0].Tags)
|
||||
require.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
|
||||
require.Empty(t, messages)
|
||||
|
||||
// mytopic: since m1 (by ID)
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, m2.ID, messages[0].ID)
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = c.Messages("mytopic", newSinceTime(2), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// mytopic: latest
|
||||
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["example"])
|
||||
|
||||
// example: since all
|
||||
messages, _ = c.Messages("example", sinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, counts["doesnotexist"])
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesLock(t *testing.T, c *messageCache) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message")))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
m3 := newDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||
m4 := newDefaultMessage("mytopic2", "message 4")
|
||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
|
||||
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||
require.Equal(t, "message 2", messages[2].Message)
|
||||
|
||||
messages, _ = c.MessagesDue()
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheTopics(t *testing.T, c *messageCache) {
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
|
||||
|
||||
topics, err := c.Topics()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, 2, len(topics))
|
||||
require.Equal(t, "topic1", topics["topic1"].ID)
|
||||
require.Equal(t, "topic2", topics["topic2"].ID)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
require.Equal(t, 5, messages[0].Priority)
|
||||
require.Equal(t, "some title", messages[0].Title)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = 200
|
||||
m3 := newDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||
m4 := newDefaultMessage("mytopic", "message 4")
|
||||
m4.Time = 400
|
||||
m5 := newDefaultMessage("mytopic", "message 5")
|
||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||
m6 := newDefaultMessage("mytopic", "message 6")
|
||||
m6.Time = 600
|
||||
m7 := newDefaultMessage("mytopic", "message 7")
|
||||
m7.Time = 700
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
require.Nil(t, c.AddMessage(m4))
|
||||
require.Nil(t, c.AddMessage(m5))
|
||||
require.Nil(t, c.AddMessage(m6))
|
||||
require.Nil(t, c.AddMessage(m7))
|
||||
|
||||
// Case 1: Since ID exists, exclude scheduled
|
||||
messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false)
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
|
||||
// Case 2: Since ID exists, include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true)
|
||||
require.Equal(t, 5, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message)
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||
|
||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true)
|
||||
require.Equal(t, 7, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 2", messages[1].Message)
|
||||
require.Equal(t, "message 4", messages[2].Message)
|
||||
require.Equal(t, "message 6", messages[3].Message)
|
||||
require.Equal(t, "message 7", messages[4].Message)
|
||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||
|
||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "message 5", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Prune(t *testing.T) {
|
||||
testCachePrune(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Prune(t *testing.T) {
|
||||
testCachePrune(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCachePrune(t *testing.T, c *messageCache) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = now - 10
|
||||
m1.Expires = now - 5
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = now - 5
|
||||
m2.Expires = now + 5 // In the future
|
||||
|
||||
m3 := newDefaultMessage("another_topic", "and another one")
|
||||
m3.Time = now - 12
|
||||
m3.Expires = now - 2
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
require.Equal(t, 1, counts["another_topic"])
|
||||
|
||||
expiredMessageIDs, err := c.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
|
||||
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["mytopic"])
|
||||
require.Equal(t, 0, counts["another_topic"])
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 5000,
|
||||
Expires: expires1,
|
||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: expires2,
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.User = "u_BAsbaAa"
|
||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||
m.Attachment = &attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: expires3,
|
||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
require.Equal(t, "flower for you", messages[0].Message)
|
||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(10000), size)
|
||||
|
||||
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||
|
||||
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(20000), size)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.SequenceID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
ids, err := c.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 0" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
require.Equal(t, "some message 5", messages[5].Message)
|
||||
require.Equal(t, "", messages[5].Title)
|
||||
require.Nil(t, messages[5].Tags)
|
||||
require.Equal(t, 0, messages[5].Priority)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 1" schema
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(512) NOT NULL,
|
||||
title VARCHAR(256) NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags VARCHAR(256) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
|
||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, c.AddMessage(delayedMessage))
|
||||
|
||||
// 10, not 11!
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
|
||||
// 11!
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
require.Nil(t, rows.Scan(&indexName))
|
||||
require.Equal(t, "idx_topic", indexName)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From9(t *testing.T) {
|
||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 8" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
insertQuery := `
|
||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(
|
||||
insertQuery,
|
||||
fmt.Sprintf("abcd%d", i),
|
||||
time.Now().Unix(),
|
||||
"mytopic",
|
||||
fmt.Sprintf("some message %d", i),
|
||||
"", // title
|
||||
0, // priority
|
||||
"", // tags
|
||||
"", // click
|
||||
"", // icon
|
||||
"", // actions
|
||||
"", // attachment_name
|
||||
"", // attachment_type
|
||||
0, // attachment_size
|
||||
0, // attachment_type
|
||||
"", // attachment_url
|
||||
"9.9.9.9", // sender
|
||||
"", // encoding
|
||||
1, // published
|
||||
)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Create cache to trigger migration
|
||||
cacheDuration := 17 * time.Hour
|
||||
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Check version
|
||||
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var version int
|
||||
require.Nil(t, rows.Scan(&version))
|
||||
require.Equal(t, currentSchemaVersion, version)
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
for _, m := range messages {
|
||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_WAL(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := `pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;`
|
||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.FileExists(t, filename+"-wal")
|
||||
require.FileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_None(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := ""
|
||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.NoFileExists(t, filename+"-wal")
|
||||
require.NoFileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_Fail(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := `xx error`
|
||||
_, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Sender(t *testing.T) {
|
||||
testSender(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Sender(t *testing.T) {
|
||||
testSender(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testSender(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "mymessage without sender")
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||
}
|
||||
|
||||
func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) {
|
||||
// Create a scheduled (unpublished) message
|
||||
scheduledMsg := newDefaultMessage("mytopic", "scheduled message")
|
||||
scheduledMsg.ID = "scheduled1"
|
||||
scheduledMsg.SequenceID = "seq123"
|
||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||
require.Nil(t, c.AddMessage(scheduledMsg))
|
||||
|
||||
// Create a published message with different sequence ID
|
||||
publishedMsg := newDefaultMessage("mytopic", "published message")
|
||||
publishedMsg.ID = "published1"
|
||||
publishedMsg.SequenceID = "seq456"
|
||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||
require.Nil(t, c.AddMessage(publishedMsg))
|
||||
|
||||
// Create a scheduled message in a different topic
|
||||
otherTopicMsg := newDefaultMessage("othertopic", "other scheduled")
|
||||
otherTopicMsg.ID = "other1"
|
||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, c.AddMessage(otherTopicMsg))
|
||||
|
||||
// Verify all messages exist (including scheduled)
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Delete scheduled message by sequence ID and verify returned IDs
|
||||
deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(deletedIDs))
|
||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||
|
||||
// Verify scheduled message is deleted
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
|
||||
// Verify other topic's message still exists (topic-scoped deletion)
|
||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "other scheduled", messages[0].Message)
|
||||
|
||||
// Deleting non-existent sequence ID should return empty list
|
||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
// Deleting published message should not affect it (only deletes unpublished)
|
||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
}
|
||||
|
||||
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
func TestMemCache_NopCache(t *testing.T) {
|
||||
c, _ := newNopCache()
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, messages)
|
||||
|
||||
topics, err := c.Topics()
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestCache(t *testing.T) *messageCache {
|
||||
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSqliteTestCacheFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
|
||||
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
}
|
||||
|
||||
func newMemTestCache(t *testing.T) *messageCache {
|
||||
c, err := newMemCache()
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
}
|
||||
239
server/server.go
239
server/server.go
@@ -32,16 +32,23 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/attachment"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"heckel.io/ntfy/v2/util/sprig"
|
||||
"heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
type Server struct {
|
||||
config *Config
|
||||
db *db.DB // Shared PostgreSQL connection pool (with optional replicas), nil when using SQLite
|
||||
httpServer *http.Server
|
||||
httpsServer *http.Server
|
||||
httpMetricsServer *http.Server
|
||||
@@ -56,9 +63,9 @@ type Server struct {
|
||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
webPush *webPushStore // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
messageCache *message.Cache // Database that stores the messages
|
||||
webPush *webpush.Store // Database that stores web push subscriptions
|
||||
fileCache attachment.Store // Attachment store (file system or S3)
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
|
||||
@@ -90,6 +97,7 @@ var (
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
metricsPath = "/metrics"
|
||||
apiHealthPath = "/v1/health"
|
||||
apiVersionPath = "/v1/version"
|
||||
apiConfigPath = "/v1/config"
|
||||
apiStatsPath = "/v1/stats"
|
||||
apiWebPushPath = "/v1/webpush"
|
||||
@@ -115,6 +123,7 @@ var (
|
||||
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
|
||||
urlRegex = regexp.MustCompile(`^https?://`)
|
||||
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
|
||||
emailAddressRegex = regexp.MustCompile(`^[^\s,;]+@[^\s,;]+$`)
|
||||
|
||||
//go:embed site
|
||||
webFs embed.FS
|
||||
@@ -171,36 +180,64 @@ func New(conf *Config) (*Server, error) {
|
||||
if payments.Available && conf.StripeSecretKey != "" {
|
||||
stripe = newStripeAPI()
|
||||
}
|
||||
messageCache, err := createMessageCache(conf)
|
||||
// Open shared PostgreSQL connection pool if configured
|
||||
var pool *db.DB
|
||||
if conf.DatabaseURL != "" {
|
||||
primary, err := pg.Open(conf.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var replicas []*db.Host
|
||||
for _, replicaURL := range conf.DatabaseReplicaURLs {
|
||||
r, err := pg.OpenReplica(replicaURL)
|
||||
if err != nil {
|
||||
// Close already-opened replicas before returning
|
||||
for _, opened := range replicas {
|
||||
opened.DB.Close()
|
||||
}
|
||||
primary.DB.Close()
|
||||
return nil, fmt.Errorf("failed to open database replica: %w", err)
|
||||
}
|
||||
replicas = append(replicas, r)
|
||||
}
|
||||
pool = db.New(primary, replicas)
|
||||
}
|
||||
messageCache, err := createMessageCache(conf, pool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var webPush *webPushStore
|
||||
var wp *webpush.Store
|
||||
if conf.WebPushPublicKey != "" {
|
||||
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
if pool != nil {
|
||||
wp, err = webpush.NewPostgresStore(pool)
|
||||
} else {
|
||||
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
topics, err := messageCache.Topics()
|
||||
topicIDs, err := messageCache.Topics()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics := make(map[string]*topic, len(topicIDs))
|
||||
for _, id := range topicIDs {
|
||||
topics[id] = newTopic(id)
|
||||
}
|
||||
messages, err := messageCache.Stats()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fileCache *fileCache
|
||||
if conf.AttachmentCacheDir != "" {
|
||||
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileCache, err := createAttachmentStore(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
if conf.AuthFile != "" || pool != nil {
|
||||
authConfig := &user.Config{
|
||||
Filename: conf.AuthFile,
|
||||
DatabaseURL: conf.DatabaseURL,
|
||||
StartupQueries: conf.AuthStartupQueries,
|
||||
DefaultAccess: conf.AuthDefault,
|
||||
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||
@@ -210,7 +247,11 @@ func New(conf *Config) (*Server, error) {
|
||||
BcryptCost: conf.AuthBcryptCost,
|
||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||
}
|
||||
userManager, err = user.NewManager(authConfig)
|
||||
if pool != nil {
|
||||
userManager, err = user.NewPostgresManager(pool, authConfig)
|
||||
} else {
|
||||
userManager, err = user.NewSQLiteManager(conf.AuthFile, conf.AuthStartupQueries, authConfig)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -231,8 +272,9 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
s := &Server{
|
||||
config: conf,
|
||||
db: pool,
|
||||
messageCache: messageCache,
|
||||
webPush: webPush,
|
||||
webPush: wp,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
@@ -247,13 +289,24 @@ func New(conf *Config) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func createMessageCache(conf *Config) (*messageCache, error) {
|
||||
func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
|
||||
if conf.CacheDuration == 0 {
|
||||
return newNopCache()
|
||||
return message.NewNopStore()
|
||||
} else if pool != nil {
|
||||
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||
} else if conf.CacheFile != "" {
|
||||
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||
}
|
||||
return newMemCache()
|
||||
return message.NewMemStore()
|
||||
}
|
||||
|
||||
func createAttachmentStore(conf *Config) (attachment.Store, error) {
|
||||
if conf.AttachmentS3URL != "" {
|
||||
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
|
||||
} else if conf.AttachmentCacheDir != "" {
|
||||
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
||||
@@ -388,6 +441,9 @@ func (s *Server) closeDatabases() {
|
||||
if s.webPush != nil {
|
||||
s.webPush.Close()
|
||||
}
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// handle is the main entry point for all HTTP requests
|
||||
@@ -467,6 +523,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
||||
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
|
||||
return s.handleHealth(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiVersionPath {
|
||||
return s.ensureAdmin(s.handleVersion)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiConfigPath {
|
||||
return s.handleConfig(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
|
||||
@@ -702,7 +760,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor)
|
||||
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
|
||||
// can associate the download bandwidth with the uploader.
|
||||
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.AttachmentCacheDir == "" {
|
||||
if s.fileCache == nil {
|
||||
return errHTTPInternalError
|
||||
}
|
||||
matches := fileRegex.FindStringSubmatch(r.URL.Path)
|
||||
@@ -710,16 +768,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
return errHTTPInternalErrorInvalidPath
|
||||
}
|
||||
messageID := matches[1]
|
||||
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
|
||||
stat, err := os.Stat(file)
|
||||
reader, size, err := s.fileCache.Read(messageID)
|
||||
if err != nil {
|
||||
return errHTTPNotFound.Fields(log.Context{
|
||||
"message_id": messageID,
|
||||
"error_context": "filesystem",
|
||||
"error_context": "attachment_store",
|
||||
})
|
||||
}
|
||||
defer reader.Close()
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
|
||||
if r.Method == http.MethodHead {
|
||||
return nil
|
||||
}
|
||||
@@ -728,11 +786,11 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
|
||||
// - and also uses the higher bandwidth limits of a paying user
|
||||
m, err := s.messageCache.Message(messageID)
|
||||
if errors.Is(err, errMessageNotFound) {
|
||||
if errors.Is(err, model.ErrMessageNotFound) {
|
||||
if s.config.CacheBatchTimeout > 0 {
|
||||
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
||||
// and messages are persisted asynchronously, retry fetching from the database
|
||||
m, err = util.Retry(func() (*message, error) {
|
||||
m, err = util.Retry(func() (*model.Message, error) {
|
||||
return s.messageCache.Message(messageID)
|
||||
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
|
||||
}
|
||||
@@ -755,19 +813,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
} else if m.Sender.IsValid() {
|
||||
bandwidthVisitor = s.visitor(m.Sender, nil)
|
||||
}
|
||||
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
|
||||
if !bandwidthVisitor.BandwidthAllowed(size) {
|
||||
return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
|
||||
}
|
||||
// Actually send file
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if m.Attachment.Name != "" {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
|
||||
}
|
||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -778,7 +831,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||
return writeMatrixDiscoveryResponse(w)
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
|
||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) {
|
||||
start := time.Now()
|
||||
t, err := fromContext[*topic](r, contextTopic)
|
||||
if err != nil {
|
||||
@@ -792,7 +845,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := newDefaultMessage(t.ID, "")
|
||||
m := model.NewDefaultMessage(t.ID, "")
|
||||
cache, firebase, email, call, template, unifiedpush, priorityStr, e := s.parsePublishParams(r, m)
|
||||
if e != nil {
|
||||
return nil, e.With(t)
|
||||
@@ -817,7 +870,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
}
|
||||
}
|
||||
if m.PollID != "" {
|
||||
m = newPollRequestMessage(t.ID, m.PollID)
|
||||
m = model.NewPollRequestMessage(t.ID, m.PollID)
|
||||
}
|
||||
m.Sender = v.IP()
|
||||
m.User = v.MaybeUserID()
|
||||
@@ -830,6 +883,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
if m.Message == "" {
|
||||
m.Message = emptyMessageBody
|
||||
}
|
||||
m.SanitizeUTF8()
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
ev := logvrm(v, r, m).
|
||||
Tag(tagPublish).
|
||||
@@ -906,7 +960,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||
return err
|
||||
}
|
||||
minc(metricMessagesPublishedSuccess)
|
||||
return s.writeJSON(w, m.forJSON())
|
||||
return s.writeJSON(w, m.ForJSON())
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
@@ -935,11 +989,11 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *
|
||||
}
|
||||
|
||||
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
return s.handleActionMessage(w, r, v, messageDeleteEvent)
|
||||
return s.handleActionMessage(w, r, v, model.MessageDeleteEvent)
|
||||
}
|
||||
|
||||
func (s *Server) handleClear(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
return s.handleActionMessage(w, r, v, messageClearEvent)
|
||||
return s.handleActionMessage(w, r, v, model.MessageClearEvent)
|
||||
}
|
||||
|
||||
func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string) error {
|
||||
@@ -959,7 +1013,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
||||
return e.With(t)
|
||||
}
|
||||
// Create an action message with the given event type
|
||||
m := newActionMessage(event, t.ID, sequenceID)
|
||||
m := model.NewActionMessage(event, t.ID, sequenceID)
|
||||
m.Sender = v.IP()
|
||||
m.User = v.MaybeUserID()
|
||||
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
||||
@@ -975,7 +1029,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
||||
if s.config.WebPushPublicKey != "" {
|
||||
go s.publishToWebPushEndpoints(v, m)
|
||||
}
|
||||
if event == messageDeleteEvent {
|
||||
if event == model.MessageDeleteEvent {
|
||||
// Delete any existing scheduled message with the same sequence ID
|
||||
deletedIDs, err := s.messageCache.DeleteScheduledBySequenceID(t.ID, sequenceID)
|
||||
if err != nil {
|
||||
@@ -996,10 +1050,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
s.mu.Unlock()
|
||||
return s.writeJSON(w, m.forJSON())
|
||||
return s.writeJSON(w, m.ForJSON())
|
||||
}
|
||||
|
||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
func (s *Server) sendToFirebase(v *visitor, m *model.Message) {
|
||||
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
|
||||
if err := s.firebaseClient.Send(v, m); err != nil {
|
||||
minc(metricFirebasePublishedFailure)
|
||||
@@ -1013,7 +1067,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
minc(metricFirebasePublishedSuccess)
|
||||
}
|
||||
|
||||
func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
||||
func (s *Server) sendEmail(v *visitor, m *model.Message, email string) {
|
||||
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
|
||||
if err := s.smtpSender.Send(v, m, email); err != nil {
|
||||
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
|
||||
@@ -1023,7 +1077,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
||||
minc(metricEmailsPublishedSuccess)
|
||||
}
|
||||
|
||||
func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
func (s *Server) forwardPollRequest(v *visitor, m *model.Message) {
|
||||
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
||||
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
|
||||
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
|
||||
@@ -1055,7 +1109,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
||||
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
|
||||
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
|
||||
if err != nil {
|
||||
@@ -1082,7 +1136,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
filename := readParam(r, "x-filename", "filename", "file", "f")
|
||||
attach := readParam(r, "x-attach", "attach", "a")
|
||||
if attach != "" || filename != "" {
|
||||
m.Attachment = &attachment{}
|
||||
m.Attachment = &model.Attachment{}
|
||||
}
|
||||
if filename != "" {
|
||||
m.Attachment.Name = filename
|
||||
@@ -1112,6 +1166,9 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
m.Icon = icon
|
||||
}
|
||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||
if email != "" && !emailAddressRegex.MatchString(email) {
|
||||
return false, false, "", "", "", false, "", errHTTPBadRequestEmailAddressInvalid
|
||||
}
|
||||
if s.smtpSender == nil && email != "" {
|
||||
return false, false, "", "", "", false, "", errHTTPBadRequestEmailDisabled
|
||||
}
|
||||
@@ -1203,8 +1260,8 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
|
||||
// 7. curl -T file.txt ntfy.sh/mytopic
|
||||
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||
if m.Event == pollRequestEvent { // Case 1
|
||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||
if m.Event == model.PollRequestEvent { // Case 1
|
||||
return s.handleBodyDiscard(body)
|
||||
} else if unifiedpush {
|
||||
return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
|
||||
@@ -1226,7 +1283,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
|
||||
func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if utf8.Valid(body.PeekedBytes) {
|
||||
m.Message = string(body.PeekedBytes) // Do not trim
|
||||
} else {
|
||||
@@ -1236,7 +1293,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
|
||||
func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if !utf8.Valid(body.PeekedBytes) {
|
||||
return errHTTPBadRequestMessageNotUTF8.With(m)
|
||||
}
|
||||
@@ -1249,7 +1306,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
||||
func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
||||
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1274,7 +1331,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM
|
||||
|
||||
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
|
||||
// The template file must be in the templates directory, or in the configured template directory.
|
||||
func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error {
|
||||
func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error {
|
||||
if !templateNameRegex.MatchString(templateName) {
|
||||
return errHTTPBadRequestTemplateFileNotFound
|
||||
}
|
||||
@@ -1316,7 +1373,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str
|
||||
|
||||
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
|
||||
// message, title, and priority parameters.
|
||||
func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error {
|
||||
func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error {
|
||||
var err error
|
||||
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
|
||||
return err
|
||||
@@ -1357,8 +1414,8 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
|
||||
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
|
||||
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
|
||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if s.fileCache == nil || s.config.BaseURL == "" {
|
||||
return errHTTPBadRequestAttachmentsDisallowed.With(m)
|
||||
}
|
||||
vinfo, err := v.Info()
|
||||
@@ -1381,7 +1438,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
||||
}
|
||||
}
|
||||
if m.Attachment == nil {
|
||||
m.Attachment = &attachment{}
|
||||
m.Attachment = &model.Attachment{}
|
||||
}
|
||||
var ext string
|
||||
m.Attachment.Expires = attachmentExpiry
|
||||
@@ -1408,9 +1465,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
||||
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
@@ -1419,12 +1476,12 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
||||
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
|
||||
}
|
||||
return fmt.Sprintf("data: %s\n", buf.String()), nil
|
||||
@@ -1433,8 +1490,8 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
if msg.Event == messageEvent { // only handle default events
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
if msg.Event == model.MessageEvent { // only handle default events
|
||||
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
||||
}
|
||||
return "\n", nil // "keepalive" and "open" events just send an empty line
|
||||
@@ -1469,7 +1526,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
closed = true
|
||||
wlock.Unlock()
|
||||
}()
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
sub := func(v *visitor, msg *model.Message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
@@ -1512,7 +1569,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
@@ -1535,7 +1592,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
for _, t := range topics {
|
||||
t.Keepalive()
|
||||
}
|
||||
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
if err := sub(v, model.NewKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1631,7 +1688,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
}
|
||||
}
|
||||
})
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
sub := func(v *visitor, msg *model.Message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
@@ -1661,7 +1718,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
topics[i].Unsubscribe(subscriberID) // Order!
|
||||
}
|
||||
}()
|
||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||
return err
|
||||
}
|
||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||
@@ -1678,7 +1735,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
|
||||
func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) {
|
||||
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
||||
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
||||
since, err = parseSince(r, poll)
|
||||
@@ -1759,11 +1816,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi
|
||||
|
||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||
// marker, returning only messages that are newer than the marker.
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
if since.IsNone() {
|
||||
return nil
|
||||
}
|
||||
messages := make([]*message, 0)
|
||||
messages := make([]*model.Message, 0)
|
||||
for _, t := range topics {
|
||||
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
|
||||
if err != nil {
|
||||
@@ -1786,32 +1843,32 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
||||
//
|
||||
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
|
||||
// "all" for all messages, or "latest" for the most recent message for a topic
|
||||
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
||||
func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
|
||||
since := readParam(r, "x-since", "since", "si")
|
||||
|
||||
// Easy cases (empty, all, none)
|
||||
if since == "" {
|
||||
if poll {
|
||||
return sinceAllMessages, nil
|
||||
return model.SinceAllMessages, nil
|
||||
}
|
||||
return sinceNoMessages, nil
|
||||
return model.SinceNoMessages, nil
|
||||
} else if since == "all" {
|
||||
return sinceAllMessages, nil
|
||||
return model.SinceAllMessages, nil
|
||||
} else if since == "latest" {
|
||||
return sinceLatestMessage, nil
|
||||
return model.SinceLatestMessage, nil
|
||||
} else if since == "none" {
|
||||
return sinceNoMessages, nil
|
||||
return model.SinceNoMessages, nil
|
||||
}
|
||||
|
||||
// ID, timestamp, duration
|
||||
if validMessageID(since) {
|
||||
return newSinceID(since), nil
|
||||
if model.ValidMessageID(since) {
|
||||
return model.NewSinceID(since), nil
|
||||
} else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
|
||||
return newSinceTime(s), nil
|
||||
return model.NewSinceTime(s), nil
|
||||
} else if d, err := time.ParseDuration(since); err == nil {
|
||||
return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
|
||||
return model.NewSinceTime(time.Now().Add(-1 * d).Unix()), nil
|
||||
}
|
||||
return sinceNoMessages, errHTTPBadRequestSinceInvalid
|
||||
return model.SinceNoMessages, errHTTPBadRequestSinceInvalid
|
||||
}
|
||||
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
@@ -1967,14 +2024,14 @@ func (s *Server) runFirebaseKeepaliver() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
|
||||
s.sendToFirebase(v, model.NewKeepaliveMessage(firebaseControlTopic))
|
||||
/*
|
||||
FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
|
||||
To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
|
||||
Given that it's not really necessary to poll, turning it off for now should not have any impact.
|
||||
|
||||
case <-time.After(s.config.FirebasePollInterval):
|
||||
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
|
||||
s.sendToFirebase(v, model.NewKeepaliveMessage(firebasePollTopic))
|
||||
*/
|
||||
case <-s.closeChan:
|
||||
return
|
||||
@@ -2017,7 +2074,7 @@ func (s *Server) sendDelayedMessages() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error {
|
||||
logvm(v, m).Debug("Sending delayed message")
|
||||
s.mu.RLock()
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
@@ -2290,9 +2347,7 @@ func (s *Server) updateAndWriteStats(messagesCount int64) {
|
||||
s.messagesHistory = s.messagesHistory[1:]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
go func() {
|
||||
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
|
||||
}
|
||||
}()
|
||||
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +38,32 @@
|
||||
#
|
||||
# firebase-key-file: <filename>
|
||||
|
||||
# If "database-url" is set, ntfy will use PostgreSQL for all database-backed stores (message cache,
|
||||
# user manager, and web push subscriptions) instead of SQLite. When set, the "cache-file",
|
||||
# "auth-file", and "web-push-file" options must not be set.
|
||||
#
|
||||
# Note: Setting "database-url" implicitly enables authentication and access control.
|
||||
# The default access is "read-write" (see "auth-default-access").
|
||||
#
|
||||
# The URL supports standard PostgreSQL parameters (sslmode, connect_timeout, sslcert, etc.),
|
||||
# as well as ntfy-specific connection pool parameters:
|
||||
# pool_max_conns=10 - Maximum number of open connections (default: 10)
|
||||
# pool_max_idle_conns=N - Maximum number of idle connections
|
||||
# pool_conn_max_lifetime=5m - Maximum lifetime of a connection (Go duration)
|
||||
# pool_conn_max_idle_time=1m - Maximum idle time of a connection (Go duration)
|
||||
#
|
||||
# See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
|
||||
# for the full list of supported PostgreSQL connection parameters.
|
||||
#
|
||||
# Examples:
|
||||
# database-url: "postgres://user:pass@host:5432/ntfy"
|
||||
# database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50"
|
||||
#
|
||||
# database-url: <connection-string>
|
||||
|
||||
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
|
||||
# This allows for service restarts without losing messages in support of the since= parameter.
|
||||
# Not required if "database-url" is set (messages are stored in PostgreSQL instead).
|
||||
#
|
||||
# The "cache-duration" parameter defines the duration for which messages will be buffered
|
||||
# before they are deleted. This is required to support the "since=..." and "poll=1" parameter.
|
||||
@@ -77,6 +101,8 @@
|
||||
# If set, access to the ntfy server and API can be controlled on a granular level using
|
||||
# the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs.
|
||||
#
|
||||
# Note: If "database-url" is set, auth is implicitly enabled and "auth-file" must not be set.
|
||||
#
|
||||
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
|
||||
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
|
||||
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
|
||||
@@ -133,6 +159,7 @@
|
||||
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
|
||||
#
|
||||
# attachment-cache-dir:
|
||||
# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1"
|
||||
# attachment-total-size-limit: "5G"
|
||||
# attachment-file-size-limit: "15M"
|
||||
# attachment-expiry-duration: "3h"
|
||||
@@ -197,6 +224,7 @@
|
||||
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
|
||||
# Not required if "database-url" is set (subscriptions are stored in PostgreSQL instead).
|
||||
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
|
||||
# - web-push-startup-queries is an optional list of queries to run on startup
|
||||
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d)
|
||||
|
||||
@@ -3,13 +3,15 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -454,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
return errHTTPUnauthorized
|
||||
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
|
||||
return errHTTPConflictTopicReserved
|
||||
} else if u.IsUser() {
|
||||
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasReservation {
|
||||
reservations, err := s.userManager.ReservationsCount(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if reservations >= u.Tier.ReservationLimit {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
// Actually add the reservation
|
||||
// Actually add the reservation (with limit check inside the transaction to avoid races)
|
||||
logvr(v, r).
|
||||
Tag(tagAccount).
|
||||
Fields(log.Context{
|
||||
@@ -476,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
"everyone": everyone.String(),
|
||||
}).
|
||||
Debug("Adding topic reservation")
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
|
||||
var limit int64
|
||||
if u.IsUser() && u.Tier != nil {
|
||||
limit = u.Tier.ReservationLimit
|
||||
}
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil {
|
||||
if errors.Is(err, user.ErrTooManyReservations) {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Kill existing subscribers
|
||||
@@ -529,22 +525,15 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
} else if len(removedTopics) == 0 {
|
||||
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", "))
|
||||
if err := s.messageCache.ExpireMessages(removedTopics...); err != nil {
|
||||
return err
|
||||
}
|
||||
go s.pruneMessages()
|
||||
@@ -641,7 +630,7 @@ func (s *Server) publishSyncEvent(v *visitor) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m := newDefaultMessage(syncTopic.ID, string(messageBytes))
|
||||
m := model.NewDefaultMessage(syncTopic.ID, string(messageBytes))
|
||||
if err := syncTopic.Publish(v, m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,14 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
return s.writeJSON(w, &apiVersionResponse{
|
||||
Version: s.config.BuildVersion,
|
||||
Commit: s.config.BuildCommit,
|
||||
Date: s.config.BuildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
users, err := s.userManager.Users()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -9,393 +10,452 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestVersion_Admin(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.BuildVersion = "1.2.3"
|
||||
c.BuildCommit = "abcdef0"
|
||||
c.BuildDate = "2026-02-08T00:00:00Z"
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin and regular user
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Admin can access /v1/version
|
||||
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
var versionResponse apiVersionResponse
|
||||
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
|
||||
require.Equal(t, "1.2.3", versionResponse.Version)
|
||||
require.Equal(t, "abcdef0", versionResponse.Commit)
|
||||
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
|
||||
|
||||
// Non-admin user cannot access /v1/version
|
||||
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Unauthenticated user cannot access /v1/version
|
||||
rr = request(t, s, "GET", "/v1/version", "", nil)
|
||||
require.Equal(t, 401, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_AddRemove(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Create user with tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 4, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Nil(t, users[1].Tier)
|
||||
require.Equal(t, "emma", users[2].Name)
|
||||
require.Equal(t, user.RoleUser, users[2].Role)
|
||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||
require.Equal(t, user.Everyone, users[3].Name)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check user was deleted
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "emma", users[1].Name)
|
||||
require.Equal(t, user.Everyone, users[2].Name)
|
||||
|
||||
// Reject invalid user change
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Create user with tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 4, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Nil(t, users[1].Tier)
|
||||
require.Equal(t, "emma", users[2].Name)
|
||||
require.Equal(t, user.RoleUser, users[2].Role)
|
||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||
require.Equal(t, user.Everyone, users[3].Name)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check user was deleted
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "emma", users[1].Name)
|
||||
require.Equal(t, user.Everyone, users[2].Name)
|
||||
|
||||
// Reject invalid user change
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_AddWithPasswordHash(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check that user can login with password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, user.RoleAdmin, users[0].Role)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change password via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserTier(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users again
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check new tier
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "not-ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
||||
|
||||
// Try to change password via API
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Cannot create user with invalid username
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
|
||||
// Cannot create user if user already exists
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot create user with invalid tier
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot delete user as non-admin
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Subscribing not allowed
|
||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
|
||||
// Grant access
|
||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Now subscribing is allowed
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Reset access
|
||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Subscribing not allowed (again)
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Grant access fails, because non-admin
|
||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin, grant access to "gol*" topics
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||
|
||||
start, timeTaken := time.Now(), atomic.Int64{}
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||
// Check that user can login with password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
timeTaken.Store(time.Since(start).Milliseconds())
|
||||
}()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Reset access
|
||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||
waitFor(t, func() bool {
|
||||
return timeTaken.Load() >= 500
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, user.RoleAdmin, users[0].Role)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change password via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserTier(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users again
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check new tier
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "not-ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
||||
|
||||
// Try to change password via API
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Cannot create user with invalid username
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
|
||||
// Cannot create user if user already exists
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot create user with invalid tier
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot delete user as non-admin
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Subscribing not allowed
|
||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
|
||||
// Grant access
|
||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Now subscribing is allowed
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Reset access
|
||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Subscribing not allowed (again)
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Grant access fails, because non-admin
|
||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin, grant access to "gol*" topics
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||
|
||||
start, timeTaken := time.Now(), atomic.Int64{}
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
timeTaken.Store(time.Since(start).Milliseconds())
|
||||
}()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Reset access
|
||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||
waitFor(t, func() bool {
|
||||
return timeTaken.Load() >= 500
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"fmt"
|
||||
"google.golang.org/api/option"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"strings"
|
||||
@@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
|
||||
}
|
||||
}
|
||||
|
||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||
if !v.FirebaseAllowed() {
|
||||
return errFirebaseTemporarilyBanned
|
||||
}
|
||||
@@ -121,11 +122,11 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
|
||||
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
|
||||
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
|
||||
// to Firebase here. This is mainly for iOS to support self-hosted servers.
|
||||
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
|
||||
func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) {
|
||||
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
||||
var apnsConfig *messaging.APNSConfig
|
||||
switch m.Event {
|
||||
case keepaliveEvent, openEvent:
|
||||
case model.KeepaliveEvent, model.OpenEvent:
|
||||
data = map[string]string{
|
||||
"id": m.ID,
|
||||
"time": fmt.Sprintf("%d", m.Time),
|
||||
@@ -133,7 +134,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
|
||||
"topic": m.Topic,
|
||||
}
|
||||
apnsConfig = createAPNSBackgroundConfig(data)
|
||||
case pollRequestEvent:
|
||||
case model.PollRequestEvent:
|
||||
data = map[string]string{
|
||||
"id": m.ID,
|
||||
"time": fmt.Sprintf("%d", m.Time),
|
||||
@@ -143,7 +144,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
|
||||
"poll_id": m.PollID,
|
||||
}
|
||||
apnsConfig = createAPNSAlertConfig(m, data)
|
||||
case messageDeleteEvent, messageClearEvent:
|
||||
case model.MessageDeleteEvent, model.MessageClearEvent:
|
||||
data = map[string]string{
|
||||
"id": m.ID,
|
||||
"time": fmt.Sprintf("%d", m.Time),
|
||||
@@ -152,7 +153,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
|
||||
"sequence_id": m.SequenceID,
|
||||
}
|
||||
apnsConfig = createAPNSBackgroundConfig(data)
|
||||
case messageEvent:
|
||||
case model.MessageEvent:
|
||||
if auther != nil {
|
||||
// If "anonymous read" for a topic is not allowed, we cannot send the message along
|
||||
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
|
||||
@@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
|
||||
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
|
||||
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
|
||||
// Extension in iOS can modify the message.
|
||||
func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig {
|
||||
func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig {
|
||||
apnsData := make(map[string]any)
|
||||
for k, v := range data {
|
||||
apnsData[k] = v
|
||||
@@ -296,8 +297,8 @@ func maybeTruncateAPNSBodyMessage(s string) string {
|
||||
//
|
||||
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
||||
// most importantly, the PollID.
|
||||
func toPollRequest(m *message) *message {
|
||||
pr := newPollRequestMessage(m.Topic, m.ID)
|
||||
func toPollRequest(m *model.Message) *model.Message {
|
||||
pr := model.NewPollRequestMessage(m.Topic, m.ID)
|
||||
pr.ID = m.ID
|
||||
pr.Time = m.Time
|
||||
pr.Priority = m.Priority // Keep priority
|
||||
|
||||
@@ -4,6 +4,7 @@ package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ var (
|
||||
type firebaseClient struct {
|
||||
}
|
||||
|
||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||
return errFirebaseNotAvailable
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -63,7 +64,7 @@ func (s *testFirebaseSender) Messages() []*messaging.Message {
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
||||
m := newKeepaliveMessage("mytopic")
|
||||
m := model.NewKeepaliveMessage("mytopic")
|
||||
fbm, err := toFirebaseMessage(m, nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "mytopic", fbm.Topic)
|
||||
@@ -94,7 +95,7 @@ func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_Open(t *testing.T) {
|
||||
m := newOpenMessage("mytopic")
|
||||
m := model.NewOpenMessage("mytopic")
|
||||
fbm, err := toFirebaseMessage(m, nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "mytopic", fbm.Topic)
|
||||
@@ -125,13 +126,13 @@ func TestToFirebaseMessage_Open(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
m := newDefaultMessage("mytopic", "this is a message")
|
||||
m := model.NewDefaultMessage("mytopic", "this is a message")
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"tag 1", "tag2"}
|
||||
m.Click = "https://google.com"
|
||||
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
|
||||
m.Title = "some title"
|
||||
m.Actions = []*action{
|
||||
m.Actions = []*model.Action{
|
||||
{
|
||||
ID: "123",
|
||||
Action: "view",
|
||||
@@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
m.Attachment = &attachment{
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "some file.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 12345,
|
||||
@@ -219,7 +220,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
|
||||
m := newDefaultMessage("mytopic", "this is a message")
|
||||
m := model.NewDefaultMessage("mytopic", "this is a message")
|
||||
m.Priority = 5
|
||||
fbm, err := toFirebaseMessage(m, &testAuther{Allow: false}) // Not allowed!
|
||||
require.Nil(t, err)
|
||||
@@ -250,7 +251,7 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToFirebaseMessage_PollRequest(t *testing.T) {
|
||||
m := newPollRequestMessage("mytopic", "fOv6k1QbCzo6")
|
||||
m := model.NewPollRequestMessage("mytopic", "fOv6k1QbCzo6")
|
||||
fbm, err := toFirebaseMessage(m, nil)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "mytopic", fbm.Topic)
|
||||
@@ -344,18 +345,18 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
||||
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||
sender := &testFirebaseSender{allowed: 2}
|
||||
client := newFirebaseClient(sender, &testAuther{})
|
||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
||||
visitor := newVisitor(newTestConfig(t, ""), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 1, len(sender.Messages()))
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.Messages()))
|
||||
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.Messages()))
|
||||
|
||||
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
||||
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 0, len(sender.Messages()))
|
||||
}
|
||||
|
||||
@@ -17,15 +17,10 @@ func (s *Server) execManager() {
|
||||
s.pruneMessages()
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
|
||||
// Message count per topic
|
||||
var messagesCached int
|
||||
messageCounts, err := s.messageCache.MessageCounts()
|
||||
// Message count
|
||||
messagesCached, err := s.messageCache.MessagesCount()
|
||||
if err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
|
||||
messageCounts = make(map[string]int) // Empty, so we can continue
|
||||
}
|
||||
for _, count := range messageCounts {
|
||||
messagesCached += count
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot get messages count")
|
||||
}
|
||||
|
||||
// Remove subscriptions without subscribers
|
||||
@@ -104,6 +99,9 @@ func (s *Server) execManager() {
|
||||
mset(metricUsers, usersCount)
|
||||
mset(metricSubscribers, subscribers)
|
||||
mset(metricTopics, topicsCount)
|
||||
if s.fileCache != nil {
|
||||
mset(metricAttachmentsTotalSize, s.fileCache.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pruneVisitors() {
|
||||
|
||||
@@ -2,27 +2,30 @@ package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||
c := newTestConfig(t)
|
||||
c.AttachmentCacheDir = ""
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||
c := newTestConfig(t, databaseURL)
|
||||
c.AttachmentCacheDir = ""
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Publish a message
|
||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
m := toMessage(t, rr.Body.String())
|
||||
// Publish a message
|
||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
m := toMessage(t, rr.Body.String())
|
||||
|
||||
// Expire message
|
||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||
// Expire message
|
||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||
|
||||
// Does not panic
|
||||
s.pruneMessages()
|
||||
// Does not panic
|
||||
s.pruneMessages()
|
||||
|
||||
// Actually deleted
|
||||
_, err := s.messageCache.Message(m.ID)
|
||||
require.Equal(t, errMessageNotFound, err)
|
||||
// Actually deleted
|
||||
_, err := s.messageCache.Message(m.ID)
|
||||
require.Equal(t, model.ErrMessageNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
5
server/server_race_off_test.go
Normal file
5
server/server_race_off_test.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !race
|
||||
|
||||
package server
|
||||
|
||||
const raceEnabled = false
|
||||
5
server/server_race_on_test.go
Normal file
5
server/server_race_on_test.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build race
|
||||
|
||||
package server
|
||||
|
||||
const raceEnabled = true
|
||||
File diff suppressed because one or more lines are too long
@@ -11,6 +11,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *
|
||||
|
||||
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
|
||||
// Failures will be logged, but not returned to the caller.
|
||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
|
||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) {
|
||||
u, sender := v.User(), m.Sender.String()
|
||||
if u != nil {
|
||||
sender = u.Name
|
||||
|
||||
@@ -14,217 +14,224 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
||||
var called, verified atomic.Bool
|
||||
var code atomic.Pointer[string]
|
||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||
if code.Load() != nil {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
var called, verified atomic.Bool
|
||||
var code atomic.Pointer[string]
|
||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||
if code.Load() != nil {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||
code.Store(util.String("123456"))
|
||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||
if verified.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||
verified.Store(true)
|
||||
} else {
|
||||
t.Fatal("Unexpected path:", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer twilioVerifyServer.Close()
|
||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||
code.Store(util.String("123456"))
|
||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||
if verified.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||
verified.Store(true)
|
||||
} else {
|
||||
t.Fatal("Unexpected path:", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer twilioVerifyServer.Close()
|
||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioCallsServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioVerifyService = "VA1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioCallsServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioVerifyService = "VA1234567890"
|
||||
s := newTestServer(t, c)
|
||||
// Send verification code for phone number
|
||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return *code.Load() == "123456"
|
||||
})
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
// Add phone number with code
|
||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return verified.Load()
|
||||
})
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(phoneNumbers))
|
||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||
|
||||
// Send verification code for phone number
|
||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return *code.Load() == "123456"
|
||||
})
|
||||
// Do the thing
|
||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
|
||||
// Add phone number with code
|
||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return verified.Load()
|
||||
})
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(phoneNumbers))
|
||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||
// Remove the phone number
|
||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Do the thing
|
||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes",
|
||||
// Verify the phone number is gone from the DB
|
||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
|
||||
// Remove the phone number
|
||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Verify the phone number is gone from the DB
|
||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes", // <<<------
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes", // <<<------
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
||||
<Response>
|
||||
<Pause length="1"/>
|
||||
<Say language="de-DE" loop="3">
|
||||
@@ -240,88 +247,97 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||
</Say>
|
||||
<Say language="de-DE">Auf Wiederhören.</Say>
|
||||
</Response>`))
|
||||
s := newTestServer(t, c)
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+invalid",
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+invalid",
|
||||
})
|
||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+123123",
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+123123",
|
||||
})
|
||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+1234",
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfig(t, databaseURL))
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+1234",
|
||||
})
|
||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
wpush "heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -82,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
||||
if err != nil {
|
||||
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
||||
return
|
||||
}
|
||||
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON()))
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
||||
return
|
||||
@@ -128,7 +130,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
warningSent := make([]*webPushSubscription, 0)
|
||||
warningSent := make([]*wpush.Subscription, 0)
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
|
||||
@@ -143,7 +145,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
|
||||
func (s *Server) sendWebPushNotification(sub *wpush.Subscription, message []byte, contexters ...log.Contexter) error {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
|
||||
payload := &webpush.Subscription{
|
||||
Endpoint: sub.Endpoint,
|
||||
|
||||
@@ -4,6 +4,8 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||
// Nothing to see here
|
||||
}
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -18,6 +14,11 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,237 +26,262 @@ const (
|
||||
)
|
||||
|
||||
func TestServer_WebPush_Enabled(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
conf := newTestConfig(t, databaseURL)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf2 := newTestConfig(t)
|
||||
s2 := newTestServer(t, conf2)
|
||||
conf2 := newTestConfig(t, databaseURL)
|
||||
s2 := newTestServer(t, conf2)
|
||||
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf3 := newTestConfigWithWebPush(t)
|
||||
s3 := newTestServer(t, conf3)
|
||||
conf3 := newTestConfigWithWebPush(t, databaseURL)
|
||||
s3 := newTestServer(t, conf3)
|
||||
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
|
||||
})
|
||||
}
|
||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfig(t, databaseURL))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Delete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||
|
||||
var received atomic.Bool
|
||||
var received atomic.Bool
|
||||
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
endpoint := pushService.URL + "/push-receive"
|
||||
addSubscription(t, s, endpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -281,11 +307,13 @@ func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLen
|
||||
require.Len(t, subs, expectedLength)
|
||||
}
|
||||
|
||||
func newTestConfigWithWebPush(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
func newTestConfigWithWebPush(t *testing.T, databaseURL string) *Config {
|
||||
conf := newTestConfig(t, databaseURL)
|
||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||
require.Nil(t, err)
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
if conf.DatabaseURL == "" {
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
}
|
||||
conf.WebPushEmailAddress = "testing@example.com"
|
||||
conf.WebPushPrivateKey = privateKey
|
||||
conf.WebPushPublicKey = publicKey
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
type mailer interface {
|
||||
Send(v *visitor, m *message, to string) error
|
||||
Send(v *visitor, m *model.Message, to string) error
|
||||
Counts() (total int64, success int64, failure int64)
|
||||
}
|
||||
|
||||
@@ -27,7 +28,7 @@ type smtpSender struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
||||
func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error {
|
||||
return s.withCount(v, m, func() error {
|
||||
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
|
||||
if err != nil {
|
||||
@@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) {
|
||||
return s.success + s.failure, s.success, s.failure
|
||||
}
|
||||
|
||||
func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
||||
func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error {
|
||||
err := fn()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) {
|
||||
func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) {
|
||||
topicURL := baseURL + "/" + m.Topic
|
||||
subject := m.Title
|
||||
if subject == "" {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestFormatMail_Basic(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustEmojis(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustOtherTags(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustPriority(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_UTF8Subject(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_WithAllTheThings(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -33,6 +34,7 @@ var (
|
||||
var (
|
||||
onlySpacesRegex = regexp.MustCompile(`(?m)^\s+$`)
|
||||
consecutiveNewLinesRegex = regexp.MustCompile(`\n{3,}`)
|
||||
htmlLineBreakRegex = regexp.MustCompile(`(?i)<br\s*/?>`)
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -158,7 +160,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
||||
if len(body) > conf.MessageSizeLimit {
|
||||
body = body[:conf.MessageSizeLimit]
|
||||
}
|
||||
m := newDefaultMessage(s.topic, body)
|
||||
m := model.NewDefaultMessage(s.topic, body)
|
||||
subject := strings.TrimSpace(msg.Header.Get("Subject"))
|
||||
if subject != "" {
|
||||
dec := mime.WordDecoder{}
|
||||
@@ -183,7 +185,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *smtpSession) publishMessage(m *message) error {
|
||||
func (s *smtpSession) publishMessage(m *model.Message) error {
|
||||
// Extract remote address (for rate limiting)
|
||||
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
|
||||
if err != nil {
|
||||
@@ -327,6 +329,9 @@ func readHTMLMailBody(reader io.Reader, transferEncoding string) (string, error)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Convert <br> tags to newlines before stripping HTML, so that line breaks
|
||||
// in HTML emails (e.g. from Synology DSM, and other appliances) are preserved.
|
||||
body = htmlLineBreakRegex.ReplaceAllString(body, "\n")
|
||||
stripped := bluemonday.
|
||||
StrictPolicy().
|
||||
AddSpaceWhenStrippingTag(true).
|
||||
|
||||
@@ -694,7 +694,8 @@ home automation setup
|
||||
Now the light is on
|
||||
|
||||
If you don't want to receive this message anymore, stop the push
|
||||
services in your FRITZ!Box .
|
||||
services in your FRITZ!Box .
|
||||
|
||||
Here you can see the active push services: "System > Push Service".
|
||||
|
||||
This mail has ben sent by your FRITZ!Box automatically.`
|
||||
@@ -1354,9 +1355,11 @@ Congratulations! You have successfully set up the email notification on Synology
|
||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/synology", r.URL.Path)
|
||||
require.Equal(t, "[Synology NAS] Test Message from Litts_NAS", r.Header.Get("Title"))
|
||||
actual := readAll(t, r.Body)
|
||||
expected := `Congratulations! You have successfully set up the email notification on Synology_NAS. For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/. (If you cannot connect to the server, please contact the administrator.) From Synology_NAS`
|
||||
require.Equal(t, expected, actual)
|
||||
expected := "Congratulations! You have successfully set up the email notification on Synology_NAS.\n" +
|
||||
"For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/.\n" +
|
||||
"(If you cannot connect to the server, please contact the administrator.)\n\n" +
|
||||
"From Synology_NAS"
|
||||
require.Equal(t, expected, readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerDomain = "mydomain.me"
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
@@ -1365,6 +1368,36 @@ Congratulations! You have successfully set up the email notification on Synology
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_HTMLEmail_BrTagsPreserved(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: nas@example.com
|
||||
RCPT TO: ntfy-alerts@ntfy.sh
|
||||
DATA
|
||||
Content-Type: text/html; charset=utf-8
|
||||
Content-Transfer-Encoding: 8bit
|
||||
Subject: Task Scheduler: daily-backup
|
||||
|
||||
Task Scheduler has completed a scheduled task.<BR><BR>Task: daily-backup<BR>Start time: Mon, 01 Jan 2026 02:00:00 +0000<BR>Stop time: Mon, 01 Jan 2024 02:03:00 +0000<BR>Current status: 0 (Normal)<BR>Standard output/error:<BR>OK<BR><BR>From MyNAS
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/alerts", r.URL.Path)
|
||||
require.Equal(t, "Task Scheduler: daily-backup", r.Header.Get("Title"))
|
||||
expected := "Task Scheduler has completed a scheduled task.\n\n" +
|
||||
"Task: daily-backup\n" +
|
||||
"Start time: Mon, 01 Jan 2026 02:00:00 +0000\n" +
|
||||
"Stop time: Mon, 01 Jan 2024 02:03:00 +0000\n" +
|
||||
"Current status: 0 (Normal)\n" +
|
||||
"Standard output/error:\n" +
|
||||
"OK\n\n" +
|
||||
"From MyNAS"
|
||||
require.Equal(t, expected, readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
@@ -1411,7 +1444,7 @@ what's up
|
||||
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
|
||||
|
||||
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
|
||||
conf = newTestConfig(t)
|
||||
conf = newTestConfig(t, "")
|
||||
conf.SMTPServerListen = ":25"
|
||||
conf.SMTPServerDomain = "ntfy.sh"
|
||||
conf.SMTPServerAddrPrefix = "ntfy-"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ type topicSubscriber struct {
|
||||
}
|
||||
|
||||
// subscriber is a function that is called for every new message on a topic
|
||||
type subscriber func(v *visitor, msg *message) error
|
||||
type subscriber func(v *visitor, msg *model.Message) error
|
||||
|
||||
// newTopic creates a new topic
|
||||
func newTopic(id string) *topic {
|
||||
@@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) {
|
||||
}
|
||||
|
||||
// Publish asynchronously publishes to all subscribers
|
||||
func (t *topic) Publish(v *visitor, m *message) error {
|
||||
func (t *topic) Publish(v *visitor, m *model.Message) error {
|
||||
go func() {
|
||||
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
||||
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
canceled1 := atomic.Bool{}
|
||||
@@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||
func TestTopic_CancelSubscribersUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
canceled1 := atomic.Bool{}
|
||||
@@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) {
|
||||
cancel: func() {},
|
||||
}
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
256
server/types.go
256
server/types.go
@@ -2,219 +2,35 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// List of possible events
|
||||
const (
|
||||
openEvent = "open"
|
||||
keepaliveEvent = "keepalive"
|
||||
messageEvent = "message"
|
||||
messageDeleteEvent = "message_delete"
|
||||
messageClearEvent = "message_clear"
|
||||
pollRequestEvent = "poll_request"
|
||||
)
|
||||
|
||||
const (
|
||||
messageIDLength = 12
|
||||
)
|
||||
|
||||
// message represents a message published to a topic
|
||||
type message struct {
|
||||
ID string `json:"id"` // Random message ID
|
||||
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*action `json:"actions,omitempty"`
|
||||
Attachment *attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
|
||||
func (m *message) Context() log.Context {
|
||||
fields := map[string]any{
|
||||
"topic": m.Topic,
|
||||
"message_id": m.ID,
|
||||
"message_sequence_id": m.SequenceID,
|
||||
"message_time": m.Time,
|
||||
"message_event": m.Event,
|
||||
"message_body_size": len(m.Message),
|
||||
}
|
||||
if m.Sender.IsValid() {
|
||||
fields["message_sender"] = m.Sender.String()
|
||||
}
|
||||
if m.User != "" {
|
||||
fields["message_user"] = m.User
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// forJSON returns a copy of the message suitable for JSON output.
|
||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||
func (m *message) forJSON() *message {
|
||||
if m.SequenceID == m.ID {
|
||||
clone := *m
|
||||
clone.SequenceID = ""
|
||||
return &clone
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type attachment struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type action struct {
|
||||
ID string `json:"id"`
|
||||
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
||||
Label string `json:"label"` // action button label
|
||||
Clear bool `json:"clear"` // clear notification after successful execution
|
||||
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
||||
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
||||
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
||||
Body string `json:"body,omitempty"` // used in "http" action
|
||||
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
||||
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
||||
Value string `json:"value,omitempty"` // used in "copy" action
|
||||
}
|
||||
|
||||
func newAction() *action {
|
||||
return &action{
|
||||
Headers: make(map[string]string),
|
||||
Extras: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// publishMessage is used as input when publishing as JSON
|
||||
type publishMessage struct {
|
||||
Topic string `json:"topic"`
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Priority int `json:"priority"`
|
||||
Tags []string `json:"tags"`
|
||||
Click string `json:"click"`
|
||||
Icon string `json:"icon"`
|
||||
Actions []action `json:"actions"`
|
||||
Attach string `json:"attach"`
|
||||
Markdown bool `json:"markdown"`
|
||||
Filename string `json:"filename"`
|
||||
Email string `json:"email"`
|
||||
Call string `json:"call"`
|
||||
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
||||
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
||||
Delay string `json:"delay"`
|
||||
Topic string `json:"topic"`
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Priority int `json:"priority"`
|
||||
Tags []string `json:"tags"`
|
||||
Click string `json:"click"`
|
||||
Icon string `json:"icon"`
|
||||
Actions []model.Action `json:"actions"`
|
||||
Attach string `json:"attach"`
|
||||
Markdown bool `json:"markdown"`
|
||||
Filename string `json:"filename"`
|
||||
Email string `json:"email"`
|
||||
Call string `json:"call"`
|
||||
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
||||
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
||||
Delay string `json:"delay"`
|
||||
}
|
||||
|
||||
// messageEncoder is a function that knows how to encode a message
|
||||
type messageEncoder func(msg *message) (string, error)
|
||||
|
||||
// newMessage creates a new message with the current timestamp
|
||||
func newMessage(event, topic, msg string) *message {
|
||||
return &message{
|
||||
ID: util.RandomString(messageIDLength),
|
||||
Time: time.Now().Unix(),
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
// newOpenMessage is a convenience method to create an open message
|
||||
func newOpenMessage(topic string) *message {
|
||||
return newMessage(openEvent, topic, "")
|
||||
}
|
||||
|
||||
// newKeepaliveMessage is a convenience method to create a keepalive message
|
||||
func newKeepaliveMessage(topic string) *message {
|
||||
return newMessage(keepaliveEvent, topic, "")
|
||||
}
|
||||
|
||||
// newDefaultMessage is a convenience method to create a notification message
|
||||
func newDefaultMessage(topic, msg string) *message {
|
||||
return newMessage(messageEvent, topic, msg)
|
||||
}
|
||||
|
||||
// newPollRequestMessage is a convenience method to create a poll request message
|
||||
func newPollRequestMessage(topic, pollID string) *message {
|
||||
m := newMessage(pollRequestEvent, topic, newMessageBody)
|
||||
m.PollID = pollID
|
||||
return m
|
||||
}
|
||||
|
||||
// newActionMessage creates a new action message (message_delete or message_clear)
|
||||
func newActionMessage(event, topic, sequenceID string) *message {
|
||||
m := newMessage(event, topic, "")
|
||||
m.SequenceID = sequenceID
|
||||
return m
|
||||
}
|
||||
|
||||
func validMessageID(s string) bool {
|
||||
return util.ValidRandomString(s, messageIDLength)
|
||||
}
|
||||
|
||||
type sinceMarker struct {
|
||||
time time.Time
|
||||
id string
|
||||
}
|
||||
|
||||
func newSinceTime(timestamp int64) sinceMarker {
|
||||
return sinceMarker{time.Unix(timestamp, 0), ""}
|
||||
}
|
||||
|
||||
func newSinceID(id string) sinceMarker {
|
||||
return sinceMarker{time.Unix(0, 0), id}
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsAll() bool {
|
||||
return t == sinceAllMessages
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsNone() bool {
|
||||
return t == sinceNoMessages
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsLatest() bool {
|
||||
return t == sinceLatestMessage
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsID() bool {
|
||||
return t.id != "" && t.id != "latest"
|
||||
}
|
||||
|
||||
func (t sinceMarker) Time() time.Time {
|
||||
return t.time
|
||||
}
|
||||
|
||||
func (t sinceMarker) ID() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
var (
|
||||
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
|
||||
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
|
||||
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
|
||||
)
|
||||
type messageEncoder func(msg *model.Message) (string, error)
|
||||
|
||||
type queryFilter struct {
|
||||
ID string
|
||||
@@ -246,8 +62,8 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *queryFilter) Pass(msg *message) bool {
|
||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||
func (q *queryFilter) Pass(msg *model.Message) bool {
|
||||
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
|
||||
return true // filters only apply to messages
|
||||
} else if q.ID != "" && msg.ID != q.ID {
|
||||
return false
|
||||
@@ -320,6 +136,12 @@ type apiHealthResponse struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type apiVersionResponse struct {
|
||||
Version string `json:"version"`
|
||||
Commit string `json:"commit"`
|
||||
Date string `json:"date"`
|
||||
}
|
||||
|
||||
type apiStatsResponse struct {
|
||||
Messages int64 `json:"messages"`
|
||||
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
|
||||
@@ -564,12 +386,12 @@ const (
|
||||
)
|
||||
|
||||
type webPushPayload struct {
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *message `json:"message"`
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *model.Message `json:"message"`
|
||||
}
|
||||
|
||||
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
|
||||
func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload {
|
||||
return &webPushPayload{
|
||||
Event: webPushMessageEvent,
|
||||
SubscriptionID: subscriptionID,
|
||||
@@ -587,22 +409,6 @@ func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
|
||||
}
|
||||
}
|
||||
|
||||
type webPushSubscription struct {
|
||||
ID string
|
||||
Endpoint string
|
||||
Auth string
|
||||
P256dh string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func (w *webPushSubscription) Context() log.Context {
|
||||
return map[string]any{
|
||||
"web_push_subscription_id": w.ID,
|
||||
"web_push_subscription_user_id": w.UserID,
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// https://developer.mozilla.org/en-US/docs/Web/Manifest
|
||||
type webManifestResponse struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -53,7 +54,7 @@ const (
|
||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||
type visitor struct {
|
||||
config *Config
|
||||
messageCache *messageCache
|
||||
messageCache *message.Cache
|
||||
userManager *user.Manager // May be nil
|
||||
ip netip.Addr // Visitor IP address
|
||||
user *user.User // Only set if authenticated user, otherwise nil
|
||||
@@ -114,7 +115,7 @@ const (
|
||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messages, emails, calls int64
|
||||
if user != nil {
|
||||
messages = user.Stats.Messages
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
var (
|
||||
errWebPushNoRows = errors.New("no rows found")
|
||||
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
const (
|
||||
createWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
selectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
selectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
insertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentWebPushSchemaVersion = 1
|
||||
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
type webPushStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(selectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||
// existing entries for a given endpoint.
|
||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rowsCount.Close()
|
||||
var subscriptionCount int
|
||||
if !rowsCount.Next() {
|
||||
return errWebPushNoRows
|
||||
}
|
||||
if err := rowsCount.Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rowsCount.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
var subscriptionID string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return errWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
|
||||
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
|
||||
subscriptions := make([]*webPushSubscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &webPushSubscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
|
||||
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return errWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *webPushStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "u_1234")
|
||||
|
||||
subs2, err := webPush.SubscriptionsForTopic("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs2, 1)
|
||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert 10 subscriptions with the same IP address
|
||||
for i := 0; i < 10; i++ {
|
||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
// Another one for the same endpoint should be fine
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different endpoint it should fail
|
||||
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different IP address it should be fine again
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics, and another with one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
// Update the first subscription to have only one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
|
||||
}
|
||||
|
||||
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
|
||||
|
||||
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
require.Nil(t, err)
|
||||
defer rows.Close()
|
||||
var endpoint string
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&endpoint))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, testWebPushEndpoint, endpoint)
|
||||
require.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Should not be cleaned up yet
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// Run expiration
|
||||
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Run expiration
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// List again, should be 0
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func newTestWebPushStore(t *testing.T) *webPushStore {
|
||||
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
return webPush
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package test
|
||||
import (
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/server"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -16,7 +16,7 @@ func StartServer(t *testing.T) (*server.Server, int) {
|
||||
|
||||
// StartServerWithConfig starts a server.Server with a random port and waits for the server to be up
|
||||
func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, int) {
|
||||
port := 10000 + rand.Intn(30000)
|
||||
port := findAvailablePort(t)
|
||||
conf.ListenHTTP = fmt.Sprintf(":%d", port)
|
||||
conf.AttachmentCacheDir = t.TempDir()
|
||||
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
|
||||
@@ -33,6 +33,17 @@ func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, i
|
||||
return s, port
|
||||
}
|
||||
|
||||
// findAvailablePort asks the OS for a free port by binding to :0
|
||||
func findAvailablePort(t *testing.T) int {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
return port
|
||||
}
|
||||
|
||||
// StopServer stops the test server and waits for the port to be down
|
||||
func StopServer(t *testing.T, s *server.Server, port int) {
|
||||
s.Stop()
|
||||
|
||||
5
tools/loadtest/go.mod
Normal file
5
tools/loadtest/go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module loadtest
|
||||
|
||||
go 1.25.2
|
||||
|
||||
require github.com/gorilla/websocket v1.5.3
|
||||
2
tools/loadtest/go.sum
Normal file
2
tools/loadtest/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
543
tools/loadtest/main.go
Normal file
543
tools/loadtest/main.go
Normal file
@@ -0,0 +1,543 @@
|
||||
// Load test program for ntfy staging server.
|
||||
// Replicates production traffic patterns derived from access.log analysis.
|
||||
//
|
||||
// Traffic profile (from ~5M requests over 20 hours):
|
||||
// ~71 req/sec average, ~4,300 req/min
|
||||
// 49.6% poll requests (GET /TOPIC/json?poll=1&since=ID)
|
||||
// 21.4% publish POST (POST /TOPIC with small body)
|
||||
// 6.2% subscribe stream (GET /TOPIC/json?since=X, long-lived)
|
||||
// 4.1% config check (GET /v1/config)
|
||||
// 2.3% other topic GET (GET /TOPIC)
|
||||
// 2.2% account check (GET /v1/account)
|
||||
// 1.9% websocket sub (GET /TOPIC/ws?since=X)
|
||||
// 1.5% publish PUT (PUT /TOPIC with small body)
|
||||
// 1.5% raw subscribe (GET /TOPIC/raw?since=X)
|
||||
// 1.1% json subscribe (GET /TOPIC/json, no since)
|
||||
// 0.7% SSE subscribe (GET /TOPIC/sse?since=X)
|
||||
// remaining: static, PATCH, OPTIONS, etc. (omitted)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
baseURL string
|
||||
username string
|
||||
password string
|
||||
rps float64
|
||||
scale float64
|
||||
numTopics int
|
||||
subStreams int
|
||||
wsStreams int
|
||||
sseStreams int
|
||||
rawStreams int
|
||||
duration time.Duration
|
||||
|
||||
totalRequests atomic.Int64
|
||||
totalErrors atomic.Int64
|
||||
activeStreams atomic.Int64
|
||||
|
||||
// Error tracking by category
|
||||
errMu sync.Mutex
|
||||
recentErrors []string // last N unique error messages
|
||||
errorCounts = make(map[string]int64)
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&baseURL, "url", "https://staging.ntfy.sh", "Base URL of ntfy server")
|
||||
flag.StringVar(&username, "user", "", "Username for authentication")
|
||||
flag.StringVar(&password, "pass", "", "Password for authentication")
|
||||
flag.Float64Var(&rps, "rps", 71, "Target requests per second (default: prod average)")
|
||||
flag.Float64Var(&scale, "scale", 1.0, "Scale factor for all load (0.5 = half load, 2.0 = double)")
|
||||
flag.IntVar(&numTopics, "topics", 500, "Number of unique topics to use")
|
||||
flag.IntVar(&subStreams, "sub-streams", 200, "Number of concurrent JSON streaming subscriptions")
|
||||
flag.IntVar(&wsStreams, "ws-streams", 50, "Number of concurrent WebSocket subscriptions")
|
||||
flag.IntVar(&sseStreams, "sse-streams", 20, "Number of concurrent SSE subscriptions")
|
||||
flag.IntVar(&rawStreams, "raw-streams", 30, "Number of concurrent raw subscriptions")
|
||||
flag.DurationVar(&duration, "duration", 10*time.Minute, "Test duration")
|
||||
flag.Parse()
|
||||
|
||||
rps *= scale
|
||||
subStreams = int(float64(subStreams) * scale)
|
||||
wsStreams = int(float64(wsStreams) * scale)
|
||||
sseStreams = int(float64(sseStreams) * scale)
|
||||
rawStreams = int(float64(rawStreams) * scale)
|
||||
|
||||
topics := generateTopics(numTopics)
|
||||
|
||||
fmt.Printf("ntfy load test\n")
|
||||
fmt.Printf(" Target: %s\n", baseURL)
|
||||
fmt.Printf(" RPS: %.1f\n", rps)
|
||||
fmt.Printf(" Scale: %.1fx\n", scale)
|
||||
fmt.Printf(" Topics: %d\n", numTopics)
|
||||
fmt.Printf(" Sub streams: %d json, %d ws, %d sse, %d raw\n", subStreams, wsStreams, sseStreams, rawStreams)
|
||||
fmt.Printf(" Duration: %s\n", duration)
|
||||
fmt.Println()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
|
||||
// Also handle Ctrl+C
|
||||
sigCtx, sigCancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||
defer sigCancel()
|
||||
ctx = sigCtx
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 1000,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// Long-lived streaming client (no timeout)
|
||||
streamClient := &http.Client{
|
||||
Timeout: 0,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 500,
|
||||
MaxIdleConnsPerHost: 500,
|
||||
IdleConnTimeout: 0,
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start long-lived streaming subscriptions
|
||||
for i := 0; i < subStreams; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamSubscription(ctx, streamClient, topics, "json")
|
||||
}()
|
||||
}
|
||||
for i := 0; i < wsStreams; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wsSubscription(ctx, topics)
|
||||
}()
|
||||
}
|
||||
for i := 0; i < sseStreams; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamSubscription(ctx, streamClient, topics, "sse")
|
||||
}()
|
||||
}
|
||||
for i := 0; i < rawStreams; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamSubscription(ctx, streamClient, topics, "raw")
|
||||
}()
|
||||
}
|
||||
|
||||
// Start request generators based on traffic weights
|
||||
// Weights from log analysis (normalized to sum ~100):
|
||||
// poll=49.6, publish_post=21.4, config=4.1, other_get=2.3, account=2.2, publish_put=1.5
|
||||
// Total short-lived weight ≈ 81.1
|
||||
type requestType struct {
|
||||
name string
|
||||
weight float64
|
||||
fn func(ctx context.Context, client *http.Client, topics []string)
|
||||
}
|
||||
|
||||
types := []requestType{
|
||||
{"poll", 49.6, doPoll},
|
||||
{"publish_post", 21.4, doPublishPost},
|
||||
{"config", 4.1, doConfig},
|
||||
{"other_get", 2.3, doOtherGet},
|
||||
{"account", 2.2, doAccountCheck},
|
||||
{"publish_put", 1.5, doPublishPut},
|
||||
}
|
||||
|
||||
totalWeight := 0.0
|
||||
for _, t := range types {
|
||||
totalWeight += t.weight
|
||||
}
|
||||
|
||||
for _, t := range types {
|
||||
t := t
|
||||
typeRPS := rps * (t.weight / totalWeight)
|
||||
if typeRPS < 0.1 {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
runAtRate(ctx, typeRPS, func() {
|
||||
t.fn(ctx, client, topics)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// Stats reporter
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reportStats(ctx)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
fmt.Printf("\nDone. Total requests: %d, errors: %d\n", totalRequests.Load(), totalErrors.Load())
|
||||
}
|
||||
|
||||
func trackError(category string, err error) {
|
||||
totalErrors.Add(1)
|
||||
key := fmt.Sprintf("%s: %s", category, truncateErr(err))
|
||||
errMu.Lock()
|
||||
errorCounts[key]++
|
||||
errMu.Unlock()
|
||||
}
|
||||
|
||||
func trackErrorMsg(category string, msg string) {
|
||||
totalErrors.Add(1)
|
||||
key := fmt.Sprintf("%s: %s", category, msg)
|
||||
errMu.Lock()
|
||||
errorCounts[key]++
|
||||
errMu.Unlock()
|
||||
}
|
||||
|
||||
func truncateErr(err error) string {
|
||||
s := err.Error()
|
||||
if len(s) > 120 {
|
||||
s = s[:120] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func setAuth(req *http.Request) {
|
||||
if username != "" && password != "" {
|
||||
req.SetBasicAuth(username, password)
|
||||
}
|
||||
}
|
||||
|
||||
func generateTopics(n int) []string {
|
||||
topics := make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
topics[i] = "loadtest-" + hex.EncodeToString(b)
|
||||
}
|
||||
return topics
|
||||
}
|
||||
|
||||
func pickTopic(topics []string) string {
|
||||
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(topics))))
|
||||
return topics[n.Int64()]
|
||||
}
|
||||
|
||||
func randomSince() string {
|
||||
b := make([]byte, 6)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func randomMessage() string {
|
||||
messages := []string{
|
||||
"Test notification",
|
||||
"Server backup completed successfully",
|
||||
"Deployment finished",
|
||||
"Alert: disk usage above 80%",
|
||||
"Build #1234 passed",
|
||||
"New order received",
|
||||
"Temperature sensor reading: 72F",
|
||||
"Cron job completed",
|
||||
}
|
||||
return messages[mrand.Intn(len(messages))]
|
||||
}
|
||||
|
||||
// runAtRate executes fn at approximately the given rate per second
|
||||
func runAtRate(ctx context.Context, rate float64, fn func()) {
|
||||
interval := time.Duration(float64(time.Second) / rate)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
go fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Short-lived request types ---
|
||||
|
||||
func doPoll(ctx context.Context, client *http.Client, topics []string) {
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s/json?poll=1&since=%s", baseURL, topic, randomSince())
|
||||
doGet(ctx, client, url)
|
||||
}
|
||||
|
||||
func doPublishPost(ctx context.Context, client *http.Client, topics []string) {
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(randomMessage()))
|
||||
if err != nil {
|
||||
trackError("publish_post_req", err)
|
||||
return
|
||||
}
|
||||
setAuth(req)
|
||||
// Some messages have titles/priorities like real traffic
|
||||
if mrand.Float32() < 0.3 {
|
||||
req.Header.Set("X-Title", "Load Test")
|
||||
}
|
||||
if mrand.Float32() < 0.1 {
|
||||
req.Header.Set("X-Priority", fmt.Sprintf("%d", mrand.Intn(5)+1))
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
totalRequests.Add(1)
|
||||
if err != nil {
|
||||
trackError("publish_post", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
trackErrorMsg("publish_post_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
func doPublishPut(ctx context.Context, client *http.Client, topics []string) {
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(randomMessage()))
|
||||
if err != nil {
|
||||
trackError("publish_put_req", err)
|
||||
return
|
||||
}
|
||||
setAuth(req)
|
||||
resp, err := client.Do(req)
|
||||
totalRequests.Add(1)
|
||||
if err != nil {
|
||||
trackError("publish_put", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
trackErrorMsg("publish_put_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
func doConfig(ctx context.Context, client *http.Client, topics []string) {
|
||||
url := fmt.Sprintf("%s/v1/config", baseURL)
|
||||
doGet(ctx, client, url)
|
||||
}
|
||||
|
||||
func doAccountCheck(ctx context.Context, client *http.Client, topics []string) {
|
||||
url := fmt.Sprintf("%s/v1/account", baseURL)
|
||||
doGet(ctx, client, url)
|
||||
}
|
||||
|
||||
func doOtherGet(ctx context.Context, client *http.Client, topics []string) {
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s", baseURL, topic)
|
||||
doGet(ctx, client, url)
|
||||
}
|
||||
|
||||
func doGet(ctx context.Context, client *http.Client, url string) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
trackError("get_req", err)
|
||||
return
|
||||
}
|
||||
setAuth(req)
|
||||
resp, err := client.Do(req)
|
||||
totalRequests.Add(1)
|
||||
if err != nil {
|
||||
trackError("get", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
trackErrorMsg("get_http", fmt.Sprintf("status %d for %s", resp.StatusCode, url))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Long-lived streaming subscriptions ---
|
||||
|
||||
func streamSubscription(ctx context.Context, client *http.Client, topics []string, format string) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s/%s?since=all", baseURL, topic, format)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
setAuth(req)
|
||||
activeStreams.Add(1)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
activeStreams.Add(-1)
|
||||
if ctx.Err() == nil {
|
||||
trackError("stream_"+format+"_connect", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
trackErrorMsg("stream_"+format+"_http", fmt.Sprintf("status %d", resp.StatusCode))
|
||||
resp.Body.Close()
|
||||
activeStreams.Add(-1)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
// Read from stream until context cancelled or connection drops
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
_, err := resp.Body.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
trackError("stream_"+format+"_read", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
resp.Body.Close()
|
||||
activeStreams.Add(-1)
|
||||
// Reconnect with small delay (like real clients do)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wsSubscription(ctx context.Context, topics []string) {
|
||||
wsURL := strings.Replace(baseURL, "https://", "wss://", 1)
|
||||
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
topic := pickTopic(topics)
|
||||
url := fmt.Sprintf("%s/%s/ws?since=all", wsURL, topic)
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
var wsHeader http.Header
|
||||
if username != "" && password != "" {
|
||||
wsHeader = http.Header{}
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.SetBasicAuth(username, password)
|
||||
wsHeader.Set("Authorization", req.Header.Get("Authorization"))
|
||||
}
|
||||
activeStreams.Add(1)
|
||||
conn, _, err := dialer.DialContext(ctx, url, wsHeader)
|
||||
if err != nil {
|
||||
activeStreams.Add(-1)
|
||||
if ctx.Err() == nil {
|
||||
trackError("ws_connect", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// Read messages until context cancelled or error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
activeStreams.Add(-1)
|
||||
return
|
||||
case <-done:
|
||||
conn.Close()
|
||||
activeStreams.Add(-1)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Duration(mrand.Intn(3000)) * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reportStats(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastRequests, lastErrors int64
|
||||
lastTime := time.Now()
|
||||
reportCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
currentRequests := totalRequests.Load()
|
||||
currentErrors := totalErrors.Load()
|
||||
elapsed := now.Sub(lastTime).Seconds()
|
||||
currentRPS := float64(currentRequests-lastRequests) / elapsed
|
||||
errorRate := float64(currentErrors-lastErrors) / elapsed
|
||||
|
||||
fmt.Printf("[%s] rps=%.1f err/s=%.1f total=%d errors=%d streams=%d\n",
|
||||
now.Format("15:04:05"),
|
||||
currentRPS,
|
||||
errorRate,
|
||||
currentRequests,
|
||||
currentErrors,
|
||||
activeStreams.Load(),
|
||||
)
|
||||
|
||||
// Print error breakdown every 30 seconds
|
||||
reportCount++
|
||||
if reportCount%6 == 0 && currentErrors > 0 {
|
||||
errMu.Lock()
|
||||
fmt.Printf(" Error breakdown:\n")
|
||||
for k, v := range errorCounts {
|
||||
fmt.Printf(" %s: %d\n", k, v)
|
||||
}
|
||||
errMu.Unlock()
|
||||
}
|
||||
|
||||
lastRequests = currentRequests
|
||||
lastErrors = currentErrors
|
||||
lastTime = now
|
||||
}
|
||||
}
|
||||
}
|
||||
48
tools/pgimport/README.md
Normal file
48
tools/pgimport/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# pgimport
|
||||
|
||||
One-off migration script to import ntfy data from SQLite to PostgreSQL.
|
||||
|
||||
This is **not** a generic migration tool. It only works with specific SQLite schema versions
|
||||
(message cache v14, user db v6, web push v1) and their corresponding PostgreSQL schemas.
|
||||
If your database versions differ, this tool will refuse to run.
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -o pgimport ./tools/pgimport/
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Using CLI flags
|
||||
pgimport \
|
||||
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
|
||||
--cache-file /var/cache/ntfy/cache.db \
|
||||
--auth-file /var/lib/ntfy/user.db \
|
||||
--web-push-file /var/lib/ntfy/webpush.db
|
||||
|
||||
# Using --create-schema to set up PostgreSQL schema automatically
|
||||
pgimport \
|
||||
--create-schema \
|
||||
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
|
||||
--cache-file /var/cache/ntfy/cache.db \
|
||||
--auth-file /var/lib/ntfy/user.db \
|
||||
--web-push-file /var/lib/ntfy/webpush.db
|
||||
|
||||
# Using server.yml (flags override config values)
|
||||
pgimport --config /etc/ntfy/server.yml
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- PostgreSQL schema must already be set up, either by running ntfy with `database-url` once,
|
||||
or by passing `--create-schema` to pgimport to create the initial schema automatically
|
||||
- ntfy must not be running during the import
|
||||
- All three SQLite files are optional; only the ones specified will be imported
|
||||
|
||||
## Notes
|
||||
|
||||
- The tool is idempotent and safe to re-run
|
||||
- After importing, row counts and content are verified against the SQLite sources
|
||||
- Invalid UTF-8 in messages is replaced with the Unicode replacement character
|
||||
1164
tools/pgimport/main.go
Normal file
1164
tools/pgimport/main.go
Normal file
File diff suppressed because it is too large
Load Diff
141
tools/s3cli/main.go
Normal file
141
tools/s3cli/main.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT"
|
||||
//
|
||||
// s3cli put <key> <file> Upload a file
|
||||
// s3cli put <key> - Upload from stdin
|
||||
// s3cli get <key> Download to stdout
|
||||
// s3cli rm <key> [<key>...] Delete one or more objects
|
||||
// s3cli ls List all objects
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
usage()
|
||||
}
|
||||
s3URL := os.Getenv("S3_URL")
|
||||
if s3URL == "" {
|
||||
fail("S3_URL environment variable is required")
|
||||
}
|
||||
cfg, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
fail("invalid S3_URL: %s", err)
|
||||
}
|
||||
client := s3.New(cfg)
|
||||
ctx := context.Background()
|
||||
|
||||
switch os.Args[1] {
|
||||
case "put":
|
||||
cmdPut(ctx, client)
|
||||
case "get":
|
||||
cmdGet(ctx, client)
|
||||
case "rm":
|
||||
cmdRm(ctx, client)
|
||||
case "ls":
|
||||
cmdLs(ctx, client)
|
||||
default:
|
||||
usage()
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPut(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) != 4 {
|
||||
fail("usage: s3cli put <key> <file|->\n")
|
||||
}
|
||||
key := os.Args[2]
|
||||
path := os.Args[3]
|
||||
|
||||
var r io.Reader
|
||||
if path == "-" {
|
||||
r = os.Stdin
|
||||
} else {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
fail("open %s: %s", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
r = f
|
||||
}
|
||||
|
||||
if err := client.PutObject(ctx, key, r); err != nil {
|
||||
fail("put: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
|
||||
}
|
||||
|
||||
func cmdGet(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) != 3 {
|
||||
fail("usage: s3cli get <key>\n")
|
||||
}
|
||||
key := os.Args[2]
|
||||
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
if err != nil {
|
||||
fail("get: %s", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
n, err := io.Copy(os.Stdout, reader)
|
||||
if err != nil {
|
||||
fail("read: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size)
|
||||
}
|
||||
|
||||
func cmdRm(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) < 3 {
|
||||
fail("usage: s3cli rm <key> [<key>...]\n")
|
||||
}
|
||||
keys := os.Args[2:]
|
||||
if err := client.DeleteObjects(ctx, keys); err != nil {
|
||||
fail("rm: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys))
|
||||
}
|
||||
|
||||
func cmdLs(ctx context.Context, client *s3.Client) {
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
if err != nil {
|
||||
fail("ls: %s", err)
|
||||
}
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key)
|
||||
totalSize += obj.Size
|
||||
}
|
||||
w.Flush()
|
||||
fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, `Usage: s3cli <command> [args...]
|
||||
|
||||
Commands:
|
||||
put <key> <file|-> Upload a file (use - for stdin)
|
||||
get <key> Download to stdout
|
||||
rm <key> [keys...] Delete objects
|
||||
ls List all objects
|
||||
|
||||
Environment:
|
||||
S3_URL S3 connection URL (required)
|
||||
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
`)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func fail(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
2316
user/manager.go
2316
user/manager.go
File diff suppressed because it is too large
Load Diff
286
user/manager_postgres.go
Normal file
286
user/manager_postgres.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/v2/db"
|
||||
)
|
||||
|
||||
// PostgreSQL queries
|
||||
const (
|
||||
// User queries
|
||||
postgresSelectUsersQuery = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
ORDER BY
|
||||
CASE u.role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, u.user_name
|
||||
`
|
||||
postgresSelectUserByIDQuery = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = $1
|
||||
`
|
||||
postgresSelectUserByNameQuery = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user_name = $1
|
||||
`
|
||||
postgresSelectUserByTokenQuery = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||
`
|
||||
postgresSelectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = $1
|
||||
`
|
||||
postgresSelectUsernamesQuery = `
|
||||
SELECT user_name
|
||||
FROM "user"
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user_name
|
||||
`
|
||||
postgresSelectUserCountQuery = `SELECT COUNT(*) FROM "user"`
|
||||
postgresSelectUserIDFromUsernameQuery = `SELECT id FROM "user" WHERE user_name = $1`
|
||||
postgresInsertUserQuery = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
postgresUpdateUserPassQuery = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserRoleQuery = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserProvisionedQuery = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserPrefsQuery = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||
postgresUpdateUserStatsQuery = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||
postgresUpdateUserStatsResetAllQuery = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
postgresUpdateUserTierQuery = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||
postgresUpdateUserDeletedQuery = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
|
||||
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
|
||||
postgresDeleteUsersProvisionedQuery = `DELETE FROM "user" WHERE provisioned = true`
|
||||
|
||||
// Access queries
|
||||
postgresSelectTopicPermsQuery = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN "user" u ON u.id = a.user_id
|
||||
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||
`
|
||||
postgresSelectUserAllAccessQuery = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserAccessQuery = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserReservationsQuery = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
postgresSelectUserReservationsCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
`
|
||||
postgresSelectUserReservationsOwnerQuery = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = $1
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
postgresSelectUserHasReservationQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
AND topic = $2
|
||||
`
|
||||
postgresSelectOtherAccessCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||
`
|
||||
postgresUpsertUserAccessQuery = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES (
|
||||
(SELECT id FROM "user" WHERE user_name = $1),
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||
$7
|
||||
)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||
`
|
||||
postgresDeleteUserAccessQuery = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
`
|
||||
postgresDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = true`
|
||||
postgresDeleteTopicAccessQuery = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||
AND topic = $3
|
||||
`
|
||||
postgresDeleteAllAccessQuery = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
postgresSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||
postgresSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||
postgresSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||
postgresUpsertTokenQuery = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
|
||||
`
|
||||
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
|
||||
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
|
||||
postgresDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = true`
|
||||
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
|
||||
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||
postgresDeleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = $1
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = $2
|
||||
ORDER BY expires DESC
|
||||
LIMIT $3
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
postgresInsertTierQuery = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`
|
||||
postgresUpdateTierQuery = `
|
||||
UPDATE tier
|
||||
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||
WHERE code = $13
|
||||
`
|
||||
postgresSelectTiersQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
postgresSelectTierByCodeQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = $1
|
||||
`
|
||||
postgresSelectTierByPriceIDQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||
`
|
||||
postgresDeleteTierQuery = `DELETE FROM tier WHERE code = $1`
|
||||
|
||||
// Phone queries
|
||||
postgresSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||
postgresInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||
postgresDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||
|
||||
// Billing queries
|
||||
postgresUpdateBillingQuery = `
|
||||
UPDATE "user"
|
||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||
WHERE user_name = $7
|
||||
`
|
||||
)
|
||||
|
||||
// NewPostgresManager creates a new Manager backed by a PostgreSQL database using an existing connection pool.
|
||||
var postgresQueries = queries{
|
||||
selectUserByID: postgresSelectUserByIDQuery,
|
||||
selectUserByName: postgresSelectUserByNameQuery,
|
||||
selectUserByToken: postgresSelectUserByTokenQuery,
|
||||
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
|
||||
selectUsernames: postgresSelectUsernamesQuery,
|
||||
selectUsers: postgresSelectUsersQuery,
|
||||
selectUserCount: postgresSelectUserCountQuery,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
|
||||
insertUser: postgresInsertUserQuery,
|
||||
updateUserPass: postgresUpdateUserPassQuery,
|
||||
updateUserRole: postgresUpdateUserRoleQuery,
|
||||
updateUserProvisioned: postgresUpdateUserProvisionedQuery,
|
||||
updateUserPrefs: postgresUpdateUserPrefsQuery,
|
||||
updateUserStats: postgresUpdateUserStatsQuery,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAllQuery,
|
||||
updateUserTier: postgresUpdateUserTierQuery,
|
||||
updateUserDeleted: postgresUpdateUserDeletedQuery,
|
||||
deleteUser: postgresDeleteUserQuery,
|
||||
deleteUserTier: postgresDeleteUserTierQuery,
|
||||
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
|
||||
deleteUsersProvisioned: postgresDeleteUsersProvisionedQuery,
|
||||
selectTopicPerms: postgresSelectTopicPermsQuery,
|
||||
selectUserAllAccess: postgresSelectUserAllAccessQuery,
|
||||
selectUserAccess: postgresSelectUserAccessQuery,
|
||||
selectUserReservations: postgresSelectUserReservationsQuery,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCountQuery,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwnerQuery,
|
||||
selectUserHasReservation: postgresSelectUserHasReservationQuery,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCountQuery,
|
||||
upsertUserAccess: postgresUpsertUserAccessQuery,
|
||||
deleteUserAccess: postgresDeleteUserAccessQuery,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisionedQuery,
|
||||
deleteTopicAccess: postgresDeleteTopicAccessQuery,
|
||||
deleteAllAccess: postgresDeleteAllAccessQuery,
|
||||
selectToken: postgresSelectTokenQuery,
|
||||
selectTokens: postgresSelectTokensQuery,
|
||||
selectTokenCount: postgresSelectTokenCountQuery,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
|
||||
upsertToken: postgresUpsertTokenQuery,
|
||||
updateToken: postgresUpdateTokenQuery,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
|
||||
deleteToken: postgresDeleteTokenQuery,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
|
||||
deleteAllProvisionedTokens: postgresDeleteAllProvisionedTokensQuery,
|
||||
deleteAllToken: postgresDeleteAllTokenQuery,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
|
||||
deleteExcessTokens: postgresDeleteExcessTokensQuery,
|
||||
insertTier: postgresInsertTierQuery,
|
||||
selectTiers: postgresSelectTiersQuery,
|
||||
selectTierByCode: postgresSelectTierByCodeQuery,
|
||||
selectTierByPriceID: postgresSelectTierByPriceIDQuery,
|
||||
updateTier: postgresUpdateTierQuery,
|
||||
deleteTier: postgresDeleteTierQuery,
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbersQuery,
|
||||
insertPhoneNumber: postgresInsertPhoneNumberQuery,
|
||||
deletePhoneNumber: postgresDeletePhoneNumberQuery,
|
||||
updateBilling: postgresUpdateBillingQuery,
|
||||
}
|
||||
|
||||
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
|
||||
func NewPostgresManager(d *db.DB, config *Config) (*Manager, error) {
|
||||
if err := setupPostgres(d.Primary()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newManager(d, postgresQueries, config)
|
||||
}
|
||||
113
user/manager_postgres_schema.go
Normal file
113
user/manager_postgres_schema.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema table management queries for Postgres
|
||||
const (
|
||||
postgresCurrentSchemaVersion = 6
|
||||
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgres(db)
|
||||
}
|
||||
if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgres(db *sql.DB) error {
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
295
user/manager_sqlite.go
Normal file
295
user/manager_sqlite.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// User queries
|
||||
sqliteSelectUsersQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
ORDER BY
|
||||
CASE u.role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, u.user
|
||||
`
|
||||
sqliteSelectUserByIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
sqliteSelectUserByNameQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
sqliteSelectUserByTokenQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||
`
|
||||
sqliteSelectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
`
|
||||
sqliteSelectUsernamesQuery = `
|
||||
SELECT user
|
||||
FROM user
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
sqliteSelectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||
sqliteSelectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
|
||||
sqliteInsertUserQuery = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
sqliteUpdateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
sqliteUpdateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||
sqliteUpdateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
sqliteUpdateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
sqliteUpdateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
sqliteUpdateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
sqliteUpdateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||
sqliteUpdateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
sqliteDeleteUsersProvisionedQuery = `DELETE FROM user WHERE provisioned = 1`
|
||||
|
||||
// Access queries
|
||||
sqliteSelectTopicPermsQuery = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN user u ON u.id = a.user_id
|
||||
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||
`
|
||||
sqliteSelectUserAllAccessQuery = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserAccessQuery = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserReservationsQuery = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
sqliteSelectUserReservationsCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteSelectUserReservationsOwnerQuery = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = ?
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
sqliteSelectUserHasReservationQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteSelectOtherAccessCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||
`
|
||||
sqliteUpsertUserAccessQuery = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||
`
|
||||
sqliteDeleteUserAccessQuery = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
|
||||
sqliteDeleteTopicAccessQuery = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteDeleteAllAccessQuery = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
sqliteSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||
sqliteUpsertTokenQuery = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
|
||||
`
|
||||
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
|
||||
sqliteDeleteAllProvisionedTokensQuery = `DELETE FROM user_token WHERE provisioned = 1`
|
||||
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
sqliteDeleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = ?
|
||||
ORDER BY expires DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
sqliteInsertTierQuery = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteUpdateTierQuery = `
|
||||
UPDATE tier
|
||||
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTiersQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
sqliteSelectTierByCodeQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTierByPriceIDQuery = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||
`
|
||||
sqliteDeleteTierQuery = `DELETE FROM tier WHERE code = ?`
|
||||
|
||||
// Phone queries
|
||||
sqliteSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||
sqliteInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||
sqliteDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||
|
||||
// Billing queries
|
||||
sqliteUpdateBillingQuery = `
|
||||
UPDATE user
|
||||
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||
WHERE user = ?
|
||||
`
|
||||
)
|
||||
|
||||
var sqliteQueries = queries{
|
||||
selectUserByID: sqliteSelectUserByIDQuery,
|
||||
selectUserByName: sqliteSelectUserByNameQuery,
|
||||
selectUserByToken: sqliteSelectUserByTokenQuery,
|
||||
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
|
||||
selectUsernames: sqliteSelectUsernamesQuery,
|
||||
selectUsers: sqliteSelectUsersQuery,
|
||||
selectUserCount: sqliteSelectUserCountQuery,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
|
||||
insertUser: sqliteInsertUserQuery,
|
||||
updateUserPass: sqliteUpdateUserPassQuery,
|
||||
updateUserRole: sqliteUpdateUserRoleQuery,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisionedQuery,
|
||||
updateUserPrefs: sqliteUpdateUserPrefsQuery,
|
||||
updateUserStats: sqliteUpdateUserStatsQuery,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAllQuery,
|
||||
updateUserTier: sqliteUpdateUserTierQuery,
|
||||
updateUserDeleted: sqliteUpdateUserDeletedQuery,
|
||||
deleteUser: sqliteDeleteUserQuery,
|
||||
deleteUserTier: sqliteDeleteUserTierQuery,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
|
||||
deleteUsersProvisioned: sqliteDeleteUsersProvisionedQuery,
|
||||
selectTopicPerms: sqliteSelectTopicPermsQuery,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
|
||||
selectUserAccess: sqliteSelectUserAccessQuery,
|
||||
selectUserReservations: sqliteSelectUserReservationsQuery,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCountQuery,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwnerQuery,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservationQuery,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCountQuery,
|
||||
upsertUserAccess: sqliteUpsertUserAccessQuery,
|
||||
deleteUserAccess: sqliteDeleteUserAccessQuery,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisionedQuery,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccessQuery,
|
||||
deleteAllAccess: sqliteDeleteAllAccessQuery,
|
||||
selectToken: sqliteSelectTokenQuery,
|
||||
selectTokens: sqliteSelectTokensQuery,
|
||||
selectTokenCount: sqliteSelectTokenCountQuery,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
|
||||
upsertToken: sqliteUpsertTokenQuery,
|
||||
updateToken: sqliteUpdateTokenQuery,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
|
||||
deleteToken: sqliteDeleteTokenQuery,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
|
||||
deleteAllProvisionedTokens: sqliteDeleteAllProvisionedTokensQuery,
|
||||
deleteAllToken: sqliteDeleteAllTokenQuery,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
|
||||
insertTier: sqliteInsertTierQuery,
|
||||
selectTiers: sqliteSelectTiersQuery,
|
||||
selectTierByCode: sqliteSelectTierByCodeQuery,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceIDQuery,
|
||||
updateTier: sqliteUpdateTierQuery,
|
||||
deleteTier: sqliteDeleteTierQuery,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbersQuery,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumberQuery,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumberQuery,
|
||||
updateBilling: sqliteUpdateBillingQuery,
|
||||
}
|
||||
|
||||
// NewSQLiteManager creates a new Manager backed by a SQLite database
|
||||
func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager, error) {
|
||||
parentDir := filepath.Dir(filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
d, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(d, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newManager(db.New(&db.Host{DB: d}, nil), sqliteQueries, config)
|
||||
}
|
||||
465
user/manager_sqlite_schema.go
Normal file
465
user/manager_sqlite_schema.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
sqliteCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
|
||||
)
|
||||
|
||||
// Schema version table management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 6
|
||||
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 1 -> 2 (complex migration!)
|
||||
sqliteMigrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery = `SELECT user FROM user_old`
|
||||
sqliteMigrate1To2InsertUserNoTxQuery = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
JOIN access a ON u.user = a.user;
|
||||
|
||||
DROP TABLE access;
|
||||
DROP TABLE user_old;
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
sqliteMigrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
DROP INDEX IF EXISTS idx_tier_price_id;
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
`
|
||||
|
||||
// 3 -> 4
|
||||
sqliteMigrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
sqliteMigrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
sqliteMigrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
INSERT INTO user
|
||||
SELECT
|
||||
id,
|
||||
tier_id,
|
||||
user,
|
||||
pass,
|
||||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
stripe_customer_id,
|
||||
stripe_subscription_id,
|
||||
stripe_subscription_status,
|
||||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
-- Alter user_access table: Add provisioned column
|
||||
ALTER TABLE user_access RENAME TO user_access_old;
|
||||
CREATE TABLE user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INTEGER NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
if err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(sqlDB *sql.DB) error {
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(sqlDB *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTxQuery, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom2(sqlDB *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom3(sqlDB *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom4(sqlDB *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom5(sqlDB *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
2992
user/manager_test.go
2992
user/manager_test.go
File diff suppressed because it is too large
Load Diff
@@ -2,11 +2,12 @@ package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
)
|
||||
|
||||
// User is a struct that represents a user
|
||||
@@ -242,6 +243,20 @@ const (
|
||||
everyoneID = "u_everyone"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the user Manager
|
||||
type Config struct {
|
||||
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
|
||||
DatabaseURL string // Database connection string (PostgreSQL)
|
||||
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
|
||||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
|
||||
// Error constants used by the package
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
@@ -259,3 +274,75 @@ var (
|
||||
ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user")
|
||||
ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token")
|
||||
)
|
||||
|
||||
// queries holds the database-specific SQL queries
|
||||
type queries struct {
|
||||
// User queries
|
||||
selectUserByID string
|
||||
selectUserByName string
|
||||
selectUserByToken string
|
||||
selectUserByStripeCustomerID string
|
||||
selectUsernames string
|
||||
selectUsers string
|
||||
selectUserCount string
|
||||
selectUserIDFromUsername string
|
||||
insertUser string
|
||||
updateUserPass string
|
||||
updateUserRole string
|
||||
updateUserProvisioned string
|
||||
updateUserPrefs string
|
||||
updateUserStats string
|
||||
updateUserStatsResetAll string
|
||||
updateUserTier string
|
||||
updateUserDeleted string
|
||||
deleteUser string
|
||||
deleteUserTier string
|
||||
deleteUsersMarked string
|
||||
deleteUsersProvisioned string
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms string
|
||||
selectUserAllAccess string
|
||||
selectUserAccess string
|
||||
selectUserReservations string
|
||||
selectUserReservationsCount string
|
||||
selectUserReservationsOwner string
|
||||
selectUserHasReservation string
|
||||
selectOtherAccessCount string
|
||||
upsertUserAccess string
|
||||
deleteUserAccess string
|
||||
deleteUserAccessProvisioned string
|
||||
deleteTopicAccess string
|
||||
deleteAllAccess string
|
||||
|
||||
// Token queries
|
||||
selectToken string
|
||||
selectTokens string
|
||||
selectTokenCount string
|
||||
selectAllProvisionedTokens string
|
||||
upsertToken string
|
||||
updateToken string
|
||||
updateTokenLastAccess string
|
||||
deleteToken string
|
||||
deleteProvisionedToken string
|
||||
deleteAllProvisionedTokens string
|
||||
deleteAllToken string
|
||||
deleteExpiredTokens string
|
||||
deleteExcessTokens string
|
||||
|
||||
// Tier queries
|
||||
insertTier string
|
||||
selectTiers string
|
||||
selectTierByCode string
|
||||
selectTierByPriceID string
|
||||
updateTier string
|
||||
deleteTier string
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers string
|
||||
insertPhoneNumber string
|
||||
deletePhoneNumber string
|
||||
|
||||
// Billing queries
|
||||
updateBilling string
|
||||
}
|
||||
|
||||
40
user/util.go
40
user/util.go
@@ -1,10 +1,12 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -77,3 +79,37 @@ func hashPassword(password string, cost int) (string, error) {
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
func nullInt64(v int64) sql.NullInt64 {
|
||||
if v == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
return sql.NullInt64{Int64: v, Valid: true}
|
||||
}
|
||||
|
||||
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
|
||||
// and escapes '_', assuming '\' as escape character.
|
||||
func toSQLWildcard(s string) string {
|
||||
return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
|
||||
}
|
||||
|
||||
// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
|
||||
// and removes the '\_' escape character.
|
||||
func fromSQLWildcard(s string) string {
|
||||
return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
|
||||
}
|
||||
|
||||
func escapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "_", "\\_")
|
||||
}
|
||||
|
||||
func unescapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "\\_", "_")
|
||||
}
|
||||
|
||||
281
user/util_test.go
Normal file
281
user/util_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAllowedRole(t *testing.T) {
|
||||
require.True(t, AllowedRole(RoleUser))
|
||||
require.True(t, AllowedRole(RoleAdmin))
|
||||
require.False(t, AllowedRole(RoleAnonymous))
|
||||
require.False(t, AllowedRole(Role("invalid")))
|
||||
require.False(t, AllowedRole(Role("")))
|
||||
require.False(t, AllowedRole(Role("superadmin")))
|
||||
}
|
||||
|
||||
func TestAllowedTopic(t *testing.T) {
|
||||
// Valid topics
|
||||
require.True(t, AllowedTopic("test"))
|
||||
require.True(t, AllowedTopic("mytopic"))
|
||||
require.True(t, AllowedTopic("topic123"))
|
||||
require.True(t, AllowedTopic("my-topic"))
|
||||
require.True(t, AllowedTopic("my_topic"))
|
||||
require.True(t, AllowedTopic("Topic123"))
|
||||
require.True(t, AllowedTopic("a"))
|
||||
require.True(t, AllowedTopic(strings.Repeat("a", 64))) // Max length
|
||||
|
||||
// Invalid topics - wildcards not allowed
|
||||
require.False(t, AllowedTopic("topic*"))
|
||||
require.False(t, AllowedTopic("*"))
|
||||
require.False(t, AllowedTopic("my*topic"))
|
||||
|
||||
// Invalid topics - special characters
|
||||
require.False(t, AllowedTopic("my topic")) // Space
|
||||
require.False(t, AllowedTopic("my.topic")) // Dot
|
||||
require.False(t, AllowedTopic("my/topic")) // Slash
|
||||
require.False(t, AllowedTopic("my@topic")) // At sign
|
||||
require.False(t, AllowedTopic("my+topic")) // Plus
|
||||
require.False(t, AllowedTopic("topic!")) // Exclamation
|
||||
require.False(t, AllowedTopic("topic#")) // Hash
|
||||
require.False(t, AllowedTopic("topic$")) // Dollar
|
||||
require.False(t, AllowedTopic("topic%")) // Percent
|
||||
require.False(t, AllowedTopic("topic&")) // Ampersand
|
||||
require.False(t, AllowedTopic("my\\topic")) // Backslash
|
||||
|
||||
// Invalid topics - length
|
||||
require.False(t, AllowedTopic("")) // Empty
|
||||
require.False(t, AllowedTopic(strings.Repeat("a", 65))) // Too long
|
||||
}
|
||||
|
||||
func TestAllowedTopicPattern(t *testing.T) {
|
||||
// Valid patterns - same as AllowedTopic
|
||||
require.True(t, AllowedTopicPattern("test"))
|
||||
require.True(t, AllowedTopicPattern("mytopic"))
|
||||
require.True(t, AllowedTopicPattern("topic123"))
|
||||
require.True(t, AllowedTopicPattern("my-topic"))
|
||||
require.True(t, AllowedTopicPattern("my_topic"))
|
||||
require.True(t, AllowedTopicPattern("a"))
|
||||
require.True(t, AllowedTopicPattern(strings.Repeat("a", 64))) // Max length
|
||||
|
||||
// Valid patterns - with wildcards
|
||||
require.True(t, AllowedTopicPattern("*"))
|
||||
require.True(t, AllowedTopicPattern("topic*"))
|
||||
require.True(t, AllowedTopicPattern("*topic"))
|
||||
require.True(t, AllowedTopicPattern("my*topic"))
|
||||
require.True(t, AllowedTopicPattern("***"))
|
||||
require.True(t, AllowedTopicPattern("test_*"))
|
||||
require.True(t, AllowedTopicPattern("my-*-topic"))
|
||||
require.True(t, AllowedTopicPattern(strings.Repeat("*", 64))) // Max length with wildcards
|
||||
|
||||
// Invalid patterns - special characters (other than wildcard)
|
||||
require.False(t, AllowedTopicPattern("my topic")) // Space
|
||||
require.False(t, AllowedTopicPattern("my.topic")) // Dot
|
||||
require.False(t, AllowedTopicPattern("my/topic")) // Slash
|
||||
require.False(t, AllowedTopicPattern("my@topic")) // At sign
|
||||
require.False(t, AllowedTopicPattern("my+topic")) // Plus
|
||||
require.False(t, AllowedTopicPattern("topic!")) // Exclamation
|
||||
require.False(t, AllowedTopicPattern("topic#")) // Hash
|
||||
require.False(t, AllowedTopicPattern("topic$")) // Dollar
|
||||
require.False(t, AllowedTopicPattern("topic%")) // Percent
|
||||
require.False(t, AllowedTopicPattern("topic&")) // Ampersand
|
||||
require.False(t, AllowedTopicPattern("my\\topic")) // Backslash
|
||||
|
||||
// Invalid patterns - length
|
||||
require.False(t, AllowedTopicPattern("")) // Empty
|
||||
require.False(t, AllowedTopicPattern(strings.Repeat("a", 65))) // Too long
|
||||
}
|
||||
|
||||
func TestValidPasswordHash(t *testing.T) {
|
||||
// Valid bcrypt hashes with different versions
|
||||
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Nil(t, ValidPasswordHash("$2b$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", 10))
|
||||
require.Nil(t, ValidPasswordHash("$2y$12$1234567890123456789012u1234567890123456789012345678901", 10))
|
||||
|
||||
// Valid hash with minimum cost
|
||||
require.Nil(t, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 4))
|
||||
|
||||
// Invalid - wrong prefix
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$2c$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$3a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("bcrypt$10$hash", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("nothash", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("", 10))
|
||||
|
||||
// Invalid - malformed hash
|
||||
require.NotNil(t, ValidPasswordHash("$2a$10$tooshort", 10))
|
||||
require.NotNil(t, ValidPasswordHash("$2a$10", 10))
|
||||
require.NotNil(t, ValidPasswordHash("$2a$", 10))
|
||||
|
||||
// Invalid - cost too low
|
||||
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 10))
|
||||
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$09$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
|
||||
// Edge case - cost exactly at minimum
|
||||
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
}
|
||||
|
||||
func TestValidToken(t *testing.T) {
|
||||
// Valid tokens
|
||||
require.True(t, ValidToken("tk_1234567890123456789012345678x"))
|
||||
require.True(t, ValidToken("tk_abcdefghijklmnopqrstuvwxyzabc"))
|
||||
require.True(t, ValidToken("tk_ABCDEFGHIJKLMNOPQRSTUVWXYZABC"))
|
||||
require.True(t, ValidToken("tk_012345678901234567890123456ab"))
|
||||
require.True(t, ValidToken("tk_-----------------------------"))
|
||||
require.True(t, ValidToken("tk______________________________"))
|
||||
|
||||
// Invalid tokens - wrong prefix
|
||||
require.False(t, ValidToken("tx_1234567890123456789012345678x"))
|
||||
require.False(t, ValidToken("tk1234567890123456789012345678xy"))
|
||||
require.False(t, ValidToken("token_1234567890123456789012345"))
|
||||
|
||||
// Invalid tokens - wrong length
|
||||
require.False(t, ValidToken("tk_")) // Too short
|
||||
require.False(t, ValidToken("tk_123")) // Too short
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567890")) // Too long (30 chars after prefix)
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567")) // Too short (28 chars)
|
||||
|
||||
// Invalid tokens - invalid characters
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567!@"))
|
||||
require.False(t, ValidToken("tk_12345678901234567890123456 8x"))
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567.x"))
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567*x"))
|
||||
|
||||
// Invalid tokens - no prefix
|
||||
require.False(t, ValidToken("1234567890123456789012345678901x"))
|
||||
require.False(t, ValidToken(""))
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
// Generate multiple tokens
|
||||
tokens := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
token := GenerateToken()
|
||||
|
||||
// Check format
|
||||
require.True(t, strings.HasPrefix(token, "tk_"), "Token should start with tk_")
|
||||
require.Equal(t, 32, len(token), "Token should be 32 characters long")
|
||||
|
||||
// Check it's valid
|
||||
require.True(t, ValidToken(token), "Generated token should be valid")
|
||||
|
||||
// Check it's lowercase
|
||||
require.Equal(t, strings.ToLower(token), token, "Token should be lowercase")
|
||||
|
||||
// Check uniqueness
|
||||
require.False(t, tokens[token], "Token should be unique")
|
||||
tokens[token] = true
|
||||
}
|
||||
|
||||
// Verify we got 100 unique tokens
|
||||
require.Equal(t, 100, len(tokens))
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
password := "test-password-123"
|
||||
|
||||
// Hash the password
|
||||
hash, err := HashPassword(password)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Check it's a valid bcrypt hash
|
||||
require.Nil(t, ValidPasswordHash(hash, DefaultUserPasswordBcryptCost))
|
||||
|
||||
// Check it starts with correct prefix
|
||||
require.True(t, strings.HasPrefix(hash, "$2a$"))
|
||||
|
||||
// Hash the same password again - should produce different hash
|
||||
hash2, err := HashPassword(password)
|
||||
require.Nil(t, err)
|
||||
require.NotEqual(t, hash, hash2, "Same password should produce different hashes (salt)")
|
||||
|
||||
// Empty password should still work
|
||||
emptyHash, err := HashPassword("")
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, emptyHash)
|
||||
require.Nil(t, ValidPasswordHash(emptyHash, DefaultUserPasswordBcryptCost))
|
||||
}
|
||||
|
||||
func TestHashPassword_WithCost(t *testing.T) {
|
||||
password := "test-password"
|
||||
|
||||
// Test with different costs
|
||||
hash4, err := hashPassword(password, 4)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash4, "$2a$04$"))
|
||||
|
||||
hash10, err := hashPassword(password, 10)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash10, "$2a$10$"))
|
||||
|
||||
hash12, err := hashPassword(password, 12)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash12, "$2a$12$"))
|
||||
|
||||
// All should be valid
|
||||
require.Nil(t, ValidPasswordHash(hash4, 4))
|
||||
require.Nil(t, ValidPasswordHash(hash10, 10))
|
||||
require.Nil(t, ValidPasswordHash(hash12, 12))
|
||||
}
|
||||
|
||||
func TestUser_TierID(t *testing.T) {
|
||||
// User with tier
|
||||
u := &User{
|
||||
Tier: &Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
},
|
||||
}
|
||||
require.Equal(t, "ti_123", u.TierID())
|
||||
|
||||
// User without tier
|
||||
u2 := &User{
|
||||
Tier: nil,
|
||||
}
|
||||
require.Equal(t, "", u2.TierID())
|
||||
|
||||
// Nil user
|
||||
var u3 *User
|
||||
require.Equal(t, "", u3.TierID())
|
||||
}
|
||||
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
admin := &User{Role: RoleAdmin}
|
||||
require.True(t, admin.IsAdmin())
|
||||
require.False(t, admin.IsUser())
|
||||
|
||||
user := &User{Role: RoleUser}
|
||||
require.False(t, user.IsAdmin())
|
||||
|
||||
anonymous := &User{Role: RoleAnonymous}
|
||||
require.False(t, anonymous.IsAdmin())
|
||||
|
||||
// Nil user
|
||||
var nilUser *User
|
||||
require.False(t, nilUser.IsAdmin())
|
||||
}
|
||||
|
||||
func TestUser_IsUser(t *testing.T) {
|
||||
user := &User{Role: RoleUser}
|
||||
require.True(t, user.IsUser())
|
||||
require.False(t, user.IsAdmin())
|
||||
|
||||
admin := &User{Role: RoleAdmin}
|
||||
require.False(t, admin.IsUser())
|
||||
|
||||
anonymous := &User{Role: RoleAnonymous}
|
||||
require.False(t, anonymous.IsUser())
|
||||
|
||||
// Nil user
|
||||
var nilUser *User
|
||||
require.False(t, nilUser.IsUser())
|
||||
}
|
||||
|
||||
func TestPermission_String(t *testing.T) {
|
||||
require.Equal(t, "read-write", PermissionReadWrite.String())
|
||||
require.Equal(t, "read-only", PermissionRead.String())
|
||||
require.Equal(t, "write-only", PermissionWrite.String())
|
||||
require.Equal(t, "deny-all", PermissionDenyAll.String())
|
||||
}
|
||||
20
util/util.go
20
util/util.go
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"golang.org/x/term"
|
||||
@@ -434,3 +435,22 @@ func Int(v int) *int {
|
||||
func Time(v time.Time) *time.Time {
|
||||
return &v
|
||||
}
|
||||
|
||||
// SanitizeUTF8 ensures a string is safe to store in PostgreSQL by handling two cases:
|
||||
//
|
||||
// 1. Invalid UTF-8 sequences: Some clients send Latin-1/ISO-8859-1 encoded text (e.g. accented
|
||||
// characters like é, ñ, ß) in HTTP headers or SMTP messages. Go treats these as raw bytes in
|
||||
// strings, but PostgreSQL rejects them. Any invalid UTF-8 byte is replaced with the Unicode
|
||||
// replacement character (U+FFFD, "<22>") so the message is still delivered rather than lost.
|
||||
//
|
||||
// 2. NUL bytes (0x00): These are valid in UTF-8 but PostgreSQL TEXT columns reject them.
|
||||
// They are stripped entirely.
|
||||
func SanitizeUTF8(s string) string {
|
||||
if !utf8.ValidString(s) {
|
||||
s = strings.ToValidUTF8(s, "\xef\xbf\xbd") // U+FFFD
|
||||
}
|
||||
if strings.ContainsRune(s, 0) {
|
||||
s = strings.ReplaceAll(s, "\x00", "")
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user