mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Compare commits
25 Commits
3296d158c5
...
attachment
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b11bc7468 | ||
|
|
d9efe50848 | ||
|
|
2ad78edca1 | ||
|
|
86015e100c | ||
|
|
458fbad770 | ||
|
|
9b1a32ec56 | ||
|
|
3d9ce69042 | ||
|
|
59ce581ba2 | ||
|
|
df82fdf44c | ||
|
|
3a37ea32f7 | ||
|
|
790ba243c7 | ||
|
|
4487299a80 | ||
|
|
6b38acb23a | ||
|
|
f5c255c53c | ||
|
|
fd0a49244e | ||
|
|
4699ed3ffd | ||
|
|
1afb99db67 | ||
|
|
66208e6f88 | ||
|
|
ce24594c32 | ||
|
|
888850d8bc | ||
|
|
be09acd411 | ||
|
|
bf19a5be2d | ||
|
|
b4ec6fa8df | ||
|
|
d517ce4a2a | ||
|
|
fd8f356d1f |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,7 @@ server/site/
|
||||
tools/fbsend/fbsend
|
||||
tools/pgimport/pgimport
|
||||
tools/loadtest/loadtest
|
||||
tools/s3cli/s3cli
|
||||
playground/
|
||||
secrets/
|
||||
*.iml
|
||||
|
||||
25
attachment/store.go
Normal file
25
attachment/store.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Store is an interface for storing and retrieving attachment files
|
||||
type Store interface {
|
||||
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
|
||||
Read(id string) (io.ReadCloser, int64, error)
|
||||
Remove(ids ...string) error
|
||||
Size() int64
|
||||
Remaining() int64
|
||||
}
|
||||
|
||||
var (
|
||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
|
||||
errInvalidFileID = errors.New("invalid file ID")
|
||||
)
|
||||
@@ -1,32 +1,29 @@
|
||||
package server
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
|
||||
errInvalidFileID = errors.New("invalid file ID")
|
||||
errFileExists = errors.New("file exists")
|
||||
)
|
||||
const tagFileStore = "file_store"
|
||||
|
||||
type fileCache struct {
|
||||
var errFileExists = errors.New("file exists")
|
||||
|
||||
type fileStore struct {
|
||||
dir string
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
|
||||
// NewFileStore creates a new file-system backed attachment store
|
||||
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fileCache{
|
||||
return &fileStore{
|
||||
dir: dir,
|
||||
totalSizeCurrent: size,
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
return 0, errFileExists
|
||||
@@ -68,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
mset(metricAttachmentsTotalSize, c.totalSizeCurrent)
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Remove(ids ...string) error {
|
||||
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
file := filepath.Join(c.dir, id)
|
||||
stat, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, stat.Size(), nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if err := os.Remove(file); err != nil {
|
||||
log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
}
|
||||
}
|
||||
size, err := dirSize(c.dir)
|
||||
@@ -91,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
mset(metricAttachmentsTotalSize, size)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fileCache) Size() int64 {
|
||||
func (c *fileStore) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *fileCache) Remaining() int64 {
|
||||
func (c *fileStore) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
99
attachment/store_file_test.go
Normal file
99
attachment/store_file_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
oneKilobyteArray = make([]byte, 1024)
|
||||
)
|
||||
|
||||
func TestFileStore_Write_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
|
||||
require.Equal(t, int64(11), s.Size())
|
||||
require.Equal(t, int64(10229), s.Remaining())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_Read_Success(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
|
||||
reader, readSize, err := s.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
defer reader.Close()
|
||||
data, err := io.ReadAll(reader)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
}
|
||||
|
||||
func TestFileStore_Write_Remove_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
|
||||
for i := 0; i < 10; i++ { // 10x999 = 9990
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(999), size)
|
||||
}
|
||||
require.Equal(t, int64(9990), s.Size())
|
||||
require.Equal(t, int64(250), s.Remaining())
|
||||
require.FileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk5")
|
||||
|
||||
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.NoFileExists(t, dir+"/abcdefghijk5")
|
||||
require.Equal(t, int64(7992), s.Size())
|
||||
require.Equal(t, int64(2248), s.Remaining())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
for i := 0; i < 10; i++ {
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024), size)
|
||||
}
|
||||
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkX")
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkl")
|
||||
}
|
||||
|
||||
func TestFileStore_Read_NotFound(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
_, _, err := s.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func newTestFileStore(t *testing.T) (dir string, store Store) {
|
||||
dir = t.TempDir()
|
||||
store, err := NewFileStore(dir, 10*1024)
|
||||
require.Nil(t, err)
|
||||
return dir, store
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, f string) string {
|
||||
b, err := os.ReadFile(f)
|
||||
require.Nil(t, err)
|
||||
return string(b)
|
||||
}
|
||||
150
attachment/store_s3.go
Normal file
150
attachment/store_s3.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagS3Store = "s3_store"
|
||||
)
|
||||
|
||||
type s3Store struct {
|
||||
client *s3.Client
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
|
||||
cfg, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
store := &s3Store{
|
||||
client: s3.New(cfg),
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
if totalSizeLimit > 0 {
|
||||
size, err := store.computeSize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
|
||||
}
|
||||
store.totalSizeCurrent = size
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
|
||||
|
||||
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
|
||||
// uploads, so no temp file or Content-Length is needed.
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
pr, pw := io.Pipe()
|
||||
lw := util.NewLimitWriter(pw, limiters...)
|
||||
var size int64
|
||||
var copyErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
size, copyErr = io.Copy(lw, in)
|
||||
if copyErr != nil {
|
||||
pw.CloseWithError(copyErr)
|
||||
} else {
|
||||
pw.Close()
|
||||
}
|
||||
}()
|
||||
putErr := c.client.PutObject(context.Background(), id, pr)
|
||||
pr.Close()
|
||||
<-done
|
||||
if copyErr != nil {
|
||||
return 0, copyErr
|
||||
}
|
||||
if putErr != nil {
|
||||
return 0, putErr
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
return c.client.GetObject(context.Background(), id)
|
||||
}
|
||||
|
||||
func (c *s3Store) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
}
|
||||
// S3 DeleteObjects supports up to 1000 keys per call
|
||||
for i := 0; i < len(ids); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
batch := ids[i:end]
|
||||
for _, id := range batch {
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
|
||||
}
|
||||
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
|
||||
size, err := c.computeSize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *s3Store) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
|
||||
func (c *s3Store) computeSize() (int64, error) {
|
||||
objects, err := c.client.ListAllObjects(context.Background())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
return totalSize, nil
|
||||
}
|
||||
367
attachment/store_s3_test.go
Normal file
367
attachment/store_s3_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package attachment
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// --- Integration tests using a mock S3 server ---
|
||||
|
||||
func TestS3Store_WriteReadRemove(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
// Write
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, int64(11), store.Size())
|
||||
|
||||
// Read back
|
||||
reader, readSize, err := store.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
|
||||
// Remove
|
||||
require.Nil(t, store.Remove("abcdefghijkl"))
|
||||
require.Equal(t, int64(0), store.Size())
|
||||
|
||||
// Read after remove should fail
|
||||
_, _, err = store.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteNoPrefix(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
|
||||
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(4), size)
|
||||
|
||||
reader, _, err := store.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test", string(data))
|
||||
}
|
||||
|
||||
func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
|
||||
|
||||
// First write fits
|
||||
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(80), store.Size())
|
||||
require.Equal(t, int64(20), store.Remaining())
|
||||
|
||||
// Second write exceeds total limit
|
||||
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, int64(500), store.Size())
|
||||
|
||||
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
|
||||
require.Equal(t, int64(300), store.Size())
|
||||
}
|
||||
|
||||
func TestS3Store_ReadNotFound(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, _, err := store.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestS3Store_InvalidID(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("bad", strings.NewReader("x"))
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
_, _, err = store.Read("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
err = store.Remove("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
|
||||
t.Helper()
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
s := &s3Store{
|
||||
client: &s3.Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
},
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
// Compute initial size (should be 0 for fresh mock)
|
||||
size, err := s.computeSize()
|
||||
require.Nil(t, err)
|
||||
s.totalSizeCurrent = size
|
||||
return s
|
||||
}
|
||||
|
||||
// --- Mock S3 server ---
|
||||
//
|
||||
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() *httptest.Server {
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[path] = body
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
// bucketPath is just the bucket name
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Objects []struct {
|
||||
Key string `xml:"Key"`
|
||||
} `xml:"Object"`
|
||||
}
|
||||
if err := xml.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
for _, obj := range req.Objects {
|
||||
delete(m.objects, bucketPath+"/"+obj.Key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
prefix := r.URL.Query().Get("prefix")
|
||||
m.mu.RLock()
|
||||
var contents []s3ListObject
|
||||
for key, body := range m.objects {
|
||||
// key is "bucket/objectkey", strip bucket prefix
|
||||
objKey := strings.TrimPrefix(key, bucketPath+"/")
|
||||
if objKey == key {
|
||||
continue // different bucket
|
||||
}
|
||||
if prefix == "" || strings.HasPrefix(objKey, prefix) {
|
||||
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
resp := s3ListResponse{
|
||||
Contents: contents,
|
||||
IsTruncated: false,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
xml.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
type s3ListResponse struct {
|
||||
XMLName xml.Name `xml:"ListBucketResult"`
|
||||
Contents []s3ListObject `xml:"Contents"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
}
|
||||
|
||||
type s3ListObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -14,9 +11,14 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/test"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
|
||||
t.Skip("temporarily disabled") // FIXME
|
||||
testMessage := util.RandomString(10)
|
||||
app, _, _, _ := newTestApp()
|
||||
require.Nil(t, app.Run([]string{"ntfy", "publish", "ntfytest", "ntfy unit test " + testMessage}))
|
||||
|
||||
11
cmd/serve.go
11
cmd/serve.go
@@ -53,6 +53,7 @@ var flagsServe = append(
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}),
|
||||
@@ -166,6 +167,7 @@ func execServe(c *cli.Context) error {
|
||||
authAccessRaw := c.StringSlice("auth-access")
|
||||
authTokensRaw := c.StringSlice("auth-tokens")
|
||||
attachmentCacheDir := c.String("attachment-cache-dir")
|
||||
attachmentS3URL := c.String("attachment-s3-url")
|
||||
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
|
||||
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
|
||||
attachmentExpiryDurationStr := c.String("attachment-expiry-duration")
|
||||
@@ -284,8 +286,8 @@ func execServe(c *cli.Context) error {
|
||||
}
|
||||
|
||||
// Check values
|
||||
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") {
|
||||
return errors.New("if database-url is set, it must start with postgres://")
|
||||
if databaseURL != "" && !strings.HasPrefix(databaseURL, "postgres://") && !strings.HasPrefix(databaseURL, "postgresql://") {
|
||||
return errors.New("if database-url is set, it must start with postgres:// or postgresql://")
|
||||
} else if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||
} else if len(databaseReplicaURLs) > 0 && databaseURL == "" {
|
||||
@@ -314,6 +316,10 @@ func execServe(c *cli.Context) error {
|
||||
return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set")
|
||||
} else if attachmentCacheDir != "" && baseURL == "" {
|
||||
return errors.New("if attachment-cache-dir is set, base-url must also be set")
|
||||
} else if attachmentS3URL != "" && baseURL == "" {
|
||||
return errors.New("if attachment-s3-url is set, base-url must also be set")
|
||||
} else if attachmentS3URL != "" && attachmentCacheDir != "" {
|
||||
return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive")
|
||||
} else if baseURL != "" {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
@@ -457,6 +463,7 @@ func execServe(c *cli.Context) error {
|
||||
conf.AuthAccess = authAccess
|
||||
conf.AuthTokens = authTokens
|
||||
conf.AttachmentCacheDir = attachmentCacheDir
|
||||
conf.AttachmentS3URL = attachmentS3URL
|
||||
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
|
||||
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
|
||||
conf.AttachmentExpiryDuration = attachmentExpiryDuration
|
||||
|
||||
@@ -11,6 +11,12 @@ type Beginner interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// Querier is an interface for types that can execute SQL queries.
|
||||
// *sql.DB, *sql.Tx, and *DB all implement this.
|
||||
type Querier interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// Host pairs a *sql.DB with the host:port it was opened against.
|
||||
type Host struct {
|
||||
Addr string // "host:port"
|
||||
|
||||
@@ -489,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri
|
||||
|
||||
## Attachments
|
||||
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
|
||||
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
|
||||
Once these options are set and the directory is writable by the server user, you can upload attachments via PUT.
|
||||
this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored
|
||||
either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`).
|
||||
Once configured, you can upload attachments via PUT.
|
||||
|
||||
By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues
|
||||
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
|
||||
By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues
|
||||
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
|
||||
feature) to download the file. The following config options are relevant to attachments:
|
||||
|
||||
* `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs
|
||||
* `attachment-cache-dir` is the cache directory for attached files
|
||||
* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G)
|
||||
* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`)
|
||||
* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`)
|
||||
* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G)
|
||||
* `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M)
|
||||
* `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h)
|
||||
|
||||
Here's an example config using mostly the defaults (except for the cache directory, which is empty by default):
|
||||
### Filesystem storage
|
||||
Here's an example config using the local filesystem for attachment storage:
|
||||
|
||||
=== "/etc/ntfy/server.yml (minimal)"
|
||||
``` yaml
|
||||
@@ -521,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo
|
||||
visitor-attachment-daily-bandwidth-limit: "500M"
|
||||
```
|
||||
|
||||
### S3 storage
|
||||
As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3,
|
||||
MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage.
|
||||
|
||||
The `attachment-s3-url` option uses the following format:
|
||||
|
||||
```
|
||||
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
```
|
||||
|
||||
When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores).
|
||||
|
||||
=== "/etc/ntfy/server.yml (AWS S3)"
|
||||
``` yaml
|
||||
base-url: "https://ntfy.sh"
|
||||
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1"
|
||||
```
|
||||
|
||||
=== "/etc/ntfy/server.yml (MinIO/custom endpoint)"
|
||||
``` yaml
|
||||
base-url: "https://ntfy.sh"
|
||||
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com"
|
||||
```
|
||||
|
||||
Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit`
|
||||
and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse.
|
||||
|
||||
@@ -2116,7 +2143,8 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
|
||||
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
|
||||
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
|
||||
| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header |
|
||||
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. |
|
||||
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. |
|
||||
| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. |
|
||||
| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. |
|
||||
| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. |
|
||||
| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. |
|
||||
@@ -2219,6 +2247,7 @@ OPTIONS:
|
||||
--auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES]
|
||||
--auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS]
|
||||
--attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR]
|
||||
--attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL]
|
||||
--attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
|
||||
--attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
|
||||
--attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION]
|
||||
|
||||
@@ -30,37 +30,37 @@ deb/rpm packages.
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.18.0_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.18.0_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.18.0_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.18.0_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.18.0_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.18.0_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.18.0_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.18.0_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.18.0_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.19.2_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.19.2_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
@@ -116,7 +116,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -124,7 +124,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -132,7 +132,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -140,7 +140,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -150,28 +150,28 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_amd64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_amd64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv6.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv6.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_armv7.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_armv7.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_linux_arm64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_linux_arm64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
@@ -213,18 +213,18 @@ pkg install go-ntfy
|
||||
|
||||
## macOS
|
||||
The [ntfy CLI](subscribe/cli.md) (`ntfy publish` and `ntfy subscribe` only) is supported on macOS as well.
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_darwin_all.tar.gz),
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz),
|
||||
extract it and place it somewhere in your `PATH` (e.g. `/usr/local/bin/ntfy`).
|
||||
|
||||
If run as `root`, ntfy will look for its config at `/etc/ntfy/client.yml`. For all other users, it'll look for it at
|
||||
`~/Library/Application Support/ntfy/client.yml` (sample included in the tarball).
|
||||
|
||||
```bash
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_darwin_all.tar.gz > ntfy_2.18.0_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.18.0_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.18.0_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_darwin_all.tar.gz > ntfy_2.19.2_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.19.2_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.19.2_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
mkdir ~/Library/Application\ Support/ntfy
|
||||
cp ntfy_2.18.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
cp ntfy_2.19.2_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
ntfy --help
|
||||
```
|
||||
|
||||
@@ -245,7 +245,7 @@ brew install ntfy
|
||||
The ntfy server and CLI are fully supported on Windows. You can run the ntfy server directly or as a Windows service.
|
||||
To install, you can either
|
||||
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.18.0/ntfy_2.18.0_windows_amd64.zip),
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.19.2/ntfy_2.19.2_windows_amd64.zip),
|
||||
extract it and place the `ntfy.exe` binary somewhere in your `%Path%`.
|
||||
* Or install ntfy from the [Scoop](https://scoop.sh) main repository via `scoop install ntfy`
|
||||
|
||||
|
||||
@@ -6,12 +6,55 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
|
||||
| Component | Version | Release date |
|
||||
|------------------|---------|--------------|
|
||||
| ntfy server | v2.18.0 | Mar 7, 2026 |
|
||||
| ntfy server | v2.19.2 | Mar 16, 2026 |
|
||||
| ntfy Android app | v1.24.0 | Mar 5, 2026 |
|
||||
| ntfy iOS app | v1.3 | Nov 26, 2023 |
|
||||
|
||||
Please check out the release notes for [upcoming releases](#not-released-yet) below.
|
||||
|
||||
### ntfy server v2.19.2
|
||||
Released March 16, 2026
|
||||
|
||||
This is another small bugfix release for PostgreSQL, avoiding races between primary and read replica, as well as to
|
||||
further reduce primary load.
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix race condition in web push subscription causing FK constraint violation when concurrent requests hit the same endpoint
|
||||
* Route authorization query to read-only database replica to reduce primary database load
|
||||
|
||||
## ntfy server v2.19.1
|
||||
Released March 15, 2026
|
||||
|
||||
This is a bugfix release to avoid PostgreSQL insert failures due to invalid UTF-8 messages. It also fixes `database-url`
|
||||
validation incorrectly rejecting `postgresql://` connection strings.
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Fix invalid UTF-8 in HTTP headers (e.g. Latin-1 encoded text) causing PostgreSQL insert failures and dropping entire message batches
|
||||
* Fix `database-url` validation rejecting `postgresql://` connection strings ([#1657](https://github.com/binwiederhier/ntfy/issues/1657)/[#1658](https://github.com/binwiederhier/ntfy/pull/1658))
|
||||
|
||||
## ntfy server v2.19.0
|
||||
Released March 15, 2026
|
||||
|
||||
This is a fast-follow release that enables Postgres read replica support.
|
||||
|
||||
To offload read-heavy queries from the primary database, you can optionally configure one or more read replicas
|
||||
using the `database-replica-urls` option. When configured, non-critical read-only queries (e.g. fetching messages,
|
||||
checking access permissions, etc) are distributed across the replicas using round-robin, while all writes and
|
||||
correctness-critical reads continue to go to the primary. If a replica becomes unhealthy, ntfy automatically falls back
|
||||
to the primary until the replica recovers.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Support [PostgreSQL read replicas](config.md#postgresql-experimental) for offloading non-critical read queries via `database-replica-urls` config option ([#1648](https://github.com/binwiederhier/ntfy/pull/1648))
|
||||
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files ([#1654](https://github.com/binwiederhier/ntfy/pull/1654))
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
|
||||
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
|
||||
|
||||
## ntfy server v2.18.0
|
||||
Released March 7, 2026
|
||||
|
||||
@@ -1755,14 +1798,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
|
||||
## Not released yet
|
||||
|
||||
### ntfy server v2.19.x (UNRELEASED)
|
||||
### ntfy server v2.20.x (UNRELEASED)
|
||||
|
||||
**Features:**
|
||||
|
||||
* Support PostgreSQL read replicas for offloading non-critical read queries via `database-replica-urls` config option
|
||||
* Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files
|
||||
* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Web: Throttle notification sound in web app to play at most once every 2 seconds (similar to [#1550](https://github.com/binwiederhier/ntfy/issues/1550), thanks to [@jlaffaye](https://github.com/jlaffaye) for reporting)
|
||||
* Web: Add hover tooltips to icon buttons in web app account and preferences pages ([#1565](https://github.com/binwiederhier/ntfy/issues/1565), thanks to [@jermanuts](https://github.com/jermanuts) for reporting)
|
||||
* Reject invalid e-mail addresses (e.g. multiple comma-separated recipients) with HTTP 400
|
||||
|
||||
2
go.mod
2
go.mod
@@ -4,7 +4,7 @@ go 1.25.0
|
||||
|
||||
require (
|
||||
cloud.google.com/go/firestore v1.21.0 // indirect
|
||||
cloud.google.com/go/storage v1.61.1 // indirect
|
||||
cloud.google.com/go/storage v1.61.3 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
|
||||
github.com/emersion/go-smtp v0.18.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -18,8 +18,8 @@ cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7
|
||||
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
|
||||
cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE=
|
||||
cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI=
|
||||
cloud.google.com/go/storage v1.61.1 h1:VELCSvZKiSw0AS1k3so5mKGy3CB7bTCYD8EHhTF42bY=
|
||||
cloud.google.com/go/storage v1.61.1/go.mod h1:k30/hwYfd0M8aULYbPkQLgNf+SFcdjlRHvLMXggw18E=
|
||||
cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg=
|
||||
cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk=
|
||||
cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U=
|
||||
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
|
||||
firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8=
|
||||
|
||||
@@ -125,16 +125,16 @@ func (c *Cache) addMessages(ms []*model.Message) error {
|
||||
return model.ErrUnexpectedMessageType
|
||||
}
|
||||
published := m.Time <= time.Now().Unix()
|
||||
tags := strings.Join(m.Tags, ",")
|
||||
tags := util.SanitizeUTF8(strings.Join(m.Tags, ","))
|
||||
var attachmentName, attachmentType, attachmentURL string
|
||||
var attachmentSize, attachmentExpires int64
|
||||
var attachmentDeleted bool
|
||||
if m.Attachment != nil {
|
||||
attachmentName = m.Attachment.Name
|
||||
attachmentType = m.Attachment.Type
|
||||
attachmentName = util.SanitizeUTF8(m.Attachment.Name)
|
||||
attachmentType = util.SanitizeUTF8(m.Attachment.Type)
|
||||
attachmentSize = m.Attachment.Size
|
||||
attachmentExpires = m.Attachment.Expires
|
||||
attachmentURL = m.Attachment.URL
|
||||
attachmentURL = util.SanitizeUTF8(m.Attachment.URL)
|
||||
}
|
||||
var actionsStr string
|
||||
if len(m.Actions) > 0 {
|
||||
@@ -154,13 +154,13 @@ func (c *Cache) addMessages(ms []*model.Message) error {
|
||||
m.Time,
|
||||
m.Event,
|
||||
m.Expires,
|
||||
m.Topic,
|
||||
m.Message,
|
||||
m.Title,
|
||||
util.SanitizeUTF8(m.Topic),
|
||||
util.SanitizeUTF8(m.Message),
|
||||
util.SanitizeUTF8(m.Title),
|
||||
m.Priority,
|
||||
tags,
|
||||
m.Click,
|
||||
m.Icon,
|
||||
util.SanitizeUTF8(m.Click),
|
||||
util.SanitizeUTF8(m.Icon),
|
||||
actionsStr,
|
||||
attachmentName,
|
||||
attachmentType,
|
||||
@@ -170,7 +170,7 @@ func (c *Cache) addMessages(ms []*model.Message) error {
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
m.ContentType,
|
||||
util.SanitizeUTF8(m.ContentType),
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
|
||||
@@ -827,3 +827,141 @@ func TestStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// 0xc9 0x43: Latin-1 "ÉC" — 0xc9 starts a 2-byte UTF-8 sequence but 0x43 ('C') is not a continuation byte
|
||||
m := model.NewDefaultMessage("mytopic", "\xc9Cas du serveur")
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "\uFFFDCas du serveur", messages[0].Message)
|
||||
|
||||
// 0xae: Latin-1 "®" — isolated byte above 0x7F, not a valid UTF-8 start for single byte
|
||||
m2 := model.NewDefaultMessage("mytopic", "Product\xae Pro")
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Product\uFFFD Pro", messages[1].Message)
|
||||
|
||||
// 0xe8 0x6d 0x65: Latin-1 "ème" — 0xe8 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
|
||||
m3 := model.NewDefaultMessage("mytopic", "probl\xe8me critique")
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "probl\uFFFDme critique", messages[2].Message)
|
||||
|
||||
// 0xb2: Latin-1 "²" — isolated byte in 0x80-0xBF range (UTF-8 continuation byte without lead)
|
||||
m4 := model.NewDefaultMessage("mytopic", "CO\xb2 level high")
|
||||
require.Nil(t, s.AddMessage(m4))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "CO\uFFFD level high", messages[3].Message)
|
||||
|
||||
// 0xe9 0x6d 0x61: Latin-1 "éma" — 0xe9 starts a 3-byte UTF-8 sequence but 0x6d ('m') is not a continuation byte
|
||||
m5 := model.NewDefaultMessage("mytopic", "th\xe9matique")
|
||||
require.Nil(t, s.AddMessage(m5))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "th\uFFFDmatique", messages[4].Message)
|
||||
|
||||
// 0xed 0x64 0x65: Latin-1 "íde" — 0xed starts a 3-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
|
||||
m6 := model.NewDefaultMessage("mytopic", "vid\xed\x64eo surveillance")
|
||||
require.Nil(t, s.AddMessage(m6))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "vid\uFFFDdeo surveillance", messages[5].Message)
|
||||
|
||||
// 0xf3 0x6e 0x3a 0x20: Latin-1 "ón: " — 0xf3 starts a 4-byte UTF-8 sequence but 0x6e ('n') is not a continuation byte
|
||||
m7 := model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta")
|
||||
require.Nil(t, s.AddMessage(m7))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "notificaci\uFFFDn: alerta", messages[6].Message)
|
||||
|
||||
// 0xb7: Latin-1 "·" — isolated continuation byte
|
||||
m8 := model.NewDefaultMessage("mytopic", "item\xb7value")
|
||||
require.Nil(t, s.AddMessage(m8))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "item\uFFFDvalue", messages[7].Message)
|
||||
|
||||
// 0xa8: Latin-1 "¨" — isolated continuation byte
|
||||
m9 := model.NewDefaultMessage("mytopic", "na\xa8ve")
|
||||
require.Nil(t, s.AddMessage(m9))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "na\uFFFDve", messages[8].Message)
|
||||
|
||||
// 0xdf 0x64: Latin-1 "ßd" — 0xdf starts a 2-byte UTF-8 sequence but 0x64 ('d') is not a continuation byte
|
||||
m10 := model.NewDefaultMessage("mytopic", "gro\xdf\x64ruck")
|
||||
require.Nil(t, s.AddMessage(m10))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "gro\uFFFDdruck", messages[9].Message)
|
||||
|
||||
// 0xe4 0x67 0x74: Latin-1 "ägt" — 0xe4 starts a 3-byte UTF-8 sequence but 0x67 ('g') is not a continuation byte
|
||||
m11 := model.NewDefaultMessage("mytopic", "tr\xe4gt Last")
|
||||
require.Nil(t, s.AddMessage(m11))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tr\uFFFDgt Last", messages[10].Message)
|
||||
|
||||
// 0xe9 0x65 0x20: Latin-1 "ée " — 0xe9 starts a 3-byte UTF-8 sequence but 0x65 ('e') is not a continuation byte
|
||||
m12 := model.NewDefaultMessage("mytopic", "journ\xe9\x65 termin\xe9\x65")
|
||||
require.Nil(t, s.AddMessage(m12))
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "journ\uFFFDe termin\uFFFDe", messages[11].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_NullByte(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// 0x00: NUL byte — valid UTF-8 but rejected by PostgreSQL
|
||||
m := model.NewDefaultMessage("mytopic", "hello\x00world")
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "helloworld", messages[0].Message)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8InTitleAndTags(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Invalid UTF-8 can arrive via HTTP headers (Title, Tags) which bypass body validation
|
||||
m := model.NewDefaultMessage("mytopic", "valid message")
|
||||
m.Title = "\xc9clipse du syst\xe8me"
|
||||
m.Tags = []string{"probl\xe8me", "syst\xe9me"}
|
||||
m.Click = "https://example.com/\xae"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", messages[0].Title)
|
||||
require.Equal(t, "probl\uFFFDme", messages[0].Tags[0])
|
||||
require.Equal(t, "syst\uFFFDme", messages[0].Tags[1])
|
||||
require.Equal(t, "https://example.com/\uFFFD", messages[0].Click)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_AddMessage_InvalidUTF8BatchDoesNotDropValidMessages(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||
// Previously, a single invalid message would roll back the entire batch transaction.
|
||||
// Sanitization ensures all messages in a batch are written successfully.
|
||||
msgs := []*model.Message{
|
||||
model.NewDefaultMessage("mytopic", "valid message 1"),
|
||||
model.NewDefaultMessage("mytopic", "notificaci\xf3n: alerta"),
|
||||
model.NewDefaultMessage("mytopic", "valid message 3"),
|
||||
}
|
||||
require.Nil(t, s.AddMessages(msgs))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(messages))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,6 +70,26 @@ func (m *Message) Context() log.Context {
|
||||
return fields
|
||||
}
|
||||
|
||||
// SanitizeUTF8 replaces invalid UTF-8 sequences and strips NUL bytes from all user-supplied
|
||||
// string fields. This is called early in the publish path so that all downstream consumers
|
||||
// (Firebase, WebPush, SMTP, cache) receive clean UTF-8 strings.
|
||||
func (m *Message) SanitizeUTF8() {
|
||||
m.Topic = util.SanitizeUTF8(m.Topic)
|
||||
m.Message = util.SanitizeUTF8(m.Message)
|
||||
m.Title = util.SanitizeUTF8(m.Title)
|
||||
m.Click = util.SanitizeUTF8(m.Click)
|
||||
m.Icon = util.SanitizeUTF8(m.Icon)
|
||||
m.ContentType = util.SanitizeUTF8(m.ContentType)
|
||||
for i, tag := range m.Tags {
|
||||
m.Tags[i] = util.SanitizeUTF8(tag)
|
||||
}
|
||||
if m.Attachment != nil {
|
||||
m.Attachment.Name = util.SanitizeUTF8(m.Attachment.Name)
|
||||
m.Attachment.Type = util.SanitizeUTF8(m.Attachment.Type)
|
||||
m.Attachment.URL = util.SanitizeUTF8(m.Attachment.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// ForJSON returns a copy of the message suitable for JSON output.
|
||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||
func (m *Message) ForJSON() *Message {
|
||||
|
||||
488
s3/client.go
Normal file
488
s3/client.go
Normal file
@@ -0,0 +1,488 @@
|
||||
// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces,
|
||||
// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP
|
||||
// requests with AWS Signature V4 signing, no AWS SDK dependency required.
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects,
|
||||
// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix
|
||||
// are fixed at construction time. All operations target the same bucket and prefix.
|
||||
//
|
||||
// Fields must not be modified after the Client is passed to any method or goroutine.
|
||||
type Client struct {
|
||||
AccessKey string // AWS access key ID
|
||||
SecretKey string // AWS secret access key
|
||||
Region string // e.g. "us-east-1"
|
||||
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
|
||||
Bucket string // S3 bucket name
|
||||
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
|
||||
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
|
||||
HTTPClient *http.Client // if nil, http.DefaultClient is used
|
||||
}
|
||||
|
||||
// New creates a new S3 client from the given Config.
|
||||
func New(config *Config) *Client {
|
||||
return &Client{
|
||||
AccessKey: config.AccessKey,
|
||||
SecretKey: config.SecretKey,
|
||||
Region: config.Region,
|
||||
Endpoint: config.Endpoint,
|
||||
Bucket: config.Bucket,
|
||||
Prefix: config.Prefix,
|
||||
PathStyle: config.PathStyle,
|
||||
}
|
||||
}
|
||||
|
||||
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
|
||||
// configured prefix. The body size does not need to be known in advance.
|
||||
//
|
||||
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
|
||||
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
|
||||
// into memory.
|
||||
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
|
||||
first := make([]byte, partSize)
|
||||
n, err := io.ReadFull(body, first)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
|
||||
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
combined := io.MultiReader(bytes.NewReader(first), body)
|
||||
return c.putObjectMultipart(ctx, key, combined)
|
||||
}
|
||||
|
||||
// GetObject downloads an object. The key is automatically prefixed with the client's configured
|
||||
// prefix. The caller must close the returned ReadCloser.
|
||||
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
|
||||
}
|
||||
if resp.StatusCode/100 != 2 {
|
||||
err := parseError(resp)
|
||||
resp.Body.Close()
|
||||
return nil, 0, err
|
||||
}
|
||||
return resp.Body, resp.ContentLength, nil
|
||||
}
|
||||
|
||||
// DeleteObjects removes multiple objects in a single batch request. Keys are automatically
|
||||
// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the
|
||||
// caller is responsible for batching if needed.
|
||||
//
|
||||
// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present
|
||||
// in the response, they are returned as a combined error.
|
||||
func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<Delete><Quiet>true</Quiet>")
|
||||
for _, key := range keys {
|
||||
body.WriteString("<Object><Key>")
|
||||
xml.EscapeText(&body, []byte(c.objectKey(key)))
|
||||
body.WriteString("</Key></Object>")
|
||||
}
|
||||
body.WriteString("</Delete>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
|
||||
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
|
||||
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
|
||||
|
||||
reqURL := c.bucketURL() + "?delete="
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
req.Header.Set("Content-MD5", contentMD5)
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
|
||||
// S3 may return HTTP 200 with per-key errors in the response body
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
|
||||
}
|
||||
var result deleteResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
|
||||
}
|
||||
if len(result.Errors) > 0 {
|
||||
var msgs []string
|
||||
for _, e := range result.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message))
|
||||
}
|
||||
return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
|
||||
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
|
||||
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
|
||||
query := url.Values{"list-type": {"2"}}
|
||||
if prefix := c.prefixForList(); prefix != "" {
|
||||
query.Set("prefix", prefix)
|
||||
}
|
||||
if continuationToken != "" {
|
||||
query.Set("continuation-token", continuationToken)
|
||||
}
|
||||
if maxKeys > 0 {
|
||||
query.Set("max-keys", strconv.Itoa(maxKeys))
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects: %w", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
|
||||
}
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
}
|
||||
var result listObjectsV2Response
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects XML: %w", err)
|
||||
}
|
||||
objects := make([]Object, len(result.Contents))
|
||||
for i, obj := range result.Contents {
|
||||
objects[i] = Object(obj)
|
||||
}
|
||||
return &ListResult{
|
||||
Objects: objects,
|
||||
IsTruncated: result.IsTruncated,
|
||||
NextContinuationToken: result.NextContinuationToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListAllObjects returns all objects under the client's configured prefix by paginating through
|
||||
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
|
||||
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
|
||||
const maxPages = 10000
|
||||
var all []Object
|
||||
var token string
|
||||
for page := 0; page < maxPages; page++ {
|
||||
result, err := c.ListObjects(ctx, token, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, result.Objects...)
|
||||
if !result.IsTruncated {
|
||||
return all, nil
|
||||
}
|
||||
token = result.NextContinuationToken
|
||||
}
|
||||
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
|
||||
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject request: %w", err)
|
||||
}
|
||||
req.ContentLength = size
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
|
||||
// chunks, uploading each as a separate part. This allows uploading without knowing the total
|
||||
// body size in advance.
|
||||
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
||||
// Step 1: Initiate multipart upload
|
||||
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Upload parts
|
||||
var parts []completedPart
|
||||
buf := make([]byte, partSize)
|
||||
partNumber := 1
|
||||
for {
|
||||
n, err := io.ReadFull(body, buf)
|
||||
if n > 0 {
|
||||
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
|
||||
if uploadErr != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return uploadErr
|
||||
}
|
||||
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
|
||||
partNumber++
|
||||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
|
||||
}
|
||||
|
||||
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
|
||||
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
|
||||
reqURL := c.objectURL(fullKey) + "?uploads"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = 0
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
|
||||
}
|
||||
var result initiateMultipartUploadResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
|
||||
}
|
||||
return result.UploadID, nil
|
||||
}
|
||||
|
||||
// uploadPart uploads a single part of a multipart upload and returns the ETag.
|
||||
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(data))
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
etag := resp.Header.Get("ETag")
|
||||
return etag, nil
|
||||
}
|
||||
|
||||
// completeMultipartUpload finalizes a multipart upload with the given parts.
|
||||
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<CompleteMultipartUpload>")
|
||||
for _, p := range parts {
|
||||
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
|
||||
}
|
||||
body.WriteString("</CompleteMultipartUpload>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
// Read response body to check for errors (S3 can return 200 with an error body)
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
|
||||
}
|
||||
// Check if the response contains an error
|
||||
var errResp ErrorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
errResp.StatusCode = resp.StatusCode
|
||||
return &errResp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
|
||||
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
|
||||
func (c *Client) signV4(req *http.Request, payloadHash string) {
|
||||
now := time.Now().UTC()
|
||||
datestamp := now.Format("20060102")
|
||||
amzDate := now.Format("20060102T150405Z")
|
||||
|
||||
// Required headers
|
||||
req.Header.Set("Host", c.hostHeader())
|
||||
req.Header.Set("X-Amz-Date", amzDate)
|
||||
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
|
||||
|
||||
// Canonical headers (all headers we set, sorted by lowercase key)
|
||||
signedKeys := make([]string, 0, len(req.Header))
|
||||
canonHeaders := make(map[string]string, len(req.Header))
|
||||
for k := range req.Header {
|
||||
lk := strings.ToLower(k)
|
||||
signedKeys = append(signedKeys, lk)
|
||||
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
|
||||
}
|
||||
sort.Strings(signedKeys)
|
||||
signedHeadersStr := strings.Join(signedKeys, ";")
|
||||
var chBuf strings.Builder
|
||||
for _, k := range signedKeys {
|
||||
chBuf.WriteString(k)
|
||||
chBuf.WriteByte(':')
|
||||
chBuf.WriteString(canonHeaders[k])
|
||||
chBuf.WriteByte('\n')
|
||||
}
|
||||
|
||||
// Canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
req.Method,
|
||||
canonicalURI(req.URL),
|
||||
canonicalQueryString(req.URL.Query()),
|
||||
chBuf.String(),
|
||||
signedHeadersStr,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// String to sign
|
||||
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
|
||||
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
|
||||
|
||||
// Signing key
|
||||
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
|
||||
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
|
||||
[]byte(c.Region)),
|
||||
[]byte("s3")),
|
||||
[]byte("aws4_request"))
|
||||
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
req.Header.Set("Authorization", fmt.Sprintf(
|
||||
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
c.AccessKey, credentialScope, signedHeadersStr, signature,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Client) httpClient() *http.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// objectKey prepends the configured prefix to the given key.
|
||||
func (c *Client) objectKey(key string) string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/" + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// prefixForList returns the prefix to use in ListObjectsV2 requests,
|
||||
// with a trailing slash so that only objects under the prefix directory are returned.
|
||||
func (c *Client) prefixForList() string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// bucketURL returns the base URL for bucket-level operations.
|
||||
func (c *Client) bucketURL() string {
|
||||
if c.PathStyle {
|
||||
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
|
||||
}
|
||||
|
||||
// objectURL returns the full URL for an object (key should already include the prefix).
|
||||
// Each path segment is URI-encoded to handle special characters in keys.
|
||||
func (c *Client) objectURL(key string) string {
|
||||
segments := strings.Split(key, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = uriEncode(seg)
|
||||
}
|
||||
return c.bucketURL() + "/" + strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
// hostHeader returns the value for the Host header.
|
||||
func (c *Client) hostHeader() string {
|
||||
if c.PathStyle {
|
||||
return c.Endpoint
|
||||
}
|
||||
return c.Bucket + "." + c.Endpoint
|
||||
}
|
||||
860
s3/client_test.go
Normal file
860
s3/client_test.go
Normal file
@@ -0,0 +1,860 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Mock S3 server ---
|
||||
//
|
||||
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() (*httptest.Server, *mockS3Server) {
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m), m
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[path] = body
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(body)
|
||||
}
|
||||
|
||||
type listObjectsResponse struct {
|
||||
XMLName xml.Name `xml:"ListBucketResult"`
|
||||
Contents []listObject `xml:"Contents"`
|
||||
// Pagination support
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextContinuationToken string `xml:"NextContinuationToken"`
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
// bucketPath is just the bucket name
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Objects []struct {
|
||||
Key string `xml:"Key"`
|
||||
} `xml:"Object"`
|
||||
}
|
||||
if err := xml.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
for _, obj := range req.Objects {
|
||||
delete(m.objects, bucketPath+"/"+obj.Key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
|
||||
prefix := r.URL.Query().Get("prefix")
|
||||
contToken := r.URL.Query().Get("continuation-token")
|
||||
|
||||
m.mu.RLock()
|
||||
var allKeys []string
|
||||
for key := range m.objects {
|
||||
objKey := strings.TrimPrefix(key, bucketPath+"/")
|
||||
if objKey == key {
|
||||
continue // different bucket
|
||||
}
|
||||
if prefix == "" || strings.HasPrefix(objKey, prefix) {
|
||||
allKeys = append(allKeys, objKey)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
sort.Strings(allKeys)
|
||||
|
||||
// Simple continuation token: it's the key to start after
|
||||
startIdx := 0
|
||||
if contToken != "" {
|
||||
for i, k := range allKeys {
|
||||
if k == contToken {
|
||||
startIdx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxKeys := 1000
|
||||
if mk := r.URL.Query().Get("max-keys"); mk != "" {
|
||||
fmt.Sscanf(mk, "%d", &maxKeys)
|
||||
}
|
||||
|
||||
endIdx := startIdx + maxKeys
|
||||
truncated := false
|
||||
nextToken := ""
|
||||
if endIdx < len(allKeys) {
|
||||
truncated = true
|
||||
nextToken = allKeys[endIdx-1]
|
||||
allKeys = allKeys[startIdx:endIdx]
|
||||
} else {
|
||||
allKeys = allKeys[startIdx:]
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
var contents []listObject
|
||||
for _, objKey := range allKeys {
|
||||
body := m.objects[bucketPath+"/"+objKey]
|
||||
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
resp := listObjectsResponse{
|
||||
Contents: contents,
|
||||
IsTruncated: truncated,
|
||||
NextContinuationToken: nextToken,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
xml.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) objectCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.objects)
|
||||
}
|
||||
|
||||
// --- Helper to create a test client pointing at mock server ---
|
||||
|
||||
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
return &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
}
|
||||
}
|
||||
|
||||
// --- URL parsing tests ---
|
||||
|
||||
func TestParseURL_Success(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "attachments", cfg.Prefix)
|
||||
require.Equal(t, "us-east-1", cfg.Region)
|
||||
require.Equal(t, "AKID", cfg.AccessKey)
|
||||
require.Equal(t, "SECRET", cfg.SecretKey)
|
||||
require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint)
|
||||
require.False(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_NoPrefix(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "", cfg.Prefix)
|
||||
}
|
||||
|
||||
func TestParseURL_WithEndpoint(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "prefix", cfg.Prefix)
|
||||
require.Equal(t, "s3.example.com", cfg.Endpoint)
|
||||
require.True(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_EndpointHTTP(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "localhost:9000", cfg.Endpoint)
|
||||
require.True(t, cfg.PathStyle)
|
||||
}
|
||||
|
||||
func TestParseURL_EndpointTrailingSlash(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "s3.example.com", cfg.Endpoint)
|
||||
}
|
||||
|
||||
func TestParseURL_NestedPrefix(t *testing.T) {
|
||||
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
require.Equal(t, "a/b/c", cfg.Prefix)
|
||||
}
|
||||
|
||||
func TestParseURL_MissingRegion(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID:SECRET@my-bucket")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "region")
|
||||
}
|
||||
|
||||
func TestParseURL_MissingCredentials(t *testing.T) {
|
||||
_, err := ParseURL("s3://my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access key")
|
||||
}
|
||||
|
||||
func TestParseURL_MissingSecretKey(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID@my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "secret key")
|
||||
}
|
||||
|
||||
func TestParseURL_WrongScheme(t *testing.T) {
|
||||
_, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "scheme")
|
||||
}
|
||||
|
||||
func TestParseURL_EmptyBucket(t *testing.T) {
|
||||
_, err := ParseURL("s3://AKID:SECRET@?region=us-east-1")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "bucket")
|
||||
}
|
||||
|
||||
// --- Unit tests: URL construction ---
|
||||
|
||||
func TestClient_BucketURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
|
||||
}
|
||||
|
||||
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "s3.example.com", c.hostHeader())
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
|
||||
}
|
||||
|
||||
func TestClient_ObjectKey(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
require.Equal(t, "attachments/file123", c.objectKey("file123"))
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
require.Equal(t, "file123", c2.objectKey("file123"))
|
||||
}
|
||||
|
||||
func TestClient_PrefixForList(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
require.Equal(t, "attachments/", c.prefixForList())
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
require.Equal(t, "", c2.prefixForList())
|
||||
}
|
||||
|
||||
// --- Integration tests using mock S3 server ---
|
||||
|
||||
func TestClient_PutGetObject(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
reader, size, err := client.GetObject(ctx, "test-key")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "hello world", string(data))
|
||||
}
|
||||
|
||||
func TestClient_PutGetObject_WithPrefix(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "test-key")
|
||||
require.Nil(t, err)
|
||||
data, _ := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Equal(t, "hello", string(data))
|
||||
}
|
||||
|
||||
func TestClient_GetObject_NotFound(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
_, _, err := client.GetObject(context.Background(), "nonexistent")
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
}
|
||||
|
||||
func TestClient_DeleteObjects(t *testing.T) {
|
||||
server, mock := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put several objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, 5, mock.objectCount())
|
||||
|
||||
// Delete some
|
||||
err := client.DeleteObjects(ctx, []string{"key-1", "key-3"})
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, mock.objectCount())
|
||||
|
||||
// Verify deleted ones are gone
|
||||
_, _, err = client.GetObject(ctx, "key-1")
|
||||
require.Error(t, err)
|
||||
_, _, err = client.GetObject(ctx, "key-3")
|
||||
require.Error(t, err)
|
||||
|
||||
// Verify remaining ones are still there
|
||||
reader, _, err := client.GetObject(ctx, "key-0")
|
||||
require.Nil(t, err)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func TestClient_ListObjects(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Client with prefix "pfx": list should only return objects under pfx/
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
for i := 0; i < 3; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Also put an object outside the prefix using a no-prefix client
|
||||
clientNoPrefix := newTestClient(server, "my-bucket", "")
|
||||
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
|
||||
require.Nil(t, err)
|
||||
|
||||
// List with prefix client: should only see 3
|
||||
result, err := client.ListObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 3)
|
||||
require.False(t, result.IsTruncated)
|
||||
|
||||
// List with no-prefix client: should see all 4
|
||||
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 4)
|
||||
}
|
||||
|
||||
func TestClient_ListObjects_Pagination(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Put 5 objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// List with max-keys=2
|
||||
result, err := client.ListObjects(ctx, "", 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 2)
|
||||
require.True(t, result.IsTruncated)
|
||||
require.NotEmpty(t, result.NextContinuationToken)
|
||||
|
||||
// Get next page
|
||||
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result2.Objects, 2)
|
||||
require.True(t, result2.IsTruncated)
|
||||
|
||||
// Get last page
|
||||
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result3.Objects, 1)
|
||||
require.False(t, result3.IsTruncated)
|
||||
}
|
||||
|
||||
func TestClient_ListAllObjects(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, 10)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_LargeBody(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1 MB object
|
||||
data := make([]byte, 1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "large", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "large")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
|
||||
data := make([]byte, 12*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "multipart")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(12*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ExactPartSize(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
|
||||
data := make([]byte, 5*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "exact")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(5*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_NestedKey(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
|
||||
require.Nil(t, err)
|
||||
data, _ := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Equal(t, "nested", string(data))
|
||||
}
|
||||
|
||||
// --- Scale test: 20k objects (ntfy-adjacent) ---
|
||||
|
||||
func TestClient_ListAllObjects_20k(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping 20k object test in short mode")
|
||||
}
|
||||
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "attachments")
|
||||
|
||||
ctx := context.Background()
|
||||
const numObjects = 20000
|
||||
const batchSize = 500
|
||||
|
||||
// Insert 20k objects in batches to keep it fast
|
||||
for batch := 0; batch < numObjects/batchSize; batch++ {
|
||||
for i := 0; i < batchSize; i++ {
|
||||
idx := batch*batchSize + i
|
||||
key := fmt.Sprintf("%08d", idx)
|
||||
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// List all 20k objects with pagination
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, numObjects)
|
||||
|
||||
// Verify total size
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
require.Equal(t, int64(numObjects), totalSize)
|
||||
|
||||
// Delete 1000 objects (simulating attachment expiry cleanup)
|
||||
keys := make([]string, 1000)
|
||||
for i := range keys {
|
||||
keys[i] = fmt.Sprintf("%08d", i)
|
||||
}
|
||||
err = client.DeleteObjects(ctx, keys)
|
||||
require.Nil(t, err)
|
||||
|
||||
// List again: should have 19000
|
||||
objects, err = client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, numObjects-1000)
|
||||
}
|
||||
|
||||
// --- Real S3 integration test ---
|
||||
//
|
||||
// Set the following environment variables to run this test against a real S3 bucket:
|
||||
//
|
||||
// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET
|
||||
//
|
||||
// Optional:
|
||||
//
|
||||
// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com")
|
||||
// S3_PATH_STYLE: set to "true" for path-style addressing
|
||||
// S3_PREFIX: key prefix to use (default: "ntfy-s3-test")
|
||||
func TestClient_RealBucket(t *testing.T) {
|
||||
accessKey := os.Getenv("S3_ACCESS_KEY")
|
||||
secretKey := os.Getenv("S3_SECRET_KEY")
|
||||
region := os.Getenv("S3_REGION")
|
||||
bucket := os.Getenv("S3_BUCKET")
|
||||
|
||||
if accessKey == "" || secretKey == "" || region == "" || bucket == "" {
|
||||
t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET")
|
||||
}
|
||||
|
||||
endpoint := os.Getenv("S3_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
|
||||
}
|
||||
pathStyle := os.Getenv("S3_PATH_STYLE") == "true"
|
||||
prefix := os.Getenv("S3_PREFIX")
|
||||
if prefix == "" {
|
||||
prefix = "ntfy-s3-test"
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
Endpoint: endpoint,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Clean up any leftover objects from previous runs
|
||||
existing, err := client.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
if len(existing) > 0 {
|
||||
keys := make([]string, len(existing))
|
||||
for i, obj := range existing {
|
||||
// Strip the prefix since DeleteObjects will re-add it
|
||||
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
|
||||
}
|
||||
// Batch delete in groups of 1000
|
||||
for i := 0; i < len(keys); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
err := client.DeleteObjects(ctx, keys[i:end])
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("PutGetDelete", func(t *testing.T) {
|
||||
key := "test-object"
|
||||
content := "hello from ntfy s3 test"
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, key, strings.NewReader(content))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(len(content)), size)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, content, string(data))
|
||||
|
||||
// Delete
|
||||
err = client.DeleteObjects(ctx, []string{key})
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get after delete should fail
|
||||
_, _, err = client.GetObject(ctx, key)
|
||||
require.Error(t, err)
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ListObjects", func(t *testing.T) {
|
||||
// Use a sub-prefix client for isolation
|
||||
listClient := &Client{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
Endpoint: endpoint,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix + "/list-test",
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
|
||||
// Put 10 objects
|
||||
for i := 0; i < 10; i++ {
|
||||
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// List
|
||||
objects, err := listClient.ListAllObjects(ctx)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, objects, 10)
|
||||
|
||||
// Clean up
|
||||
keys := make([]string, 10)
|
||||
for i := range keys {
|
||||
keys[i] = fmt.Sprintf("%d", i)
|
||||
}
|
||||
err = listClient.DeleteObjects(ctx, keys)
|
||||
require.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("LargeObject", func(t *testing.T) {
|
||||
key := "large-object"
|
||||
data := make([]byte, 5*1024*1024) // 5 MB
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
err := client.PutObject(ctx, key, bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(len(data)), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
|
||||
err = client.DeleteObjects(ctx, []string{key})
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
76
s3/types.go
Normal file
76
s3/types.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package s3
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
|
||||
type Config struct {
|
||||
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
|
||||
PathStyle bool
|
||||
Bucket string
|
||||
Prefix string
|
||||
Region string
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
}
|
||||
|
||||
// Object represents an S3 object returned by list operations.
|
||||
type Object struct {
|
||||
Key string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// ListResult holds the response from a ListObjectsV2 call.
|
||||
type ListResult struct {
|
||||
Objects []Object
|
||||
IsTruncated bool
|
||||
NextContinuationToken string
|
||||
}
|
||||
|
||||
// ErrorResponse is returned when S3 responds with a non-2xx status code.
|
||||
type ErrorResponse struct {
|
||||
StatusCode int
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Body string `xml:"-"` // raw response body
|
||||
}
|
||||
|
||||
func (e *ErrorResponse) Error() string {
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// listObjectsV2Response is the XML response from S3 ListObjectsV2
|
||||
type listObjectsV2Response struct {
|
||||
Contents []listObject `xml:"Contents"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextContinuationToken string `xml:"NextContinuationToken"`
|
||||
}
|
||||
|
||||
type listObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
}
|
||||
|
||||
// deleteResult is the XML response from S3 DeleteObjects
|
||||
type deleteResult struct {
|
||||
Errors []deleteError `xml:"Error"`
|
||||
}
|
||||
|
||||
type deleteError struct {
|
||||
Key string `xml:"Key"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
}
|
||||
|
||||
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
|
||||
type initiateMultipartUploadResult struct {
|
||||
UploadID string `xml:"UploadId"`
|
||||
}
|
||||
|
||||
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
|
||||
type completedPart struct {
|
||||
PartNumber int
|
||||
ETag string
|
||||
}
|
||||
166
s3/util.go
Normal file
166
s3/util.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
|
||||
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
|
||||
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
|
||||
unsignedPayload = "UNSIGNED-PAYLOAD"
|
||||
|
||||
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
|
||||
maxResponseBytes = 10 * 1024 * 1024
|
||||
|
||||
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
|
||||
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
|
||||
// part size of 5 MB for all parts except the last.
|
||||
partSize = 5 * 1024 * 1024
|
||||
)
|
||||
|
||||
// ParseURL parses an S3 URL of the form:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
//
|
||||
// When endpoint is specified, path-style addressing is enabled automatically.
|
||||
func ParseURL(s3URL string) (*Config, error) {
|
||||
u, err := url.Parse(s3URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme != "s3" {
|
||||
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return nil, fmt.Errorf("s3: bucket name must be specified as host")
|
||||
}
|
||||
bucket := u.Host
|
||||
prefix := strings.TrimPrefix(u.Path, "/")
|
||||
accessKey := u.User.Username()
|
||||
secretKey, _ := u.User.Password()
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
|
||||
}
|
||||
region := u.Query().Get("region")
|
||||
if region == "" {
|
||||
return nil, fmt.Errorf("s3: region query parameter is required")
|
||||
}
|
||||
endpointParam := u.Query().Get("endpoint")
|
||||
var endpoint string
|
||||
var pathStyle bool
|
||||
if endpointParam != "" {
|
||||
// Custom endpoint: strip scheme prefix to extract host[:port]
|
||||
ep := strings.TrimRight(endpointParam, "/")
|
||||
ep = strings.TrimPrefix(ep, "https://")
|
||||
ep = strings.TrimPrefix(ep, "http://")
|
||||
endpoint = ep
|
||||
pathStyle = true
|
||||
} else {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
|
||||
pathStyle = false
|
||||
}
|
||||
return &Config{
|
||||
Endpoint: endpoint,
|
||||
PathStyle: pathStyle,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
Region: region,
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseError reads an S3 error response and returns an *ErrorResponse.
|
||||
func parseError(resp *http.Response) error {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: reading error response: %w", err)
|
||||
}
|
||||
return parseErrorFromBytes(resp.StatusCode, body)
|
||||
}
|
||||
|
||||
func parseErrorFromBytes(statusCode int, body []byte) error {
|
||||
errResp := &ErrorResponse{
|
||||
StatusCode: statusCode,
|
||||
Body: string(body),
|
||||
}
|
||||
// Try to parse XML error; if it fails, we still have StatusCode and Body
|
||||
_ = xml.Unmarshal(body, errResp)
|
||||
return errResp
|
||||
}
|
||||
|
||||
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
|
||||
// percent-encoded per RFC 3986; forward slashes are preserved.
|
||||
func canonicalURI(u *url.URL) string {
|
||||
p := u.Path
|
||||
if p == "" {
|
||||
return "/"
|
||||
}
|
||||
segments := strings.Split(p, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = uriEncode(seg)
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
// canonicalQueryString builds the query string for the canonical request. Keys and values
|
||||
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
|
||||
func canonicalQueryString(values url.Values) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(values))
|
||||
for k := range values {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var pairs []string
|
||||
for _, k := range keys {
|
||||
ek := uriEncode(k)
|
||||
vs := make([]string, len(values[k]))
|
||||
copy(vs, values[k])
|
||||
sort.Strings(vs)
|
||||
for _, v := range vs {
|
||||
pairs = append(pairs, ek+"="+uriEncode(v))
|
||||
}
|
||||
}
|
||||
return strings.Join(pairs, "&")
|
||||
}
|
||||
|
||||
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
|
||||
// characters (A-Z a-z 0-9 - _ . ~).
|
||||
func uriEncode(s string) string {
|
||||
var buf strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
buf.WriteByte(b)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", b)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
181
s3/util_test.go
Normal file
181
s3/util_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestURIEncode(t *testing.T) {
|
||||
// Unreserved characters are not encoded
|
||||
require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz"))
|
||||
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
|
||||
require.Equal(t, "0123456789", uriEncode("0123456789"))
|
||||
require.Equal(t, "-_.~", uriEncode("-_.~"))
|
||||
|
||||
// Spaces use %20, not +
|
||||
require.Equal(t, "hello%20world", uriEncode("hello world"))
|
||||
|
||||
// Slashes are encoded (canonicalURI handles slash splitting separately)
|
||||
require.Equal(t, "a%2Fb", uriEncode("a/b"))
|
||||
|
||||
// Special characters
|
||||
require.Equal(t, "%2B", uriEncode("+"))
|
||||
require.Equal(t, "%3D", uriEncode("="))
|
||||
require.Equal(t, "%26", uriEncode("&"))
|
||||
require.Equal(t, "%40", uriEncode("@"))
|
||||
require.Equal(t, "%23", uriEncode("#"))
|
||||
|
||||
// Mixed
|
||||
require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt"))
|
||||
require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes"))
|
||||
|
||||
// Empty string
|
||||
require.Equal(t, "", uriEncode(""))
|
||||
}
|
||||
|
||||
func TestCanonicalURI(t *testing.T) {
|
||||
// Simple path
|
||||
u, _ := url.Parse("https://example.com/bucket/key")
|
||||
require.Equal(t, "/bucket/key", canonicalURI(u))
|
||||
|
||||
// Root path
|
||||
u, _ = url.Parse("https://example.com/")
|
||||
require.Equal(t, "/", canonicalURI(u))
|
||||
|
||||
// Empty path
|
||||
u, _ = url.Parse("https://example.com")
|
||||
require.Equal(t, "/", canonicalURI(u))
|
||||
|
||||
// Path with special characters
|
||||
u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces")
|
||||
require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u))
|
||||
|
||||
// Nested path
|
||||
u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt")
|
||||
require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u))
|
||||
}
|
||||
|
||||
func TestCanonicalQueryString(t *testing.T) {
|
||||
// Multiple keys sorted alphabetically
|
||||
vals := url.Values{
|
||||
"prefix": {"test/"},
|
||||
"list-type": {"2"},
|
||||
}
|
||||
require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals))
|
||||
|
||||
// Empty values
|
||||
require.Equal(t, "", canonicalQueryString(url.Values{}))
|
||||
|
||||
// Single key
|
||||
require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}}))
|
||||
|
||||
// Key with multiple values (sorted)
|
||||
vals = url.Values{"key": {"b", "a"}}
|
||||
require.Equal(t, "key=a&key=b", canonicalQueryString(vals))
|
||||
|
||||
// Keys requiring encoding
|
||||
vals = url.Values{"continuation-token": {"abc+def"}}
|
||||
require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals))
|
||||
}
|
||||
|
||||
func TestSHA256Hex(t *testing.T) {
|
||||
// SHA-256 of empty string
|
||||
require.Equal(t, emptyPayloadHash, sha256Hex([]byte("")))
|
||||
|
||||
// SHA-256 of known value
|
||||
require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello")))
|
||||
}
|
||||
|
||||
func TestHmacSHA256(t *testing.T) {
|
||||
// Known test vector: HMAC-SHA256("key", "message")
|
||||
result := hmacSHA256([]byte("key"), []byte("message"))
|
||||
require.Len(t, result, 32) // SHA-256 produces 32 bytes
|
||||
require.NotEqual(t, make([]byte, 32), result)
|
||||
|
||||
// Same inputs should produce same output
|
||||
result2 := hmacSHA256([]byte("key"), []byte("message"))
|
||||
require.Equal(t, result, result2)
|
||||
|
||||
// Different inputs should produce different output
|
||||
result3 := hmacSHA256([]byte("different-key"), []byte("message"))
|
||||
require.NotEqual(t, result, result3)
|
||||
}
|
||||
|
||||
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
|
||||
c := &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
|
||||
// All required SigV4 headers must be set
|
||||
require.NotEmpty(t, req.Header.Get("Host"))
|
||||
require.NotEmpty(t, req.Header.Get("X-Amz-Date"))
|
||||
require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256"))
|
||||
|
||||
// Authorization header must have correct format
|
||||
auth := req.Header.Get("Authorization")
|
||||
require.Contains(t, auth, "AWS4-HMAC-SHA256")
|
||||
require.Contains(t, auth, "Credential=AKID/")
|
||||
require.Contains(t, auth, "/us-east-1/s3/aws4_request")
|
||||
require.Contains(t, auth, "SignedHeaders=")
|
||||
require.Contains(t, auth, "Signature=")
|
||||
}
|
||||
|
||||
func TestSignV4_UnsignedPayload(t *testing.T) {
|
||||
c := &Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, unsignedPayload)
|
||||
|
||||
require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256"))
|
||||
}
|
||||
|
||||
func TestSignV4_DifferentRegions(t *testing.T) {
|
||||
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
|
||||
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
|
||||
c1.signV4(req1, emptyPayloadHash)
|
||||
|
||||
req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil)
|
||||
c2.signV4(req2, emptyPayloadHash)
|
||||
|
||||
// Different regions should produce different signatures
|
||||
require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func TestParseError_XMLResponse(t *testing.T) {
|
||||
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
|
||||
err := parseErrorFromBytes(404, xmlBody)
|
||||
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 404, errResp.StatusCode)
|
||||
require.Equal(t, "NoSuchKey", errResp.Code)
|
||||
require.Equal(t, "The specified key does not exist.", errResp.Message)
|
||||
}
|
||||
|
||||
func TestParseError_NonXMLResponse(t *testing.T) {
|
||||
err := parseErrorFromBytes(500, []byte("internal server error"))
|
||||
|
||||
var errResp *ErrorResponse
|
||||
require.ErrorAs(t, err, &errResp)
|
||||
require.Equal(t, 500, errResp.StatusCode)
|
||||
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
|
||||
require.Contains(t, errResp.Body, "internal server error")
|
||||
}
|
||||
@@ -112,6 +112,7 @@ type Config struct {
|
||||
AuthBcryptCost int
|
||||
AuthStatsQueueWriterInterval time.Duration
|
||||
AttachmentCacheDir string
|
||||
AttachmentS3URL string
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
AttachmentExpiryDuration time.Duration
|
||||
|
||||
@@ -142,6 +142,7 @@ var (
|
||||
errHTTPBadRequestTemplateFileNotFound = &errHTTP{40047, http.StatusBadRequest, "invalid request: template file not found", "https://ntfy.sh/docs/publish/#message-templating", nil}
|
||||
errHTTPBadRequestTemplateFileInvalid = &errHTTP{40048, http.StatusBadRequest, "invalid request: template file invalid", "https://ntfy.sh/docs/publish/#message-templating", nil}
|
||||
errHTTPBadRequestSequenceIDInvalid = &errHTTP{40049, http.StatusBadRequest, "invalid request: sequence ID invalid", "https://ntfy.sh/docs/publish/#updating-deleting-notifications", nil}
|
||||
errHTTPBadRequestEmailAddressInvalid = &errHTTP{40050, http.StatusBadRequest, "invalid request: invalid e-mail address", "https://ntfy.sh/docs/publish/#e-mail-notifications", nil}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
oneKilobyteArray = make([]byte, 1024)
|
||||
)
|
||||
|
||||
func TestFileCache_Write_Success(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
|
||||
require.Equal(t, int64(11), c.Size())
|
||||
require.Equal(t, int64(10229), c.Remaining())
|
||||
}
|
||||
|
||||
func TestFileCache_Write_Remove_Success(t *testing.T) {
|
||||
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
|
||||
for i := 0; i < 10; i++ { // 10x999 = 9990
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(999), size)
|
||||
}
|
||||
require.Equal(t, int64(9990), c.Size())
|
||||
require.Equal(t, int64(250), c.Remaining())
|
||||
require.FileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk5")
|
||||
|
||||
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.NoFileExists(t, dir+"/abcdefghijk5")
|
||||
require.Equal(t, int64(7992), c.Size())
|
||||
require.Equal(t, int64(2248), c.Remaining())
|
||||
}
|
||||
|
||||
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
for i := 0; i < 10; i++ {
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024), size)
|
||||
}
|
||||
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkX")
|
||||
}
|
||||
|
||||
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, c := newTestFileCache(t)
|
||||
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkl")
|
||||
}
|
||||
|
||||
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
|
||||
dir = t.TempDir()
|
||||
cache, err := newFileCache(dir, 10*1024)
|
||||
require.Nil(t, err)
|
||||
return dir, cache
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, f string) string {
|
||||
b, err := os.ReadFile(f)
|
||||
require.Nil(t, err)
|
||||
return string(b)
|
||||
}
|
||||
@@ -24,7 +24,6 @@ const (
|
||||
tagSMTP = "smtp" // Receive email
|
||||
tagEmail = "email" // Send email
|
||||
tagTwilio = "twilio"
|
||||
tagFileCache = "file_cache"
|
||||
tagMessageCache = "message_cache"
|
||||
tagStripe = "stripe"
|
||||
tagAccount = "account"
|
||||
@@ -36,7 +35,7 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage}
|
||||
normalErrorCodes = []int{http.StatusNotFound, http.StatusBadRequest, http.StatusTooManyRequests, http.StatusUnauthorized, http.StatusForbidden, http.StatusInsufficientStorage, http.StatusRequestEntityTooLarge}
|
||||
rateLimitingErrorCodes = []int{http.StatusTooManyRequests, http.StatusRequestEntityTooLarge}
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/attachment"
|
||||
"heckel.io/ntfy/v2/db"
|
||||
"heckel.io/ntfy/v2/db/pg"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
@@ -64,7 +65,7 @@ type Server struct {
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *message.Cache // Database that stores the messages
|
||||
webPush *webpush.Store // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
fileCache attachment.Store // Attachment store (file system or S3)
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
|
||||
@@ -122,6 +123,7 @@ var (
|
||||
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
|
||||
urlRegex = regexp.MustCompile(`^https?://`)
|
||||
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
|
||||
emailAddressRegex = regexp.MustCompile(`^[^\s,;]+@[^\s,;]+$`)
|
||||
|
||||
//go:embed site
|
||||
webFs embed.FS
|
||||
@@ -227,12 +229,9 @@ func New(conf *Config) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fileCache *fileCache
|
||||
if conf.AttachmentCacheDir != "" {
|
||||
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileCache, err := createAttachmentStore(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" || pool != nil {
|
||||
@@ -301,6 +300,15 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
|
||||
return message.NewMemStore()
|
||||
}
|
||||
|
||||
func createAttachmentStore(conf *Config) (attachment.Store, error) {
|
||||
if conf.AttachmentS3URL != "" {
|
||||
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
|
||||
} else if conf.AttachmentCacheDir != "" {
|
||||
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
||||
// a manager go routine to print stats and prune messages.
|
||||
func (s *Server) Run() error {
|
||||
@@ -752,7 +760,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor)
|
||||
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
|
||||
// can associate the download bandwidth with the uploader.
|
||||
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.AttachmentCacheDir == "" {
|
||||
if s.fileCache == nil {
|
||||
return errHTTPInternalError
|
||||
}
|
||||
matches := fileRegex.FindStringSubmatch(r.URL.Path)
|
||||
@@ -760,16 +768,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
return errHTTPInternalErrorInvalidPath
|
||||
}
|
||||
messageID := matches[1]
|
||||
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
|
||||
stat, err := os.Stat(file)
|
||||
reader, size, err := s.fileCache.Read(messageID)
|
||||
if err != nil {
|
||||
return errHTTPNotFound.Fields(log.Context{
|
||||
"message_id": messageID,
|
||||
"error_context": "filesystem",
|
||||
"error_context": "attachment_store",
|
||||
})
|
||||
}
|
||||
defer reader.Close()
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
|
||||
if r.Method == http.MethodHead {
|
||||
return nil
|
||||
}
|
||||
@@ -805,19 +813,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
} else if m.Sender.IsValid() {
|
||||
bandwidthVisitor = s.visitor(m.Sender, nil)
|
||||
}
|
||||
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
|
||||
if !bandwidthVisitor.BandwidthAllowed(size) {
|
||||
return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
|
||||
}
|
||||
// Actually send file
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if m.Attachment.Name != "" {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
|
||||
}
|
||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -880,6 +883,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
|
||||
if m.Message == "" {
|
||||
m.Message = emptyMessageBody
|
||||
}
|
||||
m.SanitizeUTF8()
|
||||
delayed := m.Time > time.Now().Unix()
|
||||
ev := logvrm(v, r, m).
|
||||
Tag(tagPublish).
|
||||
@@ -1162,6 +1166,9 @@ func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bo
|
||||
m.Icon = icon
|
||||
}
|
||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||
if email != "" && !emailAddressRegex.MatchString(email) {
|
||||
return false, false, "", "", "", false, "", errHTTPBadRequestEmailAddressInvalid
|
||||
}
|
||||
if s.smtpSender == nil && email != "" {
|
||||
return false, false, "", "", "", false, "", errHTTPBadRequestEmailDisabled
|
||||
}
|
||||
@@ -1408,7 +1415,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
|
||||
if s.fileCache == nil || s.config.BaseURL == "" {
|
||||
return errHTTPBadRequestAttachmentsDisallowed.With(m)
|
||||
}
|
||||
vinfo, err := v.Info()
|
||||
|
||||
@@ -159,6 +159,7 @@
|
||||
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
|
||||
#
|
||||
# attachment-cache-dir:
|
||||
# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1"
|
||||
# attachment-total-size-limit: "5G"
|
||||
# attachment-file-size-limit: "15M"
|
||||
# attachment-expiry-duration: "3h"
|
||||
|
||||
@@ -3,14 +3,15 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -455,21 +456,8 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
return errHTTPUnauthorized
|
||||
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
|
||||
return errHTTPConflictTopicReserved
|
||||
} else if u.IsUser() {
|
||||
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasReservation {
|
||||
reservations, err := s.userManager.ReservationsCount(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if reservations >= u.Tier.ReservationLimit {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
// Actually add the reservation
|
||||
// Actually add the reservation (with limit check inside the transaction to avoid races)
|
||||
logvr(v, r).
|
||||
Tag(tagAccount).
|
||||
Fields(log.Context{
|
||||
@@ -477,7 +465,14 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
"everyone": everyone.String(),
|
||||
}).
|
||||
Debug("Adding topic reservation")
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
|
||||
var limit int64
|
||||
if u.IsUser() && u.Tier != nil {
|
||||
limit = u.Tier.ReservationLimit
|
||||
}
|
||||
if err := s.userManager.AddReservation(u.Name, req.Topic, everyone, limit); err != nil {
|
||||
if errors.Is(err, user.ErrTooManyReservations) {
|
||||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Kill existing subscribers
|
||||
@@ -530,22 +525,15 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
|
||||
// and marks associated messages for the topics as deleted. This also eventually deletes attachments.
|
||||
// The process relies on the manager to perform the actual deletions (see runManager).
|
||||
func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *visitor, u *user.User, reservationsLimit int64) error {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
removedTopics, err := s.userManager.RemoveExcessReservations(u.Name, reservationsLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) <= reservationsLimit {
|
||||
} else if len(removedTopics) == 0 {
|
||||
logvr(v, r).Tag(tagAccount).Debug("No excess reservations to remove")
|
||||
return nil
|
||||
}
|
||||
topics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
|
||||
topics = append(topics, reservations[i].Topic)
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Info("Removing excess reservations for topics %s", strings.Join(topics, ", "))
|
||||
if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.messageCache.ExpireMessages(topics...); err != nil {
|
||||
logvr(v, r).Tag(tagAccount).Info("Removed excess topic reservations, now removing messages for topics %s", strings.Join(removedTopics, ", "))
|
||||
if err := s.messageCache.ExpireMessages(removedTopics...); err != nil {
|
||||
return err
|
||||
}
|
||||
go s.pruneMessages()
|
||||
|
||||
@@ -503,7 +503,7 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("noadmin1", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll, 0))
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("noadmin2", "pro"))
|
||||
|
||||
@@ -99,6 +99,9 @@ func (s *Server) execManager() {
|
||||
mset(metricUsers, usersCount)
|
||||
mset(metricSubscribers, subscribers)
|
||||
mset(metricTopics, topicsCount)
|
||||
if s.fileCache != nil {
|
||||
mset(metricAttachmentsTotalSize, s.fileCache.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pruneVisitors() {
|
||||
|
||||
@@ -478,8 +478,8 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll, 0))
|
||||
|
||||
// Add billing details
|
||||
u, err := s.userManager.User("phil")
|
||||
@@ -589,7 +589,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
|
||||
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll, 0))
|
||||
|
||||
// Add billing details
|
||||
u, err := s.userManager.User("phil")
|
||||
|
||||
@@ -1543,6 +1543,30 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_PublishEmailAddressInvalid(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
s := newTestServer(t, newTestConfig(t, databaseURL))
|
||||
s.smtpSender = &testMailer{}
|
||||
addresses := []string{
|
||||
"test@example.com, other@example.com",
|
||||
"invalidaddress",
|
||||
"@nope",
|
||||
"nope@",
|
||||
}
|
||||
for _, email := range addresses {
|
||||
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
|
||||
"E-Mail": email,
|
||||
})
|
||||
require.Equal(t, 400, response.Code, "expected 400 for email: %s", email)
|
||||
}
|
||||
// Valid address should succeed
|
||||
response := request(t, s, "PUT", "/mytopic", "success", map[string]string{
|
||||
"E-Mail": "test@example.com",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||
t.Parallel()
|
||||
@@ -4441,3 +4465,88 @@ func TestServer_HandleError_SkipsWriteHeaderOnHijackedConnection(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Publish_InvalidUTF8InBody(t *testing.T) {
|
||||
// All byte sequences from production logs, sent as message body
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
message string
|
||||
}{
|
||||
{"0xc9_0x43", "\xc9Cas du serveur", "\uFFFDCas du serveur"}, // Latin-1 "ÉC"
|
||||
{"0xae", "Product\xae Pro", "Product\uFFFD Pro"}, // Latin-1 "®"
|
||||
{"0xe8_0x6d_0x65", "probl\xe8me critique", "probl\uFFFDme critique"}, // Latin-1 "ème"
|
||||
{"0xb2", "CO\xb2 level high", "CO\uFFFD level high"}, // Latin-1 "²"
|
||||
{"0xe9_0x6d_0x61", "th\xe9matique", "th\uFFFDmatique"}, // Latin-1 "éma"
|
||||
{"0xed_0x64_0x65", "vid\xed\x64eo surveillance", "vid\uFFFDdeo surveillance"}, // Latin-1 "íde"
|
||||
{"0xf3_0x6e_0x3a_0x20", "notificaci\xf3n: alerta", "notificaci\uFFFDn: alerta"}, // Latin-1 "ón: "
|
||||
{"0xb7", "item\xb7value", "item\uFFFDvalue"}, // Latin-1 "·"
|
||||
{"0xa8", "na\xa8ve", "na\uFFFDve"}, // Latin-1 "¨"
|
||||
{"0x00", "hello\x00world", "helloworld"}, // NUL byte
|
||||
{"0xdf_0x64", "gro\xdf\x64ruck", "gro\uFFFDdruck"}, // Latin-1 "ßd"
|
||||
{"0xe4_0x67_0x74", "tr\xe4gt Last", "tr\uFFFDgt Last"}, // Latin-1 "ägt"
|
||||
{"0xe9_0x65_0x20", "journ\xe9\x65 termin\xe9\x65", "journ\uFFFDe termin\uFFFDe"}, // Latin-1 "ée"
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t, ""))
|
||||
|
||||
// Publish via x-message header (the most common path for invalid UTF-8 from HTTP headers)
|
||||
response := request(t, s, "PUT", "/mytopic", "", map[string]string{
|
||||
"X-Message": tc.body,
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.Equal(t, tc.message, msg.Message)
|
||||
|
||||
// Verify it was stored in the cache correctly
|
||||
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
msg = toMessage(t, response.Body.String())
|
||||
require.Equal(t, tc.message, msg.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Publish_InvalidUTF8InTitle(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t, ""))
|
||||
response := request(t, s, "PUT", "/mytopic", "valid body", map[string]string{
|
||||
"Title": "\xc9clipse du syst\xe8me",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "\uFFFDclipse du syst\uFFFDme", msg.Title)
|
||||
require.Equal(t, "valid body", msg.Message)
|
||||
}
|
||||
|
||||
func TestServer_Publish_InvalidUTF8InTags(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t, ""))
|
||||
response := request(t, s, "PUT", "/mytopic", "valid body", map[string]string{
|
||||
"Tags": "probl\xe8me,syst\xe9me",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
msg := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "probl\uFFFDme", msg.Tags[0])
|
||||
require.Equal(t, "syst\uFFFDme", msg.Tags[1])
|
||||
}
|
||||
|
||||
func TestServer_Publish_InvalidUTF8WithFirebase(t *testing.T) {
|
||||
// Verify that sanitization happens before Firebase dispatch, so Firebase
|
||||
// receives clean UTF-8 strings rather than invalid byte sequences
|
||||
sender := newTestFirebaseSender(10)
|
||||
s := newTestServer(t, newTestConfig(t, ""))
|
||||
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
|
||||
|
||||
response := request(t, s, "PUT", "/mytopic", "", map[string]string{
|
||||
"X-Message": "notificaci\xf3n: alerta",
|
||||
"Title": "\xc9clipse",
|
||||
"Tags": "probl\xe8me",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // Firebase publishing happens asynchronously
|
||||
require.Equal(t, 1, len(sender.Messages()))
|
||||
require.Equal(t, "notificaci\uFFFDn: alerta", sender.Messages()[0].Data["message"])
|
||||
require.Equal(t, "\uFFFDclipse", sender.Messages()[0].Data["title"])
|
||||
require.Equal(t, "probl\uFFFDme", sender.Messages()[0].Data["tags"])
|
||||
}
|
||||
|
||||
@@ -65,12 +65,12 @@ const (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BIGINT
|
||||
);
|
||||
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||
INSERT INTO message_stats (key, value) VALUES ('messages', 0) ON CONFLICT (key) DO NOTHING;
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schema_version (store, version) VALUES ('message', 14);
|
||||
INSERT INTO schema_version (store, version) VALUES ('message', 14) ON CONFLICT (store) DO NOTHING;
|
||||
`
|
||||
|
||||
// Initial PostgreSQL schema for user store (from user/manager_postgres_schema.go)
|
||||
@@ -146,7 +146,7 @@ const (
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
INSERT INTO schema_version (store, version) VALUES ('user', 6);
|
||||
INSERT INTO schema_version (store, version) VALUES ('user', 6) ON CONFLICT (store) DO NOTHING;
|
||||
`
|
||||
|
||||
// Initial PostgreSQL schema for web push store (from webpush/store_postgres.go)
|
||||
@@ -174,7 +174,7 @@ const (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schema_version (store, version) VALUES ('webpush', 1);
|
||||
INSERT INTO schema_version (store, version) VALUES ('webpush', 1) ON CONFLICT (store) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
@@ -185,6 +185,7 @@ var flags = []cli.Flag{
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
|
||||
&cli.BoolFlag{Name: "create-schema", Usage: "create initial PostgreSQL schema before importing"},
|
||||
&cli.BoolFlag{Name: "pre-import", Usage: "pre-import messages while ntfy is still running (only imports messages)"},
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -207,10 +208,17 @@ func execImport(c *cli.Context) error {
|
||||
cacheFile := c.String("cache-file")
|
||||
authFile := c.String("auth-file")
|
||||
webPushFile := c.String("web-push-file")
|
||||
preImport := c.Bool("pre-import")
|
||||
|
||||
if databaseURL == "" {
|
||||
return fmt.Errorf("database-url must be set (via --database-url or config file)")
|
||||
}
|
||||
if preImport {
|
||||
if cacheFile == "" {
|
||||
return fmt.Errorf("--cache-file must be set when using --pre-import")
|
||||
}
|
||||
return execPreImport(c, databaseURL, cacheFile)
|
||||
}
|
||||
if cacheFile == "" && authFile == "" && webPushFile == "" {
|
||||
return fmt.Errorf("at least one of --cache-file, --auth-file, or --web-push-file must be set")
|
||||
}
|
||||
@@ -261,7 +269,8 @@ func execImport(c *cli.Context) error {
|
||||
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := importMessages(cacheFile, pgDB); err != nil {
|
||||
sinceTime := maxMessageTime(pgDB)
|
||||
if err := importMessages(cacheFile, pgDB, sinceTime); err != nil {
|
||||
return fmt.Errorf("cannot import messages: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -300,6 +309,54 @@ func execImport(c *cli.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func execPreImport(c *cli.Context, databaseURL, cacheFile string) error {
|
||||
fmt.Println("pgimport - PRE-IMPORT mode (ntfy can keep running)")
|
||||
fmt.Println()
|
||||
fmt.Println("Source:")
|
||||
printSource(" Cache file: ", cacheFile)
|
||||
fmt.Println()
|
||||
fmt.Println("Target:")
|
||||
fmt.Printf(" Database URL: %s\n", maskPassword(databaseURL))
|
||||
fmt.Println()
|
||||
fmt.Println("This will pre-import messages into PostgreSQL while ntfy is still running.")
|
||||
fmt.Println("After this completes, stop ntfy and run pgimport again without --pre-import")
|
||||
fmt.Println("to import remaining messages, users, and web push subscriptions.")
|
||||
fmt.Print("Continue? (y/n): ")
|
||||
|
||||
var answer string
|
||||
fmt.Scanln(&answer)
|
||||
if strings.TrimSpace(strings.ToLower(answer)) != "y" {
|
||||
fmt.Println("Aborted.")
|
||||
return nil
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
pgHost, err := pg.Open(databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
||||
}
|
||||
pgDB := pgHost.DB
|
||||
defer pgDB.Close()
|
||||
|
||||
if c.Bool("create-schema") {
|
||||
if err := createSchema(pgDB, cacheFile, "", ""); err != nil {
|
||||
return fmt.Errorf("cannot create schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := importMessages(cacheFile, pgDB, 0); err != nil {
|
||||
return fmt.Errorf("cannot import messages: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("Pre-import complete. Now stop ntfy and run pgimport again without --pre-import")
|
||||
fmt.Println("to import any remaining messages, users, and web push subscriptions.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSchema(pgDB *sql.DB, cacheFile, authFile, webPushFile string) error {
|
||||
fmt.Println("Creating initial PostgreSQL schema ...")
|
||||
// User schema must be created before message schema, because message_stats and
|
||||
@@ -645,16 +702,41 @@ func importUserPhones(sqlDB, pgDB *sql.DB) (int, error) {
|
||||
|
||||
// Message import
|
||||
|
||||
func importMessages(sqliteFile string, pgDB *sql.DB) error {
|
||||
const preImportTimeDelta = 30 // seconds to subtract from max time to account for in-flight messages
|
||||
|
||||
// maxMessageTime returns the maximum message time in PostgreSQL minus a small buffer,
|
||||
// or 0 if there are no messages yet. This is used after a --pre-import run to only
|
||||
// import messages that arrived since the pre-import.
|
||||
func maxMessageTime(pgDB *sql.DB) int64 {
|
||||
var maxTime sql.NullInt64
|
||||
if err := pgDB.QueryRow(`SELECT MAX(time) FROM message`).Scan(&maxTime); err != nil || !maxTime.Valid || maxTime.Int64 == 0 {
|
||||
return 0
|
||||
}
|
||||
sinceTime := maxTime.Int64 - preImportTimeDelta
|
||||
if sinceTime < 0 {
|
||||
return 0
|
||||
}
|
||||
fmt.Printf("Pre-imported messages detected (max time: %d), importing delta (since time %d) ...\n", maxTime.Int64, sinceTime)
|
||||
return sinceTime
|
||||
}
|
||||
|
||||
func importMessages(sqliteFile string, pgDB *sql.DB, sinceTime int64) error {
|
||||
sqlDB, err := openSQLite(sqliteFile)
|
||||
if err != nil {
|
||||
fmt.Printf("Skipping message import: %s\n", err)
|
||||
return nil
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
|
||||
|
||||
rows, err := sqlDB.Query(`SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`)
|
||||
query := `SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`
|
||||
var rows *sql.Rows
|
||||
if sinceTime > 0 {
|
||||
fmt.Printf("Importing messages from %s (since time %d) ...\n", sqliteFile, sinceTime)
|
||||
rows, err = sqlDB.Query(query+` WHERE time >= ?`, sinceTime)
|
||||
} else {
|
||||
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
|
||||
rows, err = sqlDB.Query(query)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("querying messages: %w", err)
|
||||
}
|
||||
@@ -837,7 +919,9 @@ func importWebPush(sqliteFile string, pgDB *sql.DB) error {
|
||||
}
|
||||
|
||||
func toUTF8(s string) string {
|
||||
return strings.ToValidUTF8(s, "\uFFFD")
|
||||
s = strings.ToValidUTF8(s, "\uFFFD")
|
||||
s = strings.ReplaceAll(s, "\x00", "")
|
||||
return s
|
||||
}
|
||||
|
||||
// Verification
|
||||
|
||||
141
tools/s3cli/main.go
Normal file
141
tools/s3cli/main.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT"
|
||||
//
|
||||
// s3cli put <key> <file> Upload a file
|
||||
// s3cli put <key> - Upload from stdin
|
||||
// s3cli get <key> Download to stdout
|
||||
// s3cli rm <key> [<key>...] Delete one or more objects
|
||||
// s3cli ls List all objects
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
usage()
|
||||
}
|
||||
s3URL := os.Getenv("S3_URL")
|
||||
if s3URL == "" {
|
||||
fail("S3_URL environment variable is required")
|
||||
}
|
||||
cfg, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
fail("invalid S3_URL: %s", err)
|
||||
}
|
||||
client := s3.New(cfg)
|
||||
ctx := context.Background()
|
||||
|
||||
switch os.Args[1] {
|
||||
case "put":
|
||||
cmdPut(ctx, client)
|
||||
case "get":
|
||||
cmdGet(ctx, client)
|
||||
case "rm":
|
||||
cmdRm(ctx, client)
|
||||
case "ls":
|
||||
cmdLs(ctx, client)
|
||||
default:
|
||||
usage()
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPut(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) != 4 {
|
||||
fail("usage: s3cli put <key> <file|->\n")
|
||||
}
|
||||
key := os.Args[2]
|
||||
path := os.Args[3]
|
||||
|
||||
var r io.Reader
|
||||
if path == "-" {
|
||||
r = os.Stdin
|
||||
} else {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
fail("open %s: %s", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
r = f
|
||||
}
|
||||
|
||||
if err := client.PutObject(ctx, key, r); err != nil {
|
||||
fail("put: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
|
||||
}
|
||||
|
||||
func cmdGet(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) != 3 {
|
||||
fail("usage: s3cli get <key>\n")
|
||||
}
|
||||
key := os.Args[2]
|
||||
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
if err != nil {
|
||||
fail("get: %s", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
n, err := io.Copy(os.Stdout, reader)
|
||||
if err != nil {
|
||||
fail("read: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size)
|
||||
}
|
||||
|
||||
func cmdRm(ctx context.Context, client *s3.Client) {
|
||||
if len(os.Args) < 3 {
|
||||
fail("usage: s3cli rm <key> [<key>...]\n")
|
||||
}
|
||||
keys := os.Args[2:]
|
||||
if err := client.DeleteObjects(ctx, keys); err != nil {
|
||||
fail("rm: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys))
|
||||
}
|
||||
|
||||
func cmdLs(ctx context.Context, client *s3.Client) {
|
||||
objects, err := client.ListAllObjects(ctx)
|
||||
if err != nil {
|
||||
fail("ls: %s", err)
|
||||
}
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key)
|
||||
totalSize += obj.Size
|
||||
}
|
||||
w.Flush()
|
||||
fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, `Usage: s3cli <command> [args...]
|
||||
|
||||
Commands:
|
||||
put <key> <file|-> Upload a file (use - for stdin)
|
||||
get <key> Download to stdout
|
||||
rm <key> [keys...] Delete objects
|
||||
ls List all objects
|
||||
|
||||
Environment:
|
||||
S3_URL S3 connection URL (required)
|
||||
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
`)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func fail(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
133
user/manager.go
133
user/manager.go
@@ -288,33 +288,41 @@ func (a *Manager) ChangeTier(username, tier string) error {
|
||||
t, err := a.Tier(tier)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.checkReservationsLimitTx(tx, username, t.ReservationLimit); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.updateUserTier, tier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (a *Manager) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if err := a.checkReservationsLimit(username, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := a.db.Exec(a.queries.deleteUserTier, username)
|
||||
return err
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.checkReservationsLimitTx(tx, username, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.deleteUserTier, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
|
||||
u, err := a.User(username)
|
||||
func (a *Manager) checkReservationsLimitTx(tx *sql.Tx, username string, reservationsLimit int64) error {
|
||||
u, err := a.userTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
|
||||
reservations, err := a.Reservations(username)
|
||||
reservations, err := a.reservationsTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if int64(len(reservations)) > reservationsLimit {
|
||||
@@ -388,7 +396,11 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) User(username string) (*User, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserByName, username)
|
||||
return a.userTx(a.db, username)
|
||||
}
|
||||
|
||||
func (a *Manager) userTx(tx db.Querier, username string) (*User, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -415,7 +427,7 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -642,7 +654,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
|
||||
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
|
||||
// - It also prioritizes write permissions over read permissions
|
||||
func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
@@ -713,16 +725,35 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
|
||||
|
||||
// AddReservation creates two access control entries for the given topic: one with full read/write
|
||||
// access for the given user, and one for Everyone with the given permission. Both entries are
|
||||
// created atomically in a single transaction.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
|
||||
// created atomically in a single transaction. If limit is > 0, the reservation count is checked
|
||||
// inside the transaction and ErrTooManyReservations is returned if the limit would be exceeded.
|
||||
func (a *Manager) AddReservation(username string, topic string, everyone Permission, limit int64) error {
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||
if limit > 0 {
|
||||
hasReservation, err := a.hasReservationTx(tx, username, topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasReservation {
|
||||
count, err := a.reservationsCountTx(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count >= limit {
|
||||
return ErrTooManyReservations
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), true, true, username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
|
||||
if _, err := tx.Exec(a.queries.upsertUserAccess, Everyone, toSQLWildcard(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -740,10 +771,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
}
|
||||
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -753,7 +781,11 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservations, Everyone, username)
|
||||
return a.reservationsTx(a.db.ReadOnly(), username)
|
||||
}
|
||||
|
||||
func (a *Manager) reservationsTx(tx db.Querier, username string) ([]Reservation, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -779,7 +811,11 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
return a.hasReservationTx(a.db, username, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) hasReservationTx(tx db.Querier, username, topic string) (bool, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -796,7 +832,11 @@ func (a *Manager) HasReservation(username, topic string) (bool, error) {
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectUserReservationsCount, username)
|
||||
return a.reservationsCountTx(a.db, username)
|
||||
}
|
||||
|
||||
func (a *Manager) reservationsCountTx(tx db.Querier, username string) (int64, error) {
|
||||
rows, err := tx.Query(a.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -828,6 +868,30 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// RemoveExcessReservations removes reservations that exceed the given limit for the user.
|
||||
// It returns the list of topics whose reservations were removed. The read and removal are
|
||||
// performed atomically in a single transaction to avoid issues with stale replica data.
|
||||
func (a *Manager) RemoveExcessReservations(username string, limit int64) ([]string, error) {
|
||||
return db.QueryTx(a.db, func(tx *sql.Tx) ([]string, error) {
|
||||
reservations, err := a.reservationsTx(tx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(reservations)) <= limit {
|
||||
return []string{}, nil
|
||||
}
|
||||
removedTopics := make([]string, 0)
|
||||
for i := int64(len(reservations)) - 1; i >= limit; i-- {
|
||||
topic := reservations[i].Topic
|
||||
if err := a.removeReservationAccessTx(tx, username, topic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removedTopics = append(removedTopics, topic)
|
||||
}
|
||||
return removedTopics, nil
|
||||
})
|
||||
}
|
||||
|
||||
// otherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
@@ -845,14 +909,11 @@ func (a *Manager) otherAccessCount(username, topic string) (int, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topic) {
|
||||
return ErrInvalidArgument
|
||||
func (a *Manager) removeReservationAccessTx(tx *sql.Tx, username, topic string) error {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
|
||||
return err
|
||||
return a.resetTopicAccessTx(tx, Everyone, topic)
|
||||
}
|
||||
|
||||
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
|
||||
@@ -1134,7 +1195,7 @@ func (a *Manager) Tiers() ([]*Tier, error) {
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByCode, code)
|
||||
rows, err := a.db.Query(a.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1144,7 +1205,7 @@ func (a *Manager) Tier(code string) (*Tier, error) {
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := a.db.ReadOnly().Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||
|
||||
// Create user, add reservations and token
|
||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
||||
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead))
|
||||
require.Nil(t, a.AddReservation("user", "mytopic", PermissionRead, 0))
|
||||
|
||||
u, err := a.User("user")
|
||||
require.Nil(t, err)
|
||||
@@ -439,8 +439,8 @@ func TestManager_Reservations(t *testing.T) {
|
||||
a := newTestManager(t, newManager, PermissionDenyAll)
|
||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
|
||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
|
||||
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll, 0))
|
||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead, 0))
|
||||
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
|
||||
|
||||
reservations, err := a.Reservations("ben")
|
||||
@@ -523,7 +523,7 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||
}))
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
require.Nil(t, a.ChangeTier("ben", "pro"))
|
||||
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll))
|
||||
require.Nil(t, a.AddReservation("ben", "mytopic", PermissionDenyAll, 0))
|
||||
|
||||
ben, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
@@ -1076,7 +1076,7 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) {
|
||||
|
||||
// Add 10 reservations (pro tier allows that)
|
||||
for i := 0; i < 4; i++ {
|
||||
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite))
|
||||
require.Nil(t, a.AddReservation("phil", fmt.Sprintf("topic%d", i), PermissionWrite, 0))
|
||||
}
|
||||
|
||||
// Downgrading will not work (too many reservations)
|
||||
@@ -2118,7 +2118,7 @@ func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
func TestStoreReservations(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead, 0))
|
||||
|
||||
reservations, err := manager.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
@@ -2133,8 +2133,8 @@ func TestStoreReservations(t *testing.T) {
|
||||
func TestStoreReservationsCount(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 0))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite, 0))
|
||||
|
||||
count, err := manager.ReservationsCount("phil")
|
||||
require.Nil(t, err)
|
||||
@@ -2145,7 +2145,7 @@ func TestStoreReservationsCount(t *testing.T) {
|
||||
func TestStoreHasReservation(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
has, err := manager.HasReservation("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
@@ -2160,7 +2160,7 @@ func TestStoreHasReservation(t *testing.T) {
|
||||
func TestStoreReservationOwner(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
owner, err := manager.ReservationOwner("mytopic")
|
||||
require.Nil(t, err)
|
||||
@@ -2172,6 +2172,26 @@ func TestStoreReservationOwner(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreAddReservationWithLimit(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
|
||||
// Adding reservations within limit succeeds
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite, 2))
|
||||
require.Nil(t, manager.AddReservation("phil", "topic2", PermissionRead, 2))
|
||||
|
||||
// Adding a third reservation exceeds the limit
|
||||
require.Equal(t, ErrTooManyReservations, manager.AddReservation("phil", "topic3", PermissionRead, 2))
|
||||
|
||||
// Updating an existing reservation within the limit succeeds
|
||||
require.Nil(t, manager.AddReservation("phil", "topic1", PermissionRead, 2))
|
||||
|
||||
reservations, err := manager.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reservations, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreTiers(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
tier := &Tier{
|
||||
@@ -2431,7 +2451,7 @@ func TestStoreOtherAccessCount(t *testing.T) {
|
||||
forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
|
||||
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
|
||||
require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false))
|
||||
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite))
|
||||
require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite, 0))
|
||||
|
||||
count, err := manager.otherAccessCount("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
|
||||
20
util/util.go
20
util/util.go
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"golang.org/x/term"
|
||||
@@ -434,3 +435,22 @@ func Int(v int) *int {
|
||||
func Time(v time.Time) *time.Time {
|
||||
return &v
|
||||
}
|
||||
|
||||
// SanitizeUTF8 ensures a string is safe to store in PostgreSQL by handling two cases:
|
||||
//
|
||||
// 1. Invalid UTF-8 sequences: Some clients send Latin-1/ISO-8859-1 encoded text (e.g. accented
|
||||
// characters like é, ñ, ß) in HTTP headers or SMTP messages. Go treats these as raw bytes in
|
||||
// strings, but PostgreSQL rejects them. Any invalid UTF-8 byte is replaced with the Unicode
|
||||
// replacement character (U+FFFD, "<22>") so the message is still delivered rather than lost.
|
||||
//
|
||||
// 2. NUL bytes (0x00): These are valid in UTF-8 but PostgreSQL TEXT columns reject them.
|
||||
// They are stripped entirely.
|
||||
func SanitizeUTF8(s string) string {
|
||||
if !utf8.ValidString(s) {
|
||||
s = strings.ToValidUTF8(s, "\xef\xbf\xbd") // U+FFFD
|
||||
}
|
||||
if strings.ContainsRune(s, 0) {
|
||||
s = strings.ReplaceAll(s, "\x00", "")
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
36
web/package-lock.json
generated
36
web/package-lock.json
generated
@@ -3642,9 +3642,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/baseline-browser-mapping": {
|
||||
"version": "2.10.0",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.0.tgz",
|
||||
"integrity": "sha512-lIyg0szRfYbiy67j9KN8IyeD7q7hcmqnJ1ddWmNt19ItGpNN64mnllmxUNFIOdOm6by97jlL6wfpTTJrmnjWAA==",
|
||||
"version": "2.10.8",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
|
||||
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
@@ -3766,9 +3766,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/caniuse-lite": {
|
||||
"version": "1.0.30001777",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001777.tgz",
|
||||
"integrity": "sha512-tmN+fJxroPndC74efCdp12j+0rk0RHwV5Jwa1zWaFVyw2ZxAuPeG8ZgWC3Wz7uSjT3qMRQ5XHZ4COgQmsCMJAQ==",
|
||||
"version": "1.0.30001779",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001779.tgz",
|
||||
"integrity": "sha512-U5og2PN7V4DMgF50YPNtnZJGWVLFjjsN3zb6uMT5VGYIewieDj1upwfuVNXf4Kor+89c3iCRJnSzMD5LmTvsfA==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
{
|
||||
@@ -4203,9 +4203,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/electron-to-chromium": {
|
||||
"version": "1.5.307",
|
||||
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.307.tgz",
|
||||
"integrity": "sha512-5z3uFKBWjiNR44nFcYdkcXjKMbg5KXNdciu7mhTPo9tB7NbqSNP2sSnGR+fqknZSCwKkBN+oxiiajWs4dT6ORg==",
|
||||
"version": "1.5.313",
|
||||
"resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.313.tgz",
|
||||
"integrity": "sha512-QBMrTWEf00GXZmJyx2lbYD45jpI3TUFnNIzJ5BBc8piGUDwMPa1GV6HJWTZVvY/eiN3fSopl7NRbgGp9sZ9LTA==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
@@ -4324,9 +4324,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/es-iterator-helpers": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.0.tgz",
|
||||
"integrity": "sha512-04cg8iJFDOxWcYlu0GFFWgs7vtaEPCmr5w1nrj9V3z3axu/48HCMwK6VMp45Zh3ZB+xLP1ifbJfrq86+1ypKKQ==",
|
||||
"version": "1.3.1",
|
||||
"resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.3.1.tgz",
|
||||
"integrity": "sha512-zWwRvqWiuBPr0muUG/78cW3aHROFCNIQ3zpmYDpwdbnt2m+xlNyRWpHBpa2lJjSBit7BQ+RXA1iwbSmu5yJ/EQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -7043,9 +7043,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/path-scurry/node_modules/lru-cache": {
|
||||
"version": "11.2.6",
|
||||
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.6.tgz",
|
||||
"integrity": "sha512-ESL2CrkS/2wTPfuend7Zhkzo2u0daGJ/A2VucJOgQ/C48S/zB8MMeMHSGKYpXhIjbPxfuezITkaBH1wqv00DDQ==",
|
||||
"version": "11.2.7",
|
||||
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.2.7.tgz",
|
||||
"integrity": "sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==",
|
||||
"dev": true,
|
||||
"license": "BlueOak-1.0.0",
|
||||
"engines": {
|
||||
@@ -8307,9 +8307,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/terser": {
|
||||
"version": "5.46.0",
|
||||
"resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz",
|
||||
"integrity": "sha512-jTwoImyr/QbOWFFso3YoU3ik0jBBDJ6JTOQiy/J2YxVJdZCc+5u7skhNwiOR3FQIygFqVUPHl7qbbxtjW2K3Qg==",
|
||||
"version": "5.46.1",
|
||||
"resolved": "https://registry.npmjs.org/terser/-/terser-5.46.1.tgz",
|
||||
"integrity": "sha512-vzCjQO/rgUuK9sf8VJZvjqiqiHFaZLnOiimmUuOKODxWL8mm/xua7viT7aqX7dgPY60otQjUotzFMmCB4VdmqQ==",
|
||||
"dev": true,
|
||||
"license": "BSD-2-Clause",
|
||||
"dependencies": {
|
||||
|
||||
@@ -120,7 +120,7 @@
|
||||
"publish_dialog_priority_low": "Prioridad baja",
|
||||
"publish_dialog_priority_high": "Prioridad alta",
|
||||
"publish_dialog_delay_label": "Retraso",
|
||||
"publish_dialog_title_placeholder": "Título de la notificación, por ejemplo, Alerta de espacio en disco",
|
||||
"publish_dialog_title_placeholder": "Título de la notificación, ej. Alerta de espacio en disco",
|
||||
"publish_dialog_details_examples_description": "Para ver ejemplos y una descripción detallada de todas las funciones de envío, consulte la <docsLink>documentación</docsLink>.",
|
||||
"publish_dialog_attach_placeholder": "Adjuntar un archivo por URL, por ejemplo, https://f-droid.org/F-Droid.apk",
|
||||
"publish_dialog_filename_placeholder": "Nombre del archivo adjunto",
|
||||
|
||||
@@ -63,9 +63,10 @@ func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string,
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
// Insert or update subscription, and read back the actual ID (which may differ from
|
||||
// the generated one if another request for the same endpoint raced us and inserted first)
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
if err := tx.QueryRow(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt).Scan(&subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
RETURNING id
|
||||
`
|
||||
postgresUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
|
||||
postgresUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
|
||||
|
||||
@@ -56,8 +56,9 @@ const (
|
||||
sqliteUpsertSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
RETURNING id
|
||||
`
|
||||
sqliteUpdateSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
sqliteUpdateSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
|
||||
|
||||
Reference in New Issue
Block a user