Compare commits

...

25 Commits

Author SHA1 Message Date
binwiederhier
6b11bc7468 Merge branch 'main' into attachment-s3 2026-03-16 21:11:38 -04:00
binwiederhier
d9efe50848 Email validation 2026-03-16 21:03:33 -04:00
binwiederhier
2ad78edca1 Release notes 2026-03-16 20:13:39 -04:00
binwiederhier
86015e100c Multipart upload 2026-03-16 20:00:19 -04:00
binwiederhier
458fbad770 Merge branch 'main' into attachment-s3 2026-03-16 15:53:21 -04:00
binwiederhier
9b1a32ec56 Refine 2026-03-16 15:52:57 -04:00
binwiederhier
3d9ce69042 PG races 2026-03-16 15:48:36 -04:00
binwiederhier
59ce581ba2 Fix postgres primary/replica races 2026-03-16 11:21:21 -04:00
binwiederhier
df82fdf44c Add HTTP 413 to normal errors to not log 2026-03-16 10:27:23 -04:00
binwiederhier
3a37ea32f7 Webpush: Fix FK issue with Postgres 2026-03-16 10:24:16 -04:00
binwiederhier
790ba243c7 S3 WIP 2026-03-16 09:48:26 -04:00
binwiederhier
4487299a80 Merge branch 'main' into attachment-s3 2026-03-16 05:43:20 -04:00
binwiederhier
6b38acb23a Route authorization query to read-only database replica to reduce primary database load 2026-03-15 22:01:19 -04:00
binwiederhier
f5c255c53c Grr 2026-03-15 21:17:58 -04:00
binwiederhier
fd0a49244e Disable test temporarily 2026-03-15 21:13:12 -04:00
binwiederhier
4699ed3ffd Fix UTF-8 insert failures in Postgres 2026-03-15 21:03:18 -04:00
Philipp C. Heckel
1afb99db67 Merge pull request #1658 from BradStaton/1657-postgresql-urls
Support `postgresql://` and `postgres://` URLs
2026-03-15 20:45:08 -04:00
binwiederhier
66208e6f88 Pre-import 2026-03-15 20:25:22 -04:00
BradStaton
ce24594c32 Update serve.go
Support multiple postgres connection URL formats
2026-03-15 16:22:22 -04:00
binwiederhier
888850d8bc Add blurp 2026-03-15 10:29:07 -04:00
binwiederhier
be09acd411 Bump 2026-03-15 10:26:03 -04:00
binwiederhier
bf19a5be2d Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-03-15 10:12:54 -04:00
binwiederhier
b4ec6fa8df AWS deps.. 2026-03-15 10:12:23 -04:00
binwiederhier
d517ce4a2a WIP: S3 2026-03-14 21:10:46 -04:00
BonifacioCalindoro
fd8f356d1f Translated using Weblate (Spanish)
Currently translated at 100.0% (407 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/es/
2026-03-15 01:09:47 +01:00
43 changed files with 3324 additions and 295 deletions

1
.gitignore vendored
View File

@@ -9,6 +9,7 @@ server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
tools/loadtest/loadtest
tools/s3cli/s3cli
playground/
secrets/
*.iml

25
attachment/store.go Normal file
View File

@@ -0,0 +1,25 @@
package attachment
import (
"errors"
"fmt"
"io"
"regexp"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
// Store is an interface for storing and retrieving attachment files
type Store interface {
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
Read(id string) (io.ReadCloser, int64, error)
Remove(ids ...string) error
Size() int64
Remaining() int64
}
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)

View File

@@ -1,32 +1,29 @@
package server
package attachment
import (
"errors"
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
"io"
"os"
"path/filepath"
"regexp"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists")
)
const tagFileStore = "file_store"
type fileCache struct {
var errFileExists = errors.New("file exists")
type fileStore struct {
dir string
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
// NewFileStore creates a new file-system backed attachment store
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
@@ -34,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
if err != nil {
return nil, err
}
return &fileCache{
return &fileStore{
dir: dir,
totalSizeCurrent: size,
totalSizeLimit: totalSizeLimit,
}, nil
}
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
file := filepath.Join(c.dir, id)
if _, err := os.Stat(file); err == nil {
return 0, errFileExists
@@ -68,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in
}
c.mu.Lock()
c.totalSizeCurrent += size
mset(metricAttachmentsTotalSize, c.totalSizeCurrent)
c.mu.Unlock()
return size, nil
}
func (c *fileCache) Remove(ids ...string) error {
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
file := filepath.Join(c.dir, id)
stat, err := os.Stat(file)
if err != nil {
return nil, 0, err
}
f, err := os.Open(file)
if err != nil {
return nil, 0, err
}
return f, stat.Size(), nil
}
func (c *fileStore) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
file := filepath.Join(c.dir, id)
if err := os.Remove(file); err != nil {
log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
}
}
size, err := dirSize(c.dir)
@@ -91,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error {
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
mset(metricAttachmentsTotalSize, size)
return nil
}
func (c *fileCache) Size() int64 {
func (c *fileStore) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileCache) Remaining() int64 {
func (c *fileStore) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent

View File

@@ -0,0 +1,99 @@
package attachment
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileStore_Write_Success(t *testing.T) {
dir, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), s.Size())
require.Equal(t, int64(10229), s.Remaining())
}
func TestFileStore_Write_Read_Success(t *testing.T) {
_, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
reader, readSize, err := s.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
defer reader.Close()
data, err := io.ReadAll(reader)
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestFileStore_Write_Remove_Success(t *testing.T) {
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), s.Size())
require.Equal(t, int64(250), s.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), s.Size())
require.Equal(t, int64(2248), s.Remaining())
}
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
dir, s := newTestFileStore(t)
for i := 0; i < 10; i++ {
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
dir, s := newTestFileStore(t)
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Read_NotFound(t *testing.T) {
_, s := newTestFileStore(t)
_, _, err := s.Read("abcdefghijkl")
require.Error(t, err)
}
func newTestFileStore(t *testing.T) (dir string, store Store) {
dir = t.TempDir()
store, err := NewFileStore(dir, 10*1024)
require.Nil(t, err)
return dir, store
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

150
attachment/store_s3.go Normal file
View File

@@ -0,0 +1,150 @@
package attachment
import (
"context"
"fmt"
"io"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
const (
tagS3Store = "s3_store"
)
type s3Store struct {
client *s3.Client
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
cfg, err := s3.ParseURL(s3URL)
if err != nil {
return nil, err
}
store := &s3Store{
client: s3.New(cfg),
totalSizeLimit: totalSizeLimit,
}
if totalSizeLimit > 0 {
size, err := store.computeSize()
if err != nil {
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
}
store.totalSizeCurrent = size
}
return store, nil
}
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
// uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
pr, pw := io.Pipe()
lw := util.NewLimitWriter(pw, limiters...)
var size int64
var copyErr error
done := make(chan struct{})
go func() {
defer close(done)
size, copyErr = io.Copy(lw, in)
if copyErr != nil {
pw.CloseWithError(copyErr)
} else {
pw.Close()
}
}()
putErr := c.client.PutObject(context.Background(), id, pr)
pr.Close()
<-done
if copyErr != nil {
return 0, copyErr
}
if putErr != nil {
return 0, putErr
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
return c.client.GetObject(context.Background(), id)
}
func (c *s3Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
}
// S3 DeleteObjects supports up to 1000 keys per call
for i := 0; i < len(ids); i += 1000 {
end := i + 1000
if end > len(ids) {
end = len(ids)
}
batch := ids[i:end]
for _, id := range batch {
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
}
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
return err
}
}
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
size, err := c.computeSize()
if err != nil {
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *s3Store) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *s3Store) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
func (c *s3Store) computeSize() (int64, error) {
objects, err := c.client.ListAllObjects(context.Background())
if err != nil {
return 0, err
}
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
return totalSize, nil
}

367
attachment/store_s3_test.go Normal file
View File

@@ -0,0 +1,367 @@
package attachment
import (
"bytes"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
// --- Integration tests using a mock S3 server ---
func TestS3Store_WriteReadRemove(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, int64(11), store.Size())
// Read back
reader, readSize, err := store.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
// Remove
require.Nil(t, store.Remove("abcdefghijkl"))
require.Equal(t, int64(0), store.Size())
// Read after remove should fail
_, _, err = store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_WriteNoPrefix(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
require.Nil(t, err)
require.Equal(t, int64(4), size)
reader, _, err := store.Read("abcdefghijkl")
require.Nil(t, err)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "test", string(data))
}
func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
// First write fits
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
require.Nil(t, err)
require.Equal(t, int64(80), store.Size())
require.Equal(t, int64(20), store.Remaining())
// Second write exceeds total limit
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
for i := 0; i < 5; i++ {
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
require.Nil(t, err)
}
require.Equal(t, int64(500), store.Size())
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
require.Equal(t, int64(300), store.Size())
}
func TestS3Store_ReadNotFound(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, _, err := store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_InvalidID(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("bad", strings.NewReader("x"))
require.Equal(t, errInvalidFileID, err)
_, _, err = store.Read("bad")
require.Equal(t, errInvalidFileID, err)
err = store.Remove("bad")
require.Equal(t, errInvalidFileID, err)
}
// --- Helpers ---
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
t.Helper()
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
s := &s3Store{
client: &s3.Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
},
totalSizeLimit: totalSizeLimit,
}
// Compute initial size (should be 0 for fresh mock)
size, err := s.computeSize()
require.Nil(t, err)
s.totalSizeCurrent = size
return s
}
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() *httptest.Server {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m)
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
m.mu.RLock()
var contents []s3ListObject
for key, body := range m.objects {
// key is "bucket/objectkey", strip bucket prefix
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
}
}
m.mu.RUnlock()
resp := s3ListResponse{
Contents: contents,
IsTruncated: false,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
type s3ListResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []s3ListObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
}
type s3ListObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}

View File

@@ -2,9 +2,6 @@ package cmd
import (
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/test"
"heckel.io/ntfy/v2/util"
"net/http"
"net/http/httptest"
"os"
@@ -14,9 +11,14 @@ import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/test"
"heckel.io/ntfy/v2/util"
)
func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
t.Skip("temporarily disabled") // FIXME
testMessage := util.RandomString(10)
app, _, _, _ := newTestApp()
require.Nil(t, app.Run([]string{"ntfy", "publish", "ntfytest", "ntfy unit test " + testMessage}))

View File

@@ -53,6 +53,7 @@ var flagsServe = append(
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}),
@@ -166,6 +167,7 @@ func execServe(c *cli.Context) error {
authAccessRaw := c.StringSlice("auth-access")
authTokensRaw := c.StringSlice("auth-tokens")
attachmentCacheDir := c.String("attachment-cache-dir")
attachmentS3URL := c.String("attachment-s3-url")
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
attachmentExpiryDurationStr := c.String("attachment-expiry-duration")
@@ -284,8 +286,8 @@ func execServe(c *cli.Context) error {
}
// Check values
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") {
return errors.New("if database-url is set, it must start with postgres://")
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") && !strings.HasPrefix(databaseURL, "postgresql://") {
return errors.New("if database-url is set, it must start with postgres:// or postgresql://")
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
} else if len(databaseReplicaURLs) > 0 && databaseURL == "" {
@@ -314,6 +316,10 @@ func execServe(c *cli.Context) error {
return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set")
} else if attachmentCacheDir != "" && baseURL == "" {
return errors.New("if attachment-cache-dir is set, base-url must also be set")
} else if attachmentS3URL != "" && baseURL == "" {
return errors.New("if attachment-s3-url is set, base-url must also be set")
} else if attachmentS3URL != "" && attachmentCacheDir != "" {
return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive")
} else if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
@@ -457,6 +463,7 @@ func execServe(c *cli.Context) error {
conf.AuthAccess = authAccess
conf.AuthTokens = authTokens
conf.AttachmentCacheDir = attachmentCacheDir
conf.AttachmentS3URL = attachmentS3URL
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
conf.AttachmentExpiryDuration = attachmentExpiryDuration

View File

@@ -11,6 +11,12 @@ type Beginner interface {
Begin() (*sql.Tx, error)
}
// Querier is an interface for types that can execute SQL queries.
// *sql.DB, *sql.Tx, and *DB all implement this.
type Querier interface {
Query(query string, args ...any) (*sql.Rows, error)
}
// Host pairs a *sql.DB with the host:port it was opened against.
type Host struct {
Addr string // "host:port"

View File

@@ -489,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri
## Attachments
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
Once these options are set and the directory is writable by the server user, you can upload attachments via PUT.
this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored
either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`).
Once configured, you can upload attachments via PUT.
By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
feature) to download the file. The following config options are relevant to attachments:
* `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs
* `attachment-cache-dir` is the cache directory for attached files
* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G)
* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`)
* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`)
* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G)
* `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M)
* `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h)
Here's an example config using mostly the defaults (except for the cache directory, which is empty by default):
### Filesystem storage
Here's an example config using the local filesystem for attachment storage:
=== "/etc/ntfy/server.yml (minimal)"
``` yaml
@@ -521,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo
visitor-attachment-daily-bandwidth-limit: "500M"
```
### S3 storage
As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3,
MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage.
The `attachment-s3-url` option uses the following format:
```
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
```
When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores).
=== "/etc/ntfy/server.yml (AWS S3)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1"
```
=== "/etc/ntfy/server.yml (MinIO/custom endpoint)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com"
```
Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit`
and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse.
@@ -2116,7 +2143,8 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. |
| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. |
| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. |
| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. |
| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. |
@@ -2219,6 +2247,7 @@ OPTIONS:
--auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES]
--auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS]
--attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR]
--attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL]
--attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
--attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
--attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION]

View File

@@ -30,37 +30,37 @@ deb/rpm packages.
=== "x86_64/amd64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.tar.gz
tar zxvf ntfy_2.18.0_linux_amd64.tar.gz
sudo cp -a ntfy_2.18.0_linux_amd64/ntfy /usr/local/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_amd64/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.tar.gz
tar zxvf ntfy_2.19.2_linux_amd64.tar.gz
sudo cp -a ntfy_2.19.2_linux_amd64/ntfy /usr/local/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_amd64/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "armv6"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.tar.gz
tar zxvf ntfy_2.18.0_linux_armv6.tar.gz
sudo cp -a ntfy_2.18.0_linux_armv6/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_armv6/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.tar.gz
tar zxvf ntfy_2.19.2_linux_armv6.tar.gz
sudo cp -a ntfy_2.19.2_linux_armv6/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv6/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "armv7/armhf"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.tar.gz
tar zxvf ntfy_2.18.0_linux_armv7.tar.gz
sudo cp -a ntfy_2.18.0_linux_armv7/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_armv7/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.tar.gz
tar zxvf ntfy_2.19.2_linux_armv7.tar.gz
sudo cp -a ntfy_2.19.2_linux_armv7/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv7/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
=== "arm64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.tar.gz
tar zxvf ntfy_2.18.0_linux_arm64.tar.gz
sudo cp -a ntfy_2.18.0_linux_arm64/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_arm64/{client,server}/*.yml /etc/ntfy
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.tar.gz
tar zxvf ntfy_2.19.2_linux_arm64.tar.gz
sudo cp -a ntfy_2.19.2_linux_arm64/ntfy /usr/bin/ntfy
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_arm64/{client,server}/*.yml /etc/ntfy
sudo ntfy serve
```
@@ -116,7 +116,7 @@ Manually installing the .deb file:
=== "x86_64/amd64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -124,7 +124,7 @@ Manually installing the .deb file:
=== "armv6"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -132,7 +132,7 @@ Manually installing the .deb file:
=== "armv7/armhf"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -140,7 +140,7 @@ Manually installing the .deb file:
=== "arm64"
```bash
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.deb
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.deb
sudo dpkg -i ntfy_*.deb
sudo systemctl enable ntfy
sudo systemctl start ntfy
@@ -150,28 +150,28 @@ Manually installing the .deb file:
=== "x86_64/amd64"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "armv6"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "armv7/armhf"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
=== "arm64"
```bash
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.rpm
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.rpm
sudo systemctl enable ntfy
sudo systemctl start ntfy
```
@@ -213,18 +213,18 @@ pkg install go-ntfy
## macOS
The [ntfy CLI](subscribe/cli.md) (`ntfy publish` and `ntfy subscribe` only) is supported on macOS as well.
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_darwin_all.tar.gz),
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz),
extract it and place it somewhere in your `PATH` (e.g. `/usr/local/bin/ntfy`).
If run as `root`, ntfy will look for its config at `/etc/ntfy/client.yml`. For all other users, it'll look for it at
`~/Library/Application Support/ntfy/client.yml` (sample included in the tarball).
```bash
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_darwin_all.tar.gz > ntfy_2.18.0_darwin_all.tar.gz
tar zxvf ntfy_2.18.0_darwin_all.tar.gz
sudo cp -a ntfy_2.18.0_darwin_all/ntfy /usr/local/bin/ntfy
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz > ntfy_2.19.2_darwin_all.tar.gz
tar zxvf ntfy_2.19.2_darwin_all.tar.gz
sudo cp -a ntfy_2.19.2_darwin_all/ntfy /usr/local/bin/ntfy
mkdir ~/Library/Application\ Support/ntfy
cp ntfy_2.18.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
cp ntfy_2.19.2_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
ntfy --help
```
@@ -245,7 +245,7 @@ brew install ntfy
The ntfy server and CLI are fully supported on Windows. You can run the ntfy server directly or as a Windows service.
To install, you can either
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_windows_amd64.zip),
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_windows_amd64.zip),
extract it and place the `ntfy.exe` binary somewhere in your `%Path%`.
* Or install ntfy from the [Scoop](https://scoop.sh) main repository via `scoop install ntfy`

View File

@@ -6,12 +6,55 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
| Component | Version | Release date |
|------------------|---------|--------------|
| ntfy server | v2.18.0 | Mar 7, 2026 |
| ntfy server | v2.19.2 | Mar 16, 2026 |
| ntfy Android app | v1.24.0 | Mar 5, 2026 |
| ntfy iOS app | v1.3 | Nov 26, 2023 |
Please check out the release notes for [upcoming releases](#not-released-yet) below.
### ntfy server v2.19.2
Released March 16, 2026
This is another small bugfix release for PostgreSQL, avoiding races between primary and read replica, as well as to
further reduce primary load.
**Bug fixes + maintenance:**
* Fix race condition in web push subscription causing FK constraint violation when concurrent requests hit the same endpoint
* Route authorization query to read-only database replica to reduce primary database load
## ntfy server v2.19.1
Released March 15, 2026
This is a bugfix release to avoid PostgreSQL insert failures due to invalid UTF-8 messages. It also fixes `database-url`
validation incorrectly rejecting `postgresql://` connection strings.
**Bug fixes + maintenance:**
* Fix invalid UTF-8 in HTTP headers (e.g. Latin-1 encoded text) causing PostgreSQL insert failures and dropping entire message batches
* Fix `database-url` validation rejecting `postgresql://` connection strings ([#1657](https://github.com/binwiederhier/ntfy/issues/1657)/[#1658](https://github.com/binwiederhier/ntfy/pull/1658))
## ntfy server v2.19.0
Released March 15, 2026
This is a fast-follow release that enables Postgres read replica support.
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages,
checking access permissions, etc) are distributed across the replicas using round-robin, while all writes and
correctness-critical reads continue to go to the primary. If a replica becomes unhealthy, ntfy automatically falls back
to the primary until the replica recovers.
**Features:**
* Support [PostgreSQL read replicas](config.md#postgresql-experimental) for offloading non-critical read queries via `database-replica-urls` config option ([#1648](https://github.com/binwiederhier/ntfy/pull/1648))
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files ([#1654](https://github.com/binwiederhier/ntfy/pull/1654))
**Bug fixes + maintenance:**
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
## ntfy server v2.18.0
Released March 7, 2026
@@ -1755,14 +1798,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
## Not released yet
### ntfy server v2.19.x (UNRELEASED)
### ntfy server v2.20.x (UNRELEASED)
**Features:**
* Support PostgreSQL read replicas for offloading non-critical read queries via `database-replica-urls` config option
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files
* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option
**Bug fixes + maintenance:**
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
* Reject invalid e-mail addresses (e.g. multiple comma-separated recipients) with HTTP 400

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.25.0
require (
cloud.google.com/go/firestore v1.21.0 // indirect
cloud.google.com/go/storage v1.61.1 // indirect
cloud.google.com/go/storage v1.61.3 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
github.com/emersion/go-smtp v0.18.0

4
go.sum
View File

@@ -18,8 +18,8 @@ cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
cloud.google.com/go/storage v1.61.1 h1:VELCSvZKiSw0AS1k3so5mKGy3CB7bTCYD8EHhTF42bY=
cloud.google.com/go/storage v1.61.1/go.mod h1:k30/hwYfd0M8aULYbPkQLgNf+SFcdjlRHvLMXggw18E=
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8=

View File

@@ -125,16 +125,16 @@ func (c *Cache) addMessages(ms []*model.Message) error {
return model.ErrUnexpectedMessageType
}
published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",")
tags := util.SanitizeUTF8(strings.Join(m.Tags, ","))
var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
if m.Attachment != nil {
attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type
attachmentName = util.SanitizeUTF8(m.Attachment.Name)
attachmentType = util.SanitizeUTF8(m.Attachment.Type)
attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires
attachmentURL = m.Attachment.URL
attachmentURL = util.SanitizeUTF8(m.Attachment.URL)
}
var actionsStr string
if len(m.Actions) > 0 {
@@ -154,13 +154,13 @@ func (c *Cache) addMessages(ms []*model.Message) error {
m.Time,
m.Event,
m.Expires,
m.Topic,
m.Message,
m.Title,
util.SanitizeUTF8(m.Topic),
util.SanitizeUTF8(m.Message),
util.SanitizeUTF8(m.Title),
m.Priority,
tags,
m.Click,
m.Icon,
util.SanitizeUTF8(m.Click),
util.SanitizeUTF8(m.Icon),
actionsStr,
attachmentName,
attachmentType,
@@ -170,7 +170,7 @@ func (c *Cache) addMessages(ms []*model.Message) error {
attachmentDeleted, // Always zero
sender,
m.User,
m.ContentType,
util.SanitizeUTF8(m.ContentType),
m.Encoding,
published,
)

View File

@@ -827,3 +827,141 @@ func TestStore_MessageFieldRoundTrip(t *testing.T) {
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
})
}
func TestStore_AddMessage_InvalidUTF8(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// 0xc9 0x43: Latin-1 "ÉC" — 0xc9 starts a 2-byte UTF-8 sequence but 0x43 ('C') is not a continuation byte
m := model.NewDefaultMessage("mytopic", "\xc9Cas du serveur")
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "\uFFFDCas du serveur", messages[0].Message)
// 0xae: Latin-1 "®" — isolated byte above 0x7F, not a valid UTF-8 start for single byte
m2 := model.NewDefaultMessage("mytopic", "Product\xae Pro")
require.Nil(t, s.AddMessage(m2))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "Product\uFFFD Pro", messages[1].Message)
// 0xe8 0x6d 0x65: Latin-1 "ème" — 0xe8 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
m3 := model.NewDefaultMessage("mytopic", "probl\xe8me critique")
require.Nil(t, s.AddMessage(m3))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "probl\uFFFDme critique", messages[2].Message)
// 0xb2: Latin-1 "²" — isolated byte in 0x80-0xBF range (UTF-8 continuation byte without lead)
m4 := model.NewDefaultMessage("mytopic", "CO\xb2 level high")
require.Nil(t, s.AddMessage(m4))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "CO\uFFFD level high", messages[3].Message)
// 0xe9 0x6d 0x61: Latin-1 "éma" — 0xe9 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
m5 := model.NewDefaultMessage("mytopic", "th\xe9matique")
require.Nil(t, s.AddMessage(m5))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "th\uFFFDmatique", messages[4].Message)
// 0xed 0x64 0x65: Latin-1 "íde" — 0xed starts a 3-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
m6 := model.NewDefaultMessage("mytopic", "vid\xed\x64eo surveillance")
require.Nil(t, s.AddMessage(m6))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "vid\uFFFDdeo surveillance", messages[5].Message)
// 0xf3 0x6e 0x3a 0x20: Latin-1 "ón: " — 0xf3 starts a 4-byte UTF-8 sequence but 0x6e ('n') is not a continuation byte
m7 := model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta")
require.Nil(t, s.AddMessage(m7))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "notificaci\uFFFDn: alerta", messages[6].Message)
// 0xb7: Latin-1 "·" — isolated continuation byte
m8 := model.NewDefaultMessage("mytopic", "item\xb7value")
require.Nil(t, s.AddMessage(m8))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "item\uFFFDvalue", messages[7].Message)
// 0xa8: Latin-1 "¨" — isolated continuation byte
m9 := model.NewDefaultMessage("mytopic", "na\xa8ve")
require.Nil(t, s.AddMessage(m9))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "na\uFFFDve", messages[8].Message)
// 0xdf 0x64: Latin-1 "ßd" — 0xdf starts a 2-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
m10 := model.NewDefaultMessage("mytopic", "gro\xdf\x64ruck")
require.Nil(t, s.AddMessage(m10))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "gro\uFFFDdruck", messages[9].Message)
// 0xe4 0x67 0x74: Latin-1 "ägt" — 0xe4 starts a 3-byte UTF-8 sequence but 0x67 ('g') is not a continuation byte
m11 := model.NewDefaultMessage("mytopic", "tr\xe4gt Last")
require.Nil(t, s.AddMessage(m11))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "tr\uFFFDgt Last", messages[10].Message)
// 0xe9 0x65 0x20: Latin-1 "ée " — 0xe9 starts a 3-byte UTF-8 sequence but 0x65 ('e') is not a continuation byte
m12 := model.NewDefaultMessage("mytopic", "journ\xe9\x65 termin\xe9\x65")
require.Nil(t, s.AddMessage(m12))
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, "journ\uFFFDe termin\uFFFDe", messages[11].Message)
})
}
func TestStore_AddMessage_NullByte(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// 0x00: NUL byte — valid UTF-8 but rejected by PostgreSQL
m := model.NewDefaultMessage("mytopic", "hello\x00world")
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "helloworld", messages[0].Message)
})
}
func TestStore_AddMessage_InvalidUTF8InTitleAndTags(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Invalid UTF-8 can arrive via HTTP headers (Title, Tags) which bypass body validation
m := model.NewDefaultMessage("mytopic", "valid message")
m.Title = "\xc9clipse du syst\xe8me"
m.Tags = []string{"probl\xe8me", "syst\xe9me"}
m.Click = "https://example.com/\xae"
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", messages[0].Title)
require.Equal(t, "probl\uFFFDme", messages[0].Tags[0])
require.Equal(t, "syst\uFFFDme", messages[0].Tags[1])
require.Equal(t, "https://example.com/\uFFFD", messages[0].Click)
})
}
func TestStore_AddMessage_InvalidUTF8BatchDoesNotDropValidMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Previously, a single invalid message would roll back the entire batch transaction.
// Sanitization ensures all messages in a batch are written successfully.
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "valid message 1"),
model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta"),
model.NewDefaultMessage("mytopic", "valid message 3"),
}
require.Nil(t, s.AddMessages(msgs))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 3, len(messages))
})
}

View File

@@ -70,6 +70,26 @@ func (m *Message) Context() log.Context {
return fields
}
// SanitizeUTF8 replaces invalid UTF-8 sequences and strips NUL bytes from all user-supplied
// string fields. This is called early in the publish path so that all downstream consumers
// (Firebase, WebPush, SMTP, cache) receive clean UTF-8 strings.
func (m *Message) SanitizeUTF8() {
m.Topic = util.SanitizeUTF8(m.Topic)
m.Message = util.SanitizeUTF8(m.Message)
m.Title = util.SanitizeUTF8(m.Title)
m.Click = util.SanitizeUTF8(m.Click)
m.Icon = util.SanitizeUTF8(m.Icon)
m.ContentType = util.SanitizeUTF8(m.ContentType)
for i, tag := range m.Tags {
m.Tags[i] = util.SanitizeUTF8(tag)
}
if m.Attachment != nil {
m.Attachment.Name = util.SanitizeUTF8(m.Attachment.Name)
m.Attachment.Type = util.SanitizeUTF8(m.Attachment.Type)
m.Attachment.URL = util.SanitizeUTF8(m.Attachment.URL)
}
}
// ForJSON returns a copy of the message suitable for JSON output.
// It clears the SequenceID if it equals the ID to reduce redundancy.
func (m *Message) ForJSON() *Message {

488
s3/client.go Normal file
View File

@@ -0,0 +1,488 @@
// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces,
// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP
// requests with AWS Signature V4 signing, no AWS SDK dependency required.
package s3
import (
"bytes"
"context"
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
"encoding/base64"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects,
// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix
// are fixed at construction time. All operations target the same bucket and prefix.
//
// Fields must not be modified after the Client is passed to any method or goroutine.
type Client struct {
AccessKey string // AWS access key ID
SecretKey string // AWS secret access key
Region string // e.g. "us-east-1"
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
Bucket string // S3 bucket name
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
HTTPClient *http.Client // if nil, http.DefaultClient is used
}
// New creates a new S3 client from the given Config.
func New(config *Config) *Client {
return &Client{
AccessKey: config.AccessKey,
SecretKey: config.SecretKey,
Region: config.Region,
Endpoint: config.Endpoint,
Bucket: config.Bucket,
Prefix: config.Prefix,
PathStyle: config.PathStyle,
}
}
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
// configured prefix. The body size does not need to be known in advance.
//
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
// into memory.
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
first := make([]byte, partSize)
n, err := io.ReadFull(body, first)
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
}
if err != nil {
return fmt.Errorf("s3: PutObject read: %w", err)
}
combined := io.MultiReader(bytes.NewReader(first), body)
return c.putObjectMultipart(ctx, key, combined)
}
// GetObject downloads an object. The key is automatically prefixed with the client's configured
// prefix. The caller must close the returned ReadCloser.
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
}
if resp.StatusCode/100 != 2 {
err := parseError(resp)
resp.Body.Close()
return nil, 0, err
}
return resp.Body, resp.ContentLength, nil
}
// DeleteObjects removes multiple objects in a single batch request. Keys are automatically
// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the
// caller is responsible for batching if needed.
//
// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present
// in the response, they are returned as a combined error.
func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
var body bytes.Buffer
body.WriteString("<Delete><Quiet>true</Quiet>")
for _, key := range keys {
body.WriteString("<Object><Key>")
xml.EscapeText(&body, []byte(c.objectKey(key)))
body.WriteString("</Key></Object>")
}
body.WriteString("</Delete>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
reqURL := c.bucketURL() + "?delete="
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-MD5", contentMD5)
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: DeleteObjects: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// S3 may return HTTP 200 with per-key errors in the response body
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
}
var result deleteResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
}
if len(result.Errors) > 0 {
var msgs []string
for _, e := range result.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message))
}
return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; "))
}
return nil
}
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
query := url.Values{"list-type": {"2"}}
if prefix := c.prefixForList(); prefix != "" {
query.Set("prefix", prefix)
}
if continuationToken != "" {
query.Set("continuation-token", continuationToken)
}
if maxKeys > 0 {
query.Set("max-keys", strconv.Itoa(maxKeys))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects: %w", err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
}
if resp.StatusCode/100 != 2 {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
var result listObjectsV2Response
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("s3: ListObjects XML: %w", err)
}
objects := make([]Object, len(result.Contents))
for i, obj := range result.Contents {
objects[i] = Object(obj)
}
return &ListResult{
Objects: objects,
IsTruncated: result.IsTruncated,
NextContinuationToken: result.NextContinuationToken,
}, nil
}
// ListAllObjects returns all objects under the client's configured prefix by paginating through
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
const maxPages = 10000
var all []Object
var token string
for page := 0; page < maxPages; page++ {
result, err := c.ListObjects(ctx, token, 0)
if err != nil {
return nil, err
}
all = append(all, result.Objects...)
if !result.IsTruncated {
return all, nil
}
token = result.NextContinuationToken
}
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
}
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
if err != nil {
return fmt.Errorf("s3: PutObject request: %w", err)
}
req.ContentLength = size
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
if err != nil {
return err
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
reqURL := c.objectURL(fullKey) + "?uploads"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
}
req.ContentLength = 0
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
etag := resp.Header.Get("ETag")
return etag, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// Read response body to check for errors (S3 can return 200 with an error body)
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
}
// Check if the response contains an error
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
errResp.StatusCode = resp.StatusCode
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
now := time.Now().UTC()
datestamp := now.Format("20060102")
amzDate := now.Format("20060102T150405Z")
// Required headers
req.Header.Set("Host", c.hostHeader())
req.Header.Set("X-Amz-Date", amzDate)
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
// Canonical headers (all headers we set, sorted by lowercase key)
signedKeys := make([]string, 0, len(req.Header))
canonHeaders := make(map[string]string, len(req.Header))
for k := range req.Header {
lk := strings.ToLower(k)
signedKeys = append(signedKeys, lk)
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
}
sort.Strings(signedKeys)
signedHeadersStr := strings.Join(signedKeys, ";")
var chBuf strings.Builder
for _, k := range signedKeys {
chBuf.WriteString(k)
chBuf.WriteByte(':')
chBuf.WriteString(canonHeaders[k])
chBuf.WriteByte('\n')
}
// Canonical request
canonicalRequest := strings.Join([]string{
req.Method,
canonicalURI(req.URL),
canonicalQueryString(req.URL.Query()),
chBuf.String(),
signedHeadersStr,
payloadHash,
}, "\n")
// String to sign
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
// Signing key
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
[]byte(c.Region)),
[]byte("s3")),
[]byte("aws4_request"))
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
req.Header.Set("Authorization", fmt.Sprintf(
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.AccessKey, credentialScope, signedHeadersStr, signature,
))
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
// objectKey prepends the configured prefix to the given key.
func (c *Client) objectKey(key string) string {
if c.Prefix != "" {
return c.Prefix + "/" + key
}
return key
}
// prefixForList returns the prefix to use in ListObjectsV2 requests,
// with a trailing slash so that only objects under the prefix directory are returned.
func (c *Client) prefixForList() string {
if c.Prefix != "" {
return c.Prefix + "/"
}
return ""
}
// bucketURL returns the base URL for bucket-level operations.
func (c *Client) bucketURL() string {
if c.PathStyle {
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
}
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
}
// objectURL returns the full URL for an object (key should already include the prefix).
// Each path segment is URI-encoded to handle special characters in keys.
func (c *Client) objectURL(key string) string {
segments := strings.Split(key, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return c.bucketURL() + "/" + strings.Join(segments, "/")
}
// hostHeader returns the value for the Host header.
func (c *Client) hostHeader() string {
if c.PathStyle {
return c.Endpoint
}
return c.Bucket + "." + c.Endpoint
}

860
s3/client_test.go Normal file
View File

@@ -0,0 +1,860 @@
package s3
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"sort"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() (*httptest.Server, *mockS3Server) {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m), m
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
type listObjectsResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []listObject `xml:"Contents"`
// Pagination support
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
contToken := r.URL.Query().Get("continuation-token")
m.mu.RLock()
var allKeys []string
for key := range m.objects {
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
allKeys = append(allKeys, objKey)
}
}
m.mu.RUnlock()
sort.Strings(allKeys)
// Simple continuation token: it's the key to start after
startIdx := 0
if contToken != "" {
for i, k := range allKeys {
if k == contToken {
startIdx = i + 1
break
}
}
}
maxKeys := 1000
if mk := r.URL.Query().Get("max-keys"); mk != "" {
fmt.Sscanf(mk, "%d", &maxKeys)
}
endIdx := startIdx + maxKeys
truncated := false
nextToken := ""
if endIdx < len(allKeys) {
truncated = true
nextToken = allKeys[endIdx-1]
allKeys = allKeys[startIdx:endIdx]
} else {
allKeys = allKeys[startIdx:]
}
m.mu.RLock()
var contents []listObject
for _, objKey := range allKeys {
body := m.objects[bucketPath+"/"+objKey]
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
}
m.mu.RUnlock()
resp := listObjectsResponse{
Contents: contents,
IsTruncated: truncated,
NextContinuationToken: nextToken,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
func (m *mockS3Server) objectCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.objects)
}
// --- Helper to create a test client pointing at mock server ---
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
return &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
}
}
// --- URL parsing tests ---
func TestParseURL_Success(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "attachments", cfg.Prefix)
require.Equal(t, "us-east-1", cfg.Region)
require.Equal(t, "AKID", cfg.AccessKey)
require.Equal(t, "SECRET", cfg.SecretKey)
require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint)
require.False(t, cfg.PathStyle)
}
func TestParseURL_NoPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "", cfg.Prefix)
}
func TestParseURL_WithEndpoint(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "prefix", cfg.Prefix)
require.Equal(t, "s3.example.com", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointHTTP(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000")
require.Nil(t, err)
require.Equal(t, "localhost:9000", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointTrailingSlash(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/")
require.Nil(t, err)
require.Equal(t, "s3.example.com", cfg.Endpoint)
}
func TestParseURL_NestedPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "a/b/c", cfg.Prefix)
}
func TestParseURL_MissingRegion(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@my-bucket")
require.Error(t, err)
require.Contains(t, err.Error(), "region")
}
func TestParseURL_MissingCredentials(t *testing.T) {
_, err := ParseURL("s3://my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "access key")
}
func TestParseURL_MissingSecretKey(t *testing.T) {
_, err := ParseURL("s3://AKID@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "secret key")
}
func TestParseURL_WrongScheme(t *testing.T) {
_, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "scheme")
}
func TestParseURL_EmptyBucket(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "bucket")
}
// --- Unit tests: URL construction ---
func TestClient_BucketURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
}
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
}
func TestClient_ObjectURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_HostHeader_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "s3.example.com", c.hostHeader())
}
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
}
func TestClient_ObjectKey(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/file123", c.objectKey("file123"))
c2 := &Client{Prefix: ""}
require.Equal(t, "file123", c2.objectKey("file123"))
}
func TestClient_PrefixForList(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/", c.prefixForList())
c2 := &Client{Prefix: ""}
require.Equal(t, "", c2.prefixForList())
}
// --- Integration tests using mock S3 server ---
func TestClient_PutGetObject(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
require.Equal(t, int64(11), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestClient_PutGetObject_WithPrefix(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "hello", string(data))
}
func TestClient_GetObject_NotFound(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
_, _, err := client.GetObject(context.Background(), "nonexistent")
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
}
func TestClient_DeleteObjects(t *testing.T) {
server, mock := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put several objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
require.Nil(t, err)
}
require.Equal(t, 5, mock.objectCount())
// Delete some
err := client.DeleteObjects(ctx, []string{"key-1", "key-3"})
require.Nil(t, err)
require.Equal(t, 3, mock.objectCount())
// Verify deleted ones are gone
_, _, err = client.GetObject(ctx, "key-1")
require.Error(t, err)
_, _, err = client.GetObject(ctx, "key-3")
require.Error(t, err)
// Verify remaining ones are still there
reader, _, err := client.GetObject(ctx, "key-0")
require.Nil(t, err)
reader.Close()
}
func TestClient_ListObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
ctx := context.Background()
// Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ {
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "")
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
require.Nil(t, err)
// List with prefix client: should only see 3
result, err := client.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 3)
require.False(t, result.IsTruncated)
// List with no-prefix client: should see all 4
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 4)
}
func TestClient_ListObjects_Pagination(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put 5 objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// List with max-keys=2
result, err := client.ListObjects(ctx, "", 2)
require.Nil(t, err)
require.Len(t, result.Objects, 2)
require.True(t, result.IsTruncated)
require.NotEmpty(t, result.NextContinuationToken)
// Get next page
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result2.Objects, 2)
require.True(t, result2.IsTruncated)
// Get last page
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result3.Objects, 1)
require.False(t, result3.IsTruncated)
}
func TestClient_ListAllObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
for i := 0; i < 10; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
}
func TestClient_PutObject_LargeBody(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 1 MB object
data := make([]byte, 1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "large", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large")
require.Nil(t, err)
require.Equal(t, int64(1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
data := make([]byte, 12*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "multipart")
require.Nil(t, err)
require.Equal(t, int64(12*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ExactPartSize(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
data := make([]byte, 5*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "exact")
require.Nil(t, err)
require.Equal(t, int64(5*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "nested", string(data))
}
// --- Scale test: 20k objects (ntfy-adjacent) ---
func TestClient_ListAllObjects_20k(t *testing.T) {
if testing.Short() {
t.Skip("skipping 20k object test in short mode")
}
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "attachments")
ctx := context.Background()
const numObjects = 20000
const batchSize = 500
// Insert 20k objects in batches to keep it fast
for batch := 0; batch < numObjects/batchSize; batch++ {
for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx)
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
}
// List all 20k objects with pagination
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects)
// Verify total size
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
require.Equal(t, int64(numObjects), totalSize)
// Delete 1000 objects (simulating attachment expiry cleanup)
keys := make([]string, 1000)
for i := range keys {
keys[i] = fmt.Sprintf("%08d", i)
}
err = client.DeleteObjects(ctx, keys)
require.Nil(t, err)
// List again: should have 19000
objects, err = client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects-1000)
}
// --- Real S3 integration test ---
//
// Set the following environment variables to run this test against a real S3 bucket:
//
// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET
//
// Optional:
//
// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com")
// S3_PATH_STYLE: set to "true" for path-style addressing
// S3_PREFIX: key prefix to use (default: "ntfy-s3-test")
func TestClient_RealBucket(t *testing.T) {
accessKey := os.Getenv("S3_ACCESS_KEY")
secretKey := os.Getenv("S3_SECRET_KEY")
region := os.Getenv("S3_REGION")
bucket := os.Getenv("S3_BUCKET")
if accessKey == "" || secretKey == "" || region == "" || bucket == "" {
t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET")
}
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
}
pathStyle := os.Getenv("S3_PATH_STYLE") == "true"
prefix := os.Getenv("S3_PREFIX")
if prefix == "" {
prefix = "ntfy-s3-test"
}
client := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix,
PathStyle: pathStyle,
}
ctx := context.Background()
// Clean up any leftover objects from previous runs
existing, err := client.ListAllObjects(ctx)
require.Nil(t, err)
if len(existing) > 0 {
keys := make([]string, len(existing))
for i, obj := range existing {
// Strip the prefix since DeleteObjects will re-add it
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
}
// Batch delete in groups of 1000
for i := 0; i < len(keys); i += 1000 {
end := i + 1000
if end > len(keys) {
end = len(keys)
}
err := client.DeleteObjects(ctx, keys[i:end])
require.Nil(t, err)
}
}
t.Run("PutGetDelete", func(t *testing.T) {
key := "test-object"
content := "hello from ntfy s3 test"
// Put
err := client.PutObject(ctx, key, strings.NewReader(content))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(content)), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, content, string(data))
// Delete
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
// Get after delete should fail
_, _, err = client.GetObject(ctx, key)
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
})
t.Run("ListObjects", func(t *testing.T) {
// Use a sub-prefix client for isolation
listClient := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix + "/list-test",
PathStyle: pathStyle,
}
// Put 10 objects
for i := 0; i < 10; i++ {
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
require.Nil(t, err)
}
// List
objects, err := listClient.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
// Clean up
keys := make([]string, 10)
for i := range keys {
keys[i] = fmt.Sprintf("%d", i)
}
err = listClient.DeleteObjects(ctx, keys)
require.Nil(t, err)
})
t.Run("LargeObject", func(t *testing.T) {
key := "large-object"
data := make([]byte, 5*1024*1024) // 5 MB
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, key, bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(data)), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
})
}

76
s3/types.go Normal file
View File

@@ -0,0 +1,76 @@
package s3
import "fmt"
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
type Config struct {
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
PathStyle bool
Bucket string
Prefix string
Region string
AccessKey string
SecretKey string
}
// Object represents an S3 object returned by list operations.
type Object struct {
Key string
Size int64
}
// ListResult holds the response from a ListObjectsV2 call.
type ListResult struct {
Objects []Object
IsTruncated bool
NextContinuationToken string
}
// ErrorResponse is returned when S3 responds with a non-2xx status code.
type ErrorResponse struct {
StatusCode int
Code string `xml:"Code"`
Message string `xml:"Message"`
Body string `xml:"-"` // raw response body
}
func (e *ErrorResponse) Error() string {
if e.Code != "" {
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
}
return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body)
}
// listObjectsV2Response is the XML response from S3 ListObjectsV2
type listObjectsV2Response struct {
Contents []listObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
type listObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}
// deleteResult is the XML response from S3 DeleteObjects
type deleteResult struct {
Errors []deleteError `xml:"Error"`
}
type deleteError struct {
Key string `xml:"Key"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
type initiateMultipartUploadResult struct {
UploadID string `xml:"UploadId"`
}
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
type completedPart struct {
PartNumber int
ETag string
}

166
s3/util.go Normal file
View File

@@ -0,0 +1,166 @@
package s3
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
const (
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
unsignedPayload = "UNSIGNED-PAYLOAD"
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
maxResponseBytes = 10 * 1024 * 1024
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
// part size of 5 MB for all parts except the last.
partSize = 5 * 1024 * 1024
)
// ParseURL parses an S3 URL of the form:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
//
// When endpoint is specified, path-style addressing is enabled automatically.
func ParseURL(s3URL string) (*Config, error) {
u, err := url.Parse(s3URL)
if err != nil {
return nil, fmt.Errorf("s3: invalid URL: %w", err)
}
if u.Scheme != "s3" {
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
}
if u.Host == "" {
return nil, fmt.Errorf("s3: bucket name must be specified as host")
}
bucket := u.Host
prefix := strings.TrimPrefix(u.Path, "/")
accessKey := u.User.Username()
secretKey, _ := u.User.Password()
if accessKey == "" || secretKey == "" {
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
}
region := u.Query().Get("region")
if region == "" {
return nil, fmt.Errorf("s3: region query parameter is required")
}
endpointParam := u.Query().Get("endpoint")
var endpoint string
var pathStyle bool
if endpointParam != "" {
// Custom endpoint: strip scheme prefix to extract host[:port]
ep := strings.TrimRight(endpointParam, "/")
ep = strings.TrimPrefix(ep, "https://")
ep = strings.TrimPrefix(ep, "http://")
endpoint = ep
pathStyle = true
} else {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
pathStyle = false
}
return &Config{
Endpoint: endpoint,
PathStyle: pathStyle,
Bucket: bucket,
Prefix: prefix,
Region: region,
AccessKey: accessKey,
SecretKey: secretKey,
}, nil
}
// parseError reads an S3 error response and returns an *ErrorResponse.
func parseError(resp *http.Response) error {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: reading error response: %w", err)
}
return parseErrorFromBytes(resp.StatusCode, body)
}
func parseErrorFromBytes(statusCode int, body []byte) error {
errResp := &ErrorResponse{
StatusCode: statusCode,
Body: string(body),
}
// Try to parse XML error; if it fails, we still have StatusCode and Body
_ = xml.Unmarshal(body, errResp)
return errResp
}
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
// percent-encoded per RFC 3986; forward slashes are preserved.
func canonicalURI(u *url.URL) string {
p := u.Path
if p == "" {
return "/"
}
segments := strings.Split(p, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return strings.Join(segments, "/")
}
// canonicalQueryString builds the query string for the canonical request. Keys and values
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
func canonicalQueryString(values url.Values) string {
if len(values) == 0 {
return ""
}
keys := make([]string, 0, len(values))
for k := range values {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
ek := uriEncode(k)
vs := make([]string, len(values[k]))
copy(vs, values[k])
sort.Strings(vs)
for _, v := range vs {
pairs = append(pairs, ek+"="+uriEncode(v))
}
}
return strings.Join(pairs, "&")
}
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
// characters (A-Z a-z 0-9 - _ . ~).
func uriEncode(s string) string {
var buf strings.Builder
for i := 0; i < len(s); i++ {
b := s[i]
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "%%%02X", b)
}
}
return buf.String()
}
func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}

181
s3/util_test.go Normal file
View File

@@ -0,0 +1,181 @@
package s3
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestURIEncode(t *testing.T) {
// Unreserved characters are not encoded
require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz"))
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
require.Equal(t, "0123456789", uriEncode("0123456789"))
require.Equal(t, "-_.~", uriEncode("-_.~"))
// Spaces use %20, not +
require.Equal(t, "hello%20world", uriEncode("hello world"))
// Slashes are encoded (canonicalURI handles slash splitting separately)
require.Equal(t, "a%2Fb", uriEncode("a/b"))
// Special characters
require.Equal(t, "%2B", uriEncode("+"))
require.Equal(t, "%3D", uriEncode("="))
require.Equal(t, "%26", uriEncode("&"))
require.Equal(t, "%40", uriEncode("@"))
require.Equal(t, "%23", uriEncode("#"))
// Mixed
require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt"))
require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes"))
// Empty string
require.Equal(t, "", uriEncode(""))
}
func TestCanonicalURI(t *testing.T) {
// Simple path
u, _ := url.Parse("https://example.com/bucket/key")
require.Equal(t, "/bucket/key", canonicalURI(u))
// Root path
u, _ = url.Parse("https://example.com/")
require.Equal(t, "/", canonicalURI(u))
// Empty path
u, _ = url.Parse("https://example.com")
require.Equal(t, "/", canonicalURI(u))
// Path with special characters
u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces")
require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u))
// Nested path
u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt")
require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u))
}
func TestCanonicalQueryString(t *testing.T) {
// Multiple keys sorted alphabetically
vals := url.Values{
"prefix": {"test/"},
"list-type": {"2"},
}
require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals))
// Empty values
require.Equal(t, "", canonicalQueryString(url.Values{}))
// Single key
require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}}))
// Key with multiple values (sorted)
vals = url.Values{"key": {"b", "a"}}
require.Equal(t, "key=a&key=b", canonicalQueryString(vals))
// Keys requiring encoding
vals = url.Values{"continuation-token": {"abc+def"}}
require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals))
}
func TestSHA256Hex(t *testing.T) {
// SHA-256 of empty string
require.Equal(t, emptyPayloadHash, sha256Hex([]byte("")))
// SHA-256 of known value
require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello")))
}
func TestHmacSHA256(t *testing.T) {
// Known test vector: HMAC-SHA256("key", "message")
result := hmacSHA256([]byte("key"), []byte("message"))
require.Len(t, result, 32) // SHA-256 produces 32 bytes
require.NotEqual(t, make([]byte, 32), result)
// Same inputs should produce same output
result2 := hmacSHA256([]byte("key"), []byte("message"))
require.Equal(t, result, result2)
// Different inputs should produce different output
result3 := hmacSHA256([]byte("different-key"), []byte("message"))
require.NotEqual(t, result, result3)
}
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, emptyPayloadHash)
// All required SigV4 headers must be set
require.NotEmpty(t, req.Header.Get("Host"))
require.NotEmpty(t, req.Header.Get("X-Amz-Date"))
require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256"))
// Authorization header must have correct format
auth := req.Header.Get("Authorization")
require.Contains(t, auth, "AWS4-HMAC-SHA256")
require.Contains(t, auth, "Credential=AKID/")
require.Contains(t, auth, "/us-east-1/s3/aws4_request")
require.Contains(t, auth, "SignedHeaders=")
require.Contains(t, auth, "Signature=")
}
func TestSignV4_UnsignedPayload(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, unsignedPayload)
require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256"))
}
func TestSignV4_DifferentRegions(t *testing.T) {
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
c1.signV4(req1, emptyPayloadHash)
req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil)
c2.signV4(req2, emptyPayloadHash)
// Different regions should produce different signatures
require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization"))
}
func TestParseError_XMLResponse(t *testing.T) {
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
err := parseErrorFromBytes(404, xmlBody)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
require.Equal(t, "The specified key does not exist.", errResp.Message)
}
func TestParseError_NonXMLResponse(t *testing.T) {
err := parseErrorFromBytes(500, []byte("internal server error"))
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 500, errResp.StatusCode)
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
require.Contains(t, errResp.Body, "internal server error")
}

View File

@@ -112,6 +112,7 @@ type Config struct {
AuthBcryptCost int
AuthStatsQueueWriterInterval time.Duration
AttachmentCacheDir string
AttachmentS3URL string
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration

View File

@@ -142,6 +142,7 @@ var (
errHTTPBadRequestTemplateFileNotFound = &errHTTP{40047, http.StatusBadRequest, "invalid request: template file not found", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestTemplateFileInvalid = &errHTTP{40048, http.StatusBadRequest, "invalid request: template file invalid", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestSequenceIDInvalid = &errHTTP{40049, http.StatusBadRequest, "invalid request: sequence ID invalid", "https://ntfy.sh/docs/publish/#updating-deleting-notifications", nil}
errHTTPBadRequestEmailAddressInvalid = &errHTTP{40050, http.StatusBadRequest, "invalid request: invalid e-mail address", "https://ntfy.sh/docs/publish/#e-mail-notifications", nil}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}

View File

@@ -1,76 +0,0 @@
package server
import (
"bytes"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
"os"
"strings"
"testing"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining())
}
func TestFileCache_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), c.Size())
require.Equal(t, int64(2248), c.Remaining())
}
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024)
require.Nil(t, err)
return dir, cache
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

View File

@@ -24,7 +24,6 @@ const (
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagTwilio = "twilio"
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"
@@ -36,7 +35,7 @@ const (
)
var (
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage}
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage, http.StatusRequestEntityTooLarge}
rateLimitingErrorCodes = []int{http.StatusTooManyRequests, http.StatusRequestEntityTooLarge}
)

View File

@@ -32,6 +32,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/attachment"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
@@ -64,7 +65,7 @@ type Server struct {
userManager *user.Manager // Might be nil!
messageCache *message.Cache // Database that stores the messages
webPush *webpush.Store // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
fileCache attachment.Store // Attachment store (file system or S3)
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
@@ -122,6 +123,7 @@ var (
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
urlRegex = regexp.MustCompile(`^https?://`)
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
emailAddressRegex = regexp.MustCompile(`^[^\s,;]+@[^\s,;]+$`)
//go:embed site
webFs embed.FS
@@ -227,12 +229,9 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var fileCache *fileCache
if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
if err != nil {
return nil, err
}
fileCache, err := createAttachmentStore(conf)
if err != nil {
return nil, err
}
var userManager *user.Manager
if conf.AuthFile != "" || pool != nil {
@@ -301,6 +300,15 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
return message.NewMemStore()
}
func createAttachmentStore(conf *Config) (attachment.Store, error) {
if conf.AttachmentS3URL != "" {
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
} else if conf.AttachmentCacheDir != "" {
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
}
return nil, nil
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
// a manager go routine to print stats and prune messages.
func (s *Server) Run() error {
@@ -752,7 +760,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor)
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" {
if s.fileCache == nil {
return errHTTPInternalError
}
matches := fileRegex.FindStringSubmatch(r.URL.Path)
@@ -760,16 +768,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
return errHTTPInternalErrorInvalidPath
}
messageID := matches[1]
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
stat, err := os.Stat(file)
reader, size, err := s.fileCache.Read(messageID)
if err != nil {
return errHTTPNotFound.Fields(log.Context{
"message_id": messageID,
"error_context": "filesystem",
"error_context": "attachment_store",
})
}
defer reader.Close()
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
if r.Method == http.MethodHead {
return nil
}
@@ -805,19 +813,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
} else if m.Sender.IsValid() {
bandwidthVisitor = s.visitor(m.Sender, nil)
}
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
if !bandwidthVisitor.BandwidthAllowed(size) {
return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
}
// Actually send file
f, err := os.Open(file)
if err != nil {
return err
}
defer f.Close()
if m.Attachment.Name != "" {
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
}
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader)
return err
}
@@ -880,6 +883,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
if m.Message == "" {
m.Message = emptyMessageBody
}
m.SanitizeUTF8()
delayed := m.Time > time.Now().Unix()
ev := logvrm(v, r, m).
Tag(tagPublish).
@@ -1162,6 +1166,9 @@ func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bo
m.Icon = icon
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" && !emailAddressRegex.MatchString(email) {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailAddressInvalid
}
if s.smtpSender == nil && email != "" {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailDisabled
}
@@ -1408,7 +1415,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
if s.fileCache == nil || s.config.BaseURL == "" {
return errHTTPBadRequestAttachmentsDisallowed.With(m)
}
vinfo, err := v.Info()

View File

@@ -159,6 +159,7 @@
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
#
# attachment-cache-dir:
# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1"
# attachment-total-size-limit: "5G"
# attachment-file-size-limit: "15M"
# attachment-expiry-duration: "3h"

View File

@@ -3,14 +3,15 @@ package server
import (
"encoding/json"
"errors"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"net/http"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
const (
@@ -455,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPUnauthorized
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
return errHTTPConflictTopicReserved
} else if u.IsUser() {
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
if err != nil {
return err
}
if !hasReservation {
reservations, err := s.userManager.ReservationsCount(u.Name)
if err != nil {
return err
} else if reservations >= u.Tier.ReservationLimit {
return errHTTPTooManyRequestsLimitReservations
}
}
}
// Actually add the reservation
// Actually add the reservation (with limit check inside the transaction to avoid races)
logvr(v, r).
Tag(tagAccount).
Fields(log.Context{
@@ -477,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
"everyone": everyone.String(),
}).
Debug("Adding topic reservation")
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
var limit int64
if u.IsUser() && u.Tier != nil {
limit = u.Tier.ReservationLimit
}
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil {
if errors.Is(err, user.ErrTooManyReservations) {
return errHTTPTooManyRequestsLimitReservations
}
return err
}
// Kill existing subscribers
@@ -530,22 +525,15 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
// The process relies on the manager to perform the actual deletions (see runManager).
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
reservations, err := s.userManager.Reservations(u.Name)
removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit)
if err != nil {
return err
} else if int64(len(reservations)) <= reservationsLimit {
} else if len(removedTopics) == 0 {
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
return nil
}
topics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
topics = append(topics, reservations[i].Topic)
}
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
return err
}
if err := s.messageCache.ExpireMessages(topics...); err != nil {
logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", "))
if err := s.messageCache.ExpireMessages(removedTopics...); err != nil {
return err
}
go s.pruneMessages()

View File

@@ -503,7 +503,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
}))
require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("noadmin1", "pro"))
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll, 0))
require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("noadmin2", "pro"))

View File

@@ -99,6 +99,9 @@ func (s *Server) execManager() {
mset(metricUsers, usersCount)
mset(metricSubscribers, subscribers)
mset(metricTopics, topicsCount)
if s.fileCache != nil {
mset(metricAttachmentsTotalSize, s.fileCache.Size())
}
}
func (s *Server) pruneVisitors() {

View File

@@ -478,8 +478,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll, 0))
// Add billing details
u, err := s.userManager.User("phil")
@@ -589,7 +589,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
// Add billing details
u, err := s.userManager.User("phil")

View File

@@ -1543,6 +1543,30 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
})
}
func TestServer_PublishEmailAddressInvalid(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
s.smtpSender = &testMailer{}
addresses := []string{
"test@example.com, other@example.com",
"invalidaddress",
"@nope",
"nope@",
}
for _, email := range addresses {
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
"E-Mail": email,
})
require.Equal(t, 400, response.Code, "expected 400 for email: %s", email)
}
// Valid address should succeed
response := request(t, s, "PUT", "/mytopic", "success", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 200, response.Code)
})
}
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
t.Parallel()
@@ -4441,3 +4465,88 @@ func TestServer_HandleError_SkipsWriteHeaderOnHijackedConnection(t *testing.T) {
}
})
}
func TestServer_Publish_InvalidUTF8InBody(t *testing.T) {
// All byte sequences from production logs, sent as message body
tests := []struct {
name string
body string
message string
}{
{"0xc9_0x43", "\xc9Cas du serveur", "\uFFFDCas du serveur"}, // Latin-1 "ÉC"
{"0xae", "Product\xae Pro", "Product\uFFFD Pro"}, // Latin-1 "®"
{"0xe8_0x6d_0x65", "probl\xe8me critique", "probl\uFFFDme critique"}, // Latin-1 "ème"
{"0xb2", "CO\xb2 level high", "CO\uFFFD level high"}, // Latin-1 "²"
{"0xe9_0x6d_0x61", "th\xe9matique", "th\uFFFDmatique"}, // Latin-1 "éma"
{"0xed_0x64_0x65", "vid\xed\x64eo surveillance", "vid\uFFFDdeo surveillance"}, // Latin-1 "íde"
{"0xf3_0x6e_0x3a_0x20", "notificaci\xf3n: alerta", "notificaci\uFFFDn: alerta"}, // Latin-1 "ón: "
{"0xb7", "item\xb7value", "item\uFFFDvalue"}, // Latin-1 "·"
{"0xa8", "na\xa8ve", "na\uFFFDve"}, // Latin-1 "¨"
{"0x00", "hello\x00world", "helloworld"}, // NUL byte
{"0xdf_0x64", "gro\xdf\x64ruck", "gro\uFFFDdruck"}, // Latin-1 "ßd"
{"0xe4_0x67_0x74", "tr\xe4gt Last", "tr\uFFFDgt Last"}, // Latin-1 "ägt"
{"0xe9_0x65_0x20", "journ\xe9\x65 termin\xe9\x65", "journ\uFFFDe termin\uFFFDe"}, // Latin-1 "ée"
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s := newTestServer(t, newTestConfig(t, ""))
// Publish via x-message header (the most common path for invalid UTF-8 from HTTP headers)
response := request(t, s, "PUT", "/mytopic", "", map[string]string{
"X-Message": tc.body,
})
require.Equal(t, 200, response.Code)
msg := toMessage(t, response.Body.String())
require.Equal(t, tc.message, msg.Message)
// Verify it was stored in the cache correctly
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
require.Equal(t, 200, response.Code)
msg = toMessage(t, response.Body.String())
require.Equal(t, tc.message, msg.Message)
})
}
}
func TestServer_Publish_InvalidUTF8InTitle(t *testing.T) {
s := newTestServer(t, newTestConfig(t, ""))
response := request(t, s, "PUT", "/mytopic", "valid body", map[string]string{
"Title": "\xc9clipse du syst\xe8me",
})
require.Equal(t, 200, response.Code)
msg := toMessage(t, response.Body.String())
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", msg.Title)
require.Equal(t, "valid body", msg.Message)
}
func TestServer_Publish_InvalidUTF8InTags(t *testing.T) {
s := newTestServer(t, newTestConfig(t, ""))
response := request(t, s, "PUT", "/mytopic", "valid body", map[string]string{
"Tags": "probl\xe8me,syst\xe9me",
})
require.Equal(t, 200, response.Code)
msg := toMessage(t, response.Body.String())
require.Equal(t, "probl\uFFFDme", msg.Tags[0])
require.Equal(t, "syst\uFFFDme", msg.Tags[1])
}
func TestServer_Publish_InvalidUTF8WithFirebase(t *testing.T) {
// Verify that sanitization happens before Firebase dispatch, so Firebase
// receives clean UTF-8 strings rather than invalid byte sequences
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t, ""))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
response := request(t, s, "PUT", "/mytopic", "", map[string]string{
"X-Message": "notificaci\xf3n: alerta",
"Title": "\xc9clipse",
"Tags": "probl\xe8me",
})
require.Equal(t, 200, response.Code)
time.Sleep(100 * time.Millisecond) // Firebase publishing happens asynchronously
require.Equal(t, 1, len(sender.Messages()))
require.Equal(t, "notificaci\uFFFDn: alerta", sender.Messages()[0].Data["message"])
require.Equal(t, "\uFFFDclipse", sender.Messages()[0].Data["title"])
require.Equal(t, "probl\uFFFDme", sender.Messages()[0].Data["tags"])
}

View File

@@ -65,12 +65,12 @@ const (
key TEXT PRIMARY KEY,
value BIGINT
);
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
INSERT INTO message_stats (key, value) VALUES ('messages', 0) ON CONFLICT (key) DO NOTHING;
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schema_version (store, version) VALUES ('message', 14);
INSERT INTO schema_version (store, version) VALUES ('message', 14) ON CONFLICT (store) DO NOTHING;
`
// Initial PostgreSQL schema for user store (from user/manager_postgres_schema.go)
@@ -146,7 +146,7 @@ const (
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
ON CONFLICT (id) DO NOTHING;
INSERT INTO schema_version (store, version) VALUES ('user', 6);
INSERT INTO schema_version (store, version) VALUES ('user', 6) ON CONFLICT (store) DO NOTHING;
`
// Initial PostgreSQL schema for web push store (from webpush/store_postgres.go)
@@ -174,7 +174,7 @@ const (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schema_version (store, version) VALUES ('webpush', 1);
INSERT INTO schema_version (store, version) VALUES ('webpush', 1) ON CONFLICT (store) DO NOTHING;
`
)
@@ -185,6 +185,7 @@ var flags = []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
&cli.BoolFlag{Name: "create-schema", Usage: "create initial PostgreSQL schema before importing"},
&cli.BoolFlag{Name: "pre-import", Usage: "pre-import messages while ntfy is still running (only imports messages)"},
}
func main() {
@@ -207,10 +208,17 @@ func execImport(c *cli.Context) error {
cacheFile := c.String("cache-file")
authFile := c.String("auth-file")
webPushFile := c.String("web-push-file")
preImport := c.Bool("pre-import")
if databaseURL == "" {
return fmt.Errorf("database-url must be set (via --database-url or config file)")
}
if preImport {
if cacheFile == "" {
return fmt.Errorf("--cache-file must be set when using --pre-import")
}
return execPreImport(c, databaseURL, cacheFile)
}
if cacheFile == "" && authFile == "" && webPushFile == "" {
return fmt.Errorf("at least one of --cache-file, --auth-file, or --web-push-file must be set")
}
@@ -261,7 +269,8 @@ func execImport(c *cli.Context) error {
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
return err
}
if err := importMessages(cacheFile, pgDB); err != nil {
sinceTime := maxMessageTime(pgDB)
if err := importMessages(cacheFile, pgDB, sinceTime); err != nil {
return fmt.Errorf("cannot import messages: %w", err)
}
}
@@ -300,6 +309,54 @@ func execImport(c *cli.Context) error {
return nil
}
func execPreImport(c *cli.Context, databaseURL, cacheFile string) error {
fmt.Println("pgimport - PRE-IMPORT mode (ntfy can keep running)")
fmt.Println()
fmt.Println("Source:")
printSource(" Cache file: ", cacheFile)
fmt.Println()
fmt.Println("Target:")
fmt.Printf(" Database URL: %s\n", maskPassword(databaseURL))
fmt.Println()
fmt.Println("This will pre-import messages into PostgreSQL while ntfy is still running.")
fmt.Println("After this completes, stop ntfy and run pgimport again without --pre-import")
fmt.Println("to import remaining messages, users, and web push subscriptions.")
fmt.Print("Continue? (y/n): ")
var answer string
fmt.Scanln(&answer)
if strings.TrimSpace(strings.ToLower(answer)) != "y" {
fmt.Println("Aborted.")
return nil
}
fmt.Println()
pgHost, err := pg.Open(databaseURL)
if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
}
pgDB := pgHost.DB
defer pgDB.Close()
if c.Bool("create-schema") {
if err := createSchema(pgDB, cacheFile, "", ""); err != nil {
return fmt.Errorf("cannot create schema: %w", err)
}
}
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
return err
}
if err := importMessages(cacheFile, pgDB, 0); err != nil {
return fmt.Errorf("cannot import messages: %w", err)
}
fmt.Println()
fmt.Println("Pre-import complete. Now stop ntfy and run pgimport again without --pre-import")
fmt.Println("to import any remaining messages, users, and web push subscriptions.")
return nil
}
func createSchema(pgDB *sql.DB, cacheFile, authFile, webPushFile string) error {
fmt.Println("Creating initial PostgreSQL schema ...")
// User schema must be created before message schema, because message_stats and
@@ -645,16 +702,41 @@ func importUserPhones(sqlDB, pgDB *sql.DB) (int, error) {
// Message import
func importMessages(sqliteFile string, pgDB *sql.DB) error {
const preImportTimeDelta = 30 // seconds to subtract from max time to account for in-flight messages
// maxMessageTime returns the maximum message time in PostgreSQL minus a small buffer,
// or 0 if there are no messages yet. This is used after a --pre-import run to only
// import messages that arrived since the pre-import.
func maxMessageTime(pgDB *sql.DB) int64 {
var maxTime sql.NullInt64
if err := pgDB.QueryRow(`SELECT MAX(time) FROM message`).Scan(&maxTime); err != nil || !maxTime.Valid || maxTime.Int64 == 0 {
return 0
}
sinceTime := maxTime.Int64 - preImportTimeDelta
if sinceTime < 0 {
return 0
}
fmt.Printf("Pre-imported messages detected (max time: %d), importing delta (since time %d) ...\n", maxTime.Int64, sinceTime)
return sinceTime
}
func importMessages(sqliteFile string, pgDB *sql.DB, sinceTime int64) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
fmt.Printf("Skipping message import: %s\n", err)
return nil
}
defer sqlDB.Close()
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
rows, err := sqlDB.Query(`SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`)
query := `SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`
var rows *sql.Rows
if sinceTime > 0 {
fmt.Printf("Importing messages from %s (since time %d) ...\n", sqliteFile, sinceTime)
rows, err = sqlDB.Query(query+` WHERE time >= ?`, sinceTime)
} else {
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
rows, err = sqlDB.Query(query)
}
if err != nil {
return fmt.Errorf("querying messages: %w", err)
}
@@ -837,7 +919,9 @@ func importWebPush(sqliteFile string, pgDB *sql.DB) error {
}
func toUTF8(s string) string {
return strings.ToValidUTF8(s, "\uFFFD")
s = strings.ToValidUTF8(s, "\uFFFD")
s = strings.ReplaceAll(s, "\x00", "")
return s
}
// Verification

141
tools/s3cli/main.go Normal file
View File

@@ -0,0 +1,141 @@
// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls.
//
// Usage:
//
// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT"
//
// s3cli put <key> <file> Upload a file
// s3cli put <key> - Upload from stdin
// s3cli get <key> Download to stdout
// s3cli rm <key> [<key>...] Delete one or more objects
// s3cli ls List all objects
package main
import (
"context"
"fmt"
"io"
"os"
"text/tabwriter"
"heckel.io/ntfy/v2/s3"
)
func main() {
if len(os.Args) < 2 {
usage()
}
s3URL := os.Getenv("S3_URL")
if s3URL == "" {
fail("S3_URL environment variable is required")
}
cfg, err := s3.ParseURL(s3URL)
if err != nil {
fail("invalid S3_URL: %s", err)
}
client := s3.New(cfg)
ctx := context.Background()
switch os.Args[1] {
case "put":
cmdPut(ctx, client)
case "get":
cmdGet(ctx, client)
case "rm":
cmdRm(ctx, client)
case "ls":
cmdLs(ctx, client)
default:
usage()
}
}
func cmdPut(ctx context.Context, client *s3.Client) {
if len(os.Args) != 4 {
fail("usage: s3cli put <key> <file|->\n")
}
key := os.Args[2]
path := os.Args[3]
var r io.Reader
if path == "-" {
r = os.Stdin
} else {
f, err := os.Open(path)
if err != nil {
fail("open %s: %s", path, err)
}
defer f.Close()
r = f
}
if err := client.PutObject(ctx, key, r); err != nil {
fail("put: %s", err)
}
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
}
func cmdGet(ctx context.Context, client *s3.Client) {
if len(os.Args) != 3 {
fail("usage: s3cli get <key>\n")
}
key := os.Args[2]
reader, size, err := client.GetObject(ctx, key)
if err != nil {
fail("get: %s", err)
}
defer reader.Close()
n, err := io.Copy(os.Stdout, reader)
if err != nil {
fail("read: %s", err)
}
fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size)
}
func cmdRm(ctx context.Context, client *s3.Client) {
if len(os.Args) < 3 {
fail("usage: s3cli rm <key> [<key>...]\n")
}
keys := os.Args[2:]
if err := client.DeleteObjects(ctx, keys); err != nil {
fail("rm: %s", err)
}
fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys))
}
func cmdLs(ctx context.Context, client *s3.Client) {
objects, err := client.ListAllObjects(ctx)
if err != nil {
fail("ls: %s", err)
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
var totalSize int64
for _, obj := range objects {
fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key)
totalSize += obj.Size
}
w.Flush()
fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize)
}
func usage() {
fmt.Fprintf(os.Stderr, `Usage: s3cli <command> [args...]
Commands:
put <key> <file|-> Upload a file (use - for stdin)
get <key> Download to stdout
rm <key> [keys...] Delete objects
ls List all objects
Environment:
S3_URL S3 connection URL (required)
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
`)
os.Exit(1)
}
func fail(format string, args ...any) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}

View File

@@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error {
t, err := a.Tier(tier)
if err != nil {
return err
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
return err
}
if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil {
return err
}
return nil
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil {
return err
}
if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil {
return err
}
return nil
})
}
// ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if err := a.checkReservationsLimit(username, 0); err != nil {
return err
}
_, err := a.db.Exec(a.queries.deleteUserTier, username)
return err
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.checkReservationsLimitTx(tx, username, 0); err != nil {
return err
}
if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil {
return err
}
return nil
})
}
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
u, err := a.User(username)
func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error {
u, err := a.userTx(tx, username)
if err != nil {
return err
}
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
reservations, err := a.Reservations(username)
reservations, err := a.reservationsTx(tx, username)
if err != nil {
return err
} else if int64(len(reservations)) > reservationsLimit {
@@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error {
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByName, username)
return a.userTx(a.db, username)
}
func (a *Manager) userTx(tx db.Querier, username string) (*User, error) {
rows, err := tx.Query(a.queries.selectUserByName, username)
if err != nil {
return nil, err
}
@@ -415,7 +427,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
if err != nil {
return nil, err
}
@@ -642,7 +654,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
return false, false, false, err
}
@@ -713,16 +725,35 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
// AddReservation creates two access control entries for the given topic: one with full read/write
// access for the given user, and one for Everyone with the given permission. Both entries are
// created atomically in a single transaction.
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
// created atomically in a single transaction. If limit is > 0, the reservation count is checked
// inside the transaction and ErrTooManyReservations is returned if the limit would be exceeded.
func (a *Manager) AddReservation(username string, topic string, everyone Permission, limit int64) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
if limit > 0 {
hasReservation, err := a.hasReservationTx(tx, username, topic)
if err != nil {
return err
}
if !hasReservation {
count, err := a.reservationsCountTx(tx, username)
if err != nil {
return err
}
if count >= limit {
return ErrTooManyReservations
}
}
}
if _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), true, true, username, username, false); err != nil {
return err
}
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
if _, err := tx.Exec(a.queries.upsertUserAccess, Everyone, toSQLWildcard(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
return err
}
return nil
})
}
@@ -740,10 +771,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
}
return db.ExecTx(a.db, func(tx *sql.Tx) error {
for _, topic := range topics {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err
}
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
return err
}
}
@@ -753,7 +781,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
// Reservations returns all user-owned topics, and the associated everyone-access
func (a *Manager) Reservations(username string) ([]Reservation, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
return a.reservationsTx(a.db.ReadOnly(), username)
}
func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) {
rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username)
if err != nil {
return nil, err
}
@@ -779,7 +811,11 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
// HasReservation returns true if the given topic access is owned by the user
func (a *Manager) HasReservation(username, topic string) (bool, error) {
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
return a.hasReservationTx(a.db, username, topic)
}
func (a *Manager) hasReservationTx(tx db.Querier, username, topic string) (bool, error) {
rows, err := tx.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
if err != nil {
return false, err
}
@@ -796,7 +832,11 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
// ReservationsCount returns the number of reservations owned by this user
func (a *Manager) ReservationsCount(username string) (int64, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
return a.reservationsCountTx(a.db, username)
}
func (a *Manager) reservationsCountTx(tx db.Querier, username string) (int64, error) {
rows, err := tx.Query(a.queries.selectUserReservationsCount, username)
if err != nil {
return 0, err
}
@@ -828,6 +868,30 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
return ownerUserID, nil
}
// RemoveExcessReservations removes reservations that exceed the given limit for the user.
// It returns the list of topics whose reservations were removed. The read and removal are
// performed atomically in a single transaction to avoid issues with stale replica data.
func (a *Manager) RemoveExcessReservations(username string, limit int64) ([]string, error) {
return db.QueryTx(a.db, func(tx *sql.Tx) ([]string, error) {
reservations, err := a.reservationsTx(tx, username)
if err != nil {
return nil, err
}
if int64(len(reservations)) <= limit {
return []string{}, nil
}
removedTopics := make([]string, 0)
for i := int64(len(reservations)) - 1; i >= limit; i-- {
topic := reservations[i].Topic
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
return nil, err
}
removedTopics = append(removedTopics, topic)
}
return removedTopics, nil
})
}
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
@@ -845,14 +909,11 @@ func (a *Manager) otherAccessCount(username, topic string) (int, error) {
return count, nil
}
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topic) {
return ErrInvalidArgument
func (a *Manager) removeReservationAccessTx(tx *sql.Tx, username, topic string) error {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err
}
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
return err
return a.resetTopicAccessTx(tx, Everyone, topic)
}
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
@@ -1134,7 +1195,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
rows, err := a.db.Query(a.queries.selectTierByCode, code)
if err != nil {
return nil, err
}
@@ -1144,7 +1205,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
if err != nil {
return nil, err
}

View File

@@ -226,7 +226,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
// Create user, add reservations and token
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead, 0))
u, err := a.User("user")
require.Nil(t, err)
@@ -439,8 +439,8 @@ func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, newManager, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll, 0))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead, 0))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
reservations, err := a.Reservations("ben")
@@ -523,7 +523,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
}))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.ChangeTier("ben", "pro"))
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll, 0))
ben, err := a.User("ben")
require.Nil(t, err)
@@ -1076,7 +1076,7 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) {
// Add 10 reservations (pro tier allows that)
for i := 0; i < 4; i++ {
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite))
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite, 0))
}
// Downgrading will not work (too many reservations)
@@ -2118,7 +2118,7 @@ func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
func TestStoreReservations(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead, 0))
reservations, err := manager.Reservations("phil")
require.Nil(t, err)
@@ -2133,8 +2133,8 @@ func TestStoreReservations(t *testing.T) {
func TestStoreReservationsCount(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite))
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite))
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 0))
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite, 0))
count, err := manager.ReservationsCount("phil")
require.Nil(t, err)
@@ -2145,7 +2145,7 @@ func TestStoreReservationsCount(t *testing.T) {
func TestStoreHasReservation(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
has, err := manager.HasReservation("phil", "mytopic")
require.Nil(t, err)
@@ -2160,7 +2160,7 @@ func TestStoreHasReservation(t *testing.T) {
func TestStoreReservationOwner(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
owner, err := manager.ReservationOwner("mytopic")
require.Nil(t, err)
@@ -2172,6 +2172,26 @@ func TestStoreReservationOwner(t *testing.T) {
})
}
func TestStoreAddReservationWithLimit(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
// Adding reservations within limit succeeds
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 2))
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionRead, 2))
// Adding a third reservation exceeds the limit
require.Equal(t, ErrTooManyReservations, manager.AddReservation("phil", "topic3", PermissionRead, 2))
// Updating an existing reservation within the limit succeeds
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionRead, 2))
reservations, err := manager.Reservations("phil")
require.Nil(t, err)
require.Len(t, reservations, 2)
})
}
func TestStoreTiers(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
tier := &Tier{
@@ -2431,7 +2451,7 @@ func TestStoreOtherAccessCount(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false))
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite))
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite, 0))
count, err := manager.otherAccessCount("phil", "mytopic")
require.Nil(t, err)

View File

@@ -17,6 +17,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/gabriel-vasile/mimetype"
"golang.org/x/term"
@@ -434,3 +435,22 @@ func Int(v int) *int {
func Time(v time.Time) *time.Time {
return &v
}
// SanitizeUTF8 ensures a string is safe to store in PostgreSQL by handling two cases:
//
// 1. Invalid UTF-8 sequences: Some clients send Latin-1/ISO-8859-1 encoded text (e.g. accented
// characters like é, ñ, ß) in HTTP headers or SMTP messages. Go treats these as raw bytes in
// strings, but PostgreSQL rejects them. Any invalid UTF-8 byte is replaced with the Unicode
// replacement character (U+FFFD, "<22>") so the message is still delivered rather than lost.
//
// 2. NUL bytes (0x00): These are valid in UTF-8 but PostgreSQL TEXT columns reject them.
// They are stripped entirely.
func SanitizeUTF8(s string) string {
if !utf8.ValidString(s) {
s = strings.ToValidUTF8(s, "\xef\xbf\xbd") // U+FFFD
}
if strings.ContainsRune(s, 0) {
s = strings.ReplaceAll(s, "\x00", "")
}
return s
}

36
web/package-lock.json generated
View File

@@ -3642,9 +3642,9 @@
"license": "MIT"
},
"node_modules/baseline-browser-mapping": {
"version": "2.10.0",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.0.tgz",
"integrity": "sha512-lIyg0szRfYbiy67j9KN8IyeD7q7hcmqnJ1ddWmNt19ItGpNN64mnllmxUNFIOdOm6by97jlL6wfpTTJrmnjWAA==",
"version": "2.10.8",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
"dev": true,
"license": "Apache-2.0",
"bin": {
@@ -3766,9 +3766,9 @@
}
},
"node_modules/caniuse-lite": {
"version": "1.0.30001777",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001777.tgz",
"integrity": "sha512-tmN+fJxroPndC74efCdp12j+0rk0RHwV5Jwa1zWaFVyw2ZxAuPeG8ZgWC3Wz7uSjT3qMRQ5XHZ4COgQmsCMJAQ==",
"version": "1.0.30001779",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001779.tgz",
"integrity": "sha512-U5og2PN7V4DMgF50YPNtnZJGWVLFjjsN3zb6uMT5VGYIewieDj1upwfuVNXf4Kor+89c3iCRJnSzMD5LmTvsfA==",
"dev": true,
"funding": [
{
@@ -4203,9 +4203,9 @@
}
},
"node_modules/electron-to-chromium": {
"version": "1.5.307",
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.307.tgz",
"integrity": "sha512-5z3uFKBWjiNR44nFcYdkcXjKMbg5KXNdciu7mhTPo9tB7NbqSNP2sSnGR+fqknZSCwKkBN+oxiiajWs4dT6ORg==",
"version": "1.5.313",
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.313.tgz",
"integrity": "sha512-QBMrTWEf00GXZmJyx2lbYD45jpI3TUFnNIzJ5BBc8piGUDwMPa1GV6HJWTZVvY/eiN3fSopl7NRbgGp9sZ9LTA==",
"dev": true,
"license": "ISC"
},
@@ -4324,9 +4324,9 @@
}
},
"node_modules/es-iterator-helpers": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.0.tgz",
"integrity": "sha512-04cg8iJFDOxWcYlu0GFFWgs7vtaEPCmr5w1nrj9V3z3axu/48HCMwK6VMp45Zh3ZB+xLP1ifbJfrq86+1ypKKQ==",
"version": "1.3.1",
"resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.1.tgz",
"integrity": "sha512-zWwRvqWiuBPr0muUG/78cW3aHROFCNIQ3zpmYDpwdbnt2m+xlNyRWpHBpa2lJjSBit7BQ+RXA1iwbSmu5yJ/EQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -7043,9 +7043,9 @@
}
},
"node_modules/path-scurry/node_modules/lru-cache": {
"version": "11.2.6",
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.6.tgz",
"integrity": "sha512-ESL2CrkS/2wTPfuend7Zhkzo2u0daGJ/A2VucJOgQ/C48S/zB8MMeMHSGKYpXhIjbPxfuezITkaBH1wqv00DDQ==",
"version": "11.2.7",
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz",
"integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==",
"dev": true,
"license": "BlueOak-1.0.0",
"engines": {
@@ -8307,9 +8307,9 @@
}
},
"node_modules/terser": {
"version": "5.46.0",
"resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz",
"integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==",
"version": "5.46.1",
"resolved": "https://registry.npmjs.org/terser/-/terser-5.46.1.tgz",
"integrity": "sha512-vzCjQO/rgUuK9sf8VJZvjqiqiHFaZLnOiimmUuOKODxWL8mm/xua7viT7aqX7dgPY60otQjUotzFMmCB4VdmqQ==",
"dev": true,
"license": "BSD-2-Clause",
"dependencies": {

View File

@@ -120,7 +120,7 @@
"publish_dialog_priority_low": "Prioridad baja",
"publish_dialog_priority_high": "Prioridad alta",
"publish_dialog_delay_label": "Retraso",
"publish_dialog_title_placeholder": "Título de la notificación, por ejemplo, Alerta de espacio en disco",
"publish_dialog_title_placeholder": "Título de la notificación, ej. Alerta de espacio en disco",
"publish_dialog_details_examples_description": "Para ver ejemplos y una descripción detallada de todas las funciones de envío, consulte la <docsLink>documentación</docsLink>.",
"publish_dialog_attach_placeholder": "Adjuntar un archivo por URL, por ejemplo, https://f-droid.org/F-Droid.apk",
"publish_dialog_filename_placeholder": "Nombre del archivo adjunto",

View File

@@ -63,9 +63,10 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string,
} else if err != nil {
return err
}
// Insert or update subscription
// Insert or update subscription, and read back the actual ID (which may differ from
// the generated one if another request for the same endpoint raced us and inserted first)
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
if err := tx.QueryRow(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt).Scan(&subscriptionID); err != nil {
return err
}
// Replace all subscription topics

View File

@@ -53,6 +53,7 @@ const (
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
RETURNING id
`
postgresUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
postgresUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`

View File

@@ -56,8 +56,9 @@ const (
sqliteUpsertSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
RETURNING id
`
sqliteUpdateSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
sqliteUpdateSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`