Compare commits

...

8 Commits

Author SHA1 Message Date
binwiederhier
6b11bc7468 Merge branch 'main' into attachment-s3 2026-03-16 21:11:38 -04:00
binwiederhier
d9efe50848 Email validation 2026-03-16 21:03:33 -04:00
binwiederhier
86015e100c Multipart upload 2026-03-16 20:00:19 -04:00
binwiederhier
458fbad770 Merge branch 'main' into attachment-s3 2026-03-16 15:53:21 -04:00
binwiederhier
790ba243c7 S3 WIP 2026-03-16 09:48:26 -04:00
binwiederhier
4487299a80 Merge branch 'main' into attachment-s3 2026-03-16 05:43:20 -04:00
binwiederhier
b4ec6fa8df AWS deps.. 2026-03-15 10:12:23 -04:00
binwiederhier
d517ce4a2a WIP: S3 2026-03-14 21:10:46 -04:00
23 changed files with 2697 additions and 129 deletions

1
.gitignore vendored
View File

@@ -9,6 +9,7 @@ server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
tools/loadtest/loadtest
tools/s3cli/s3cli
playground/
secrets/
*.iml

25
attachment/store.go Normal file
View File

@@ -0,0 +1,25 @@
package attachment
import (
"errors"
"fmt"
"io"
"regexp"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
// Store is an interface for storing and retrieving attachment files
type Store interface {
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
Read(id string) (io.ReadCloser, int64, error)
Remove(ids ...string) error
Size() int64
Remaining() int64
}
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)

View File

@@ -1,32 +1,29 @@
package server
package attachment
import (
"errors"
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
"io"
"os"
"path/filepath"
"regexp"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists")
)
const tagFileStore = "file_store"
type fileCache struct {
var errFileExists = errors.New("file exists")
type fileStore struct {
dir string
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
// NewFileStore creates a new file-system backed attachment store
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
@@ -34,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
if err != nil {
return nil, err
}
return &fileCache{
return &fileStore{
dir: dir,
totalSizeCurrent: size,
totalSizeLimit: totalSizeLimit,
}, nil
}
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
file := filepath.Join(c.dir, id)
if _, err := os.Stat(file); err == nil {
return 0, errFileExists
@@ -68,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in
}
c.mu.Lock()
c.totalSizeCurrent += size
mset(metricAttachmentsTotalSize, c.totalSizeCurrent)
c.mu.Unlock()
return size, nil
}
func (c *fileCache) Remove(ids ...string) error {
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
file := filepath.Join(c.dir, id)
stat, err := os.Stat(file)
if err != nil {
return nil, 0, err
}
f, err := os.Open(file)
if err != nil {
return nil, 0, err
}
return f, stat.Size(), nil
}
func (c *fileStore) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
file := filepath.Join(c.dir, id)
if err := os.Remove(file); err != nil {
log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment")
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
}
}
size, err := dirSize(c.dir)
@@ -91,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error {
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
mset(metricAttachmentsTotalSize, size)
return nil
}
func (c *fileCache) Size() int64 {
func (c *fileStore) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileCache) Remaining() int64 {
func (c *fileStore) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent

View File

@@ -0,0 +1,99 @@
package attachment
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileStore_Write_Success(t *testing.T) {
dir, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), s.Size())
require.Equal(t, int64(10229), s.Remaining())
}
func TestFileStore_Write_Read_Success(t *testing.T) {
_, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
reader, readSize, err := s.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
defer reader.Close()
data, err := io.ReadAll(reader)
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestFileStore_Write_Remove_Success(t *testing.T) {
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), s.Size())
require.Equal(t, int64(250), s.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), s.Size())
require.Equal(t, int64(2248), s.Remaining())
}
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
dir, s := newTestFileStore(t)
for i := 0; i < 10; i++ {
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
dir, s := newTestFileStore(t)
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Read_NotFound(t *testing.T) {
_, s := newTestFileStore(t)
_, _, err := s.Read("abcdefghijkl")
require.Error(t, err)
}
func newTestFileStore(t *testing.T) (dir string, store Store) {
dir = t.TempDir()
store, err := NewFileStore(dir, 10*1024)
require.Nil(t, err)
return dir, store
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

150
attachment/store_s3.go Normal file
View File

@@ -0,0 +1,150 @@
package attachment
import (
"context"
"fmt"
"io"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
const (
tagS3Store = "s3_store"
)
type s3Store struct {
client *s3.Client
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
cfg, err := s3.ParseURL(s3URL)
if err != nil {
return nil, err
}
store := &s3Store{
client: s3.New(cfg),
totalSizeLimit: totalSizeLimit,
}
if totalSizeLimit > 0 {
size, err := store.computeSize()
if err != nil {
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
}
store.totalSizeCurrent = size
}
return store, nil
}
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
// uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
pr, pw := io.Pipe()
lw := util.NewLimitWriter(pw, limiters...)
var size int64
var copyErr error
done := make(chan struct{})
go func() {
defer close(done)
size, copyErr = io.Copy(lw, in)
if copyErr != nil {
pw.CloseWithError(copyErr)
} else {
pw.Close()
}
}()
putErr := c.client.PutObject(context.Background(), id, pr)
pr.Close()
<-done
if copyErr != nil {
return 0, copyErr
}
if putErr != nil {
return 0, putErr
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
return c.client.GetObject(context.Background(), id)
}
func (c *s3Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
}
// S3 DeleteObjects supports up to 1000 keys per call
for i := 0; i < len(ids); i += 1000 {
end := i + 1000
if end > len(ids) {
end = len(ids)
}
batch := ids[i:end]
for _, id := range batch {
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
}
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
return err
}
}
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
size, err := c.computeSize()
if err != nil {
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *s3Store) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *s3Store) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
func (c *s3Store) computeSize() (int64, error) {
objects, err := c.client.ListAllObjects(context.Background())
if err != nil {
return 0, err
}
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
return totalSize, nil
}

367
attachment/store_s3_test.go Normal file
View File

@@ -0,0 +1,367 @@
package attachment
import (
"bytes"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
// --- Integration tests using a mock S3 server ---
func TestS3Store_WriteReadRemove(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, int64(11), store.Size())
// Read back
reader, readSize, err := store.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
// Remove
require.Nil(t, store.Remove("abcdefghijkl"))
require.Equal(t, int64(0), store.Size())
// Read after remove should fail
_, _, err = store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_WriteNoPrefix(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
require.Nil(t, err)
require.Equal(t, int64(4), size)
reader, _, err := store.Read("abcdefghijkl")
require.Nil(t, err)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "test", string(data))
}
func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
// First write fits
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
require.Nil(t, err)
require.Equal(t, int64(80), store.Size())
require.Equal(t, int64(20), store.Remaining())
// Second write exceeds total limit
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
require.Equal(t, util.ErrLimitReached, err)
}
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
for i := 0; i < 5; i++ {
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
require.Nil(t, err)
}
require.Equal(t, int64(500), store.Size())
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
require.Equal(t, int64(300), store.Size())
}
func TestS3Store_ReadNotFound(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, _, err := store.Read("abcdefghijkl")
require.Error(t, err)
}
func TestS3Store_InvalidID(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("bad", strings.NewReader("x"))
require.Equal(t, errInvalidFileID, err)
_, _, err = store.Read("bad")
require.Equal(t, errInvalidFileID, err)
err = store.Remove("bad")
require.Equal(t, errInvalidFileID, err)
}
// --- Helpers ---
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
t.Helper()
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
s := &s3Store{
client: &s3.Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
},
totalSizeLimit: totalSizeLimit,
}
// Compute initial size (should be 0 for fresh mock)
size, err := s.computeSize()
require.Nil(t, err)
s.totalSizeCurrent = size
return s
}
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() *httptest.Server {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m)
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
m.mu.RLock()
var contents []s3ListObject
for key, body := range m.objects {
// key is "bucket/objectkey", strip bucket prefix
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
}
}
m.mu.RUnlock()
resp := s3ListResponse{
Contents: contents,
IsTruncated: false,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
type s3ListResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []s3ListObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
}
type s3ListObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}

View File

@@ -53,6 +53,7 @@ var flagsServe = append(
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}),
@@ -166,6 +167,7 @@ func execServe(c *cli.Context) error {
authAccessRaw := c.StringSlice("auth-access")
authTokensRaw := c.StringSlice("auth-tokens")
attachmentCacheDir := c.String("attachment-cache-dir")
attachmentS3URL := c.String("attachment-s3-url")
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
attachmentExpiryDurationStr := c.String("attachment-expiry-duration")
@@ -314,6 +316,10 @@ func execServe(c *cli.Context) error {
return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set")
} else if attachmentCacheDir != "" && baseURL == "" {
return errors.New("if attachment-cache-dir is set, base-url must also be set")
} else if attachmentS3URL != "" && baseURL == "" {
return errors.New("if attachment-s3-url is set, base-url must also be set")
} else if attachmentS3URL != "" && attachmentCacheDir != "" {
return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive")
} else if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
@@ -457,6 +463,7 @@ func execServe(c *cli.Context) error {
conf.AuthAccess = authAccess
conf.AuthTokens = authTokens
conf.AttachmentCacheDir = attachmentCacheDir
conf.AttachmentS3URL = attachmentS3URL
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
conf.AttachmentExpiryDuration = attachmentExpiryDuration

View File

@@ -489,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri
## Attachments
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
Once these options are set and the directory is writable by the server user, you can upload attachments via PUT.
this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored
either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`).
Once configured, you can upload attachments via PUT.
By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues
and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download
feature) to download the file. The following config options are relevant to attachments:
* `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs
* `attachment-cache-dir` is the cache directory for attached files
* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G)
* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`)
* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`)
* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G)
* `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M)
* `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h)
Here's an example config using mostly the defaults (except for the cache directory, which is empty by default):
### Filesystem storage
Here's an example config using the local filesystem for attachment storage:
=== "/etc/ntfy/server.yml (minimal)"
``` yaml
@@ -521,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo
visitor-attachment-daily-bandwidth-limit: "500M"
```
### S3 storage
As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3,
MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage.
The `attachment-s3-url` option uses the following format:
```
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
```
When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores).
=== "/etc/ntfy/server.yml (AWS S3)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1"
```
=== "/etc/ntfy/server.yml (MinIO/custom endpoint)"
``` yaml
base-url: "https://ntfy.sh"
attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com"
```
Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit`
and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse.
@@ -2116,7 +2143,8 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. |
| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. |
| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. |
| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. |
| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. |
| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. |
@@ -2219,6 +2247,7 @@ OPTIONS:
--auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES]
--auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS]
--attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR]
--attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL]
--attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
--attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
--attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION]

View File

@@ -1798,4 +1798,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
## Not released yet
Nothing.
### ntfy server v2.20.x (UNRELEASED)
**Features:**
* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option
**Bug fixes + maintenance:**
* Reject invalid e-mail addresses (e.g. multiple comma-separated recipients) with HTTP 400

488
s3/client.go Normal file
View File

@@ -0,0 +1,488 @@
// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces,
// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP
// requests with AWS Signature V4 signing, no AWS SDK dependency required.
package s3
import (
"bytes"
"context"
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
"encoding/base64"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects,
// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix
// are fixed at construction time. All operations target the same bucket and prefix.
//
// Fields must not be modified after the Client is passed to any method or goroutine.
type Client struct {
AccessKey string // AWS access key ID
SecretKey string // AWS secret access key
Region string // e.g. "us-east-1"
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
Bucket string // S3 bucket name
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
HTTPClient *http.Client // if nil, http.DefaultClient is used
}
// New creates a new S3 client from the given Config.
func New(config *Config) *Client {
return &Client{
AccessKey: config.AccessKey,
SecretKey: config.SecretKey,
Region: config.Region,
Endpoint: config.Endpoint,
Bucket: config.Bucket,
Prefix: config.Prefix,
PathStyle: config.PathStyle,
}
}
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
// configured prefix. The body size does not need to be known in advance.
//
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
// into memory.
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
first := make([]byte, partSize)
n, err := io.ReadFull(body, first)
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
}
if err != nil {
return fmt.Errorf("s3: PutObject read: %w", err)
}
combined := io.MultiReader(bytes.NewReader(first), body)
return c.putObjectMultipart(ctx, key, combined)
}
// GetObject downloads an object. The key is automatically prefixed with the client's configured
// prefix. The caller must close the returned ReadCloser.
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
}
if resp.StatusCode/100 != 2 {
err := parseError(resp)
resp.Body.Close()
return nil, 0, err
}
return resp.Body, resp.ContentLength, nil
}
// DeleteObjects removes multiple objects in a single batch request. Keys are automatically
// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the
// caller is responsible for batching if needed.
//
// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present
// in the response, they are returned as a combined error.
func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
var body bytes.Buffer
body.WriteString("<Delete><Quiet>true</Quiet>")
for _, key := range keys {
body.WriteString("<Object><Key>")
xml.EscapeText(&body, []byte(c.objectKey(key)))
body.WriteString("</Key></Object>")
}
body.WriteString("</Delete>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
reqURL := c.bucketURL() + "?delete="
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-MD5", contentMD5)
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: DeleteObjects: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// S3 may return HTTP 200 with per-key errors in the response body
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
}
var result deleteResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
}
if len(result.Errors) > 0 {
var msgs []string
for _, e := range result.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message))
}
return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; "))
}
return nil
}
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
query := url.Values{"list-type": {"2"}}
if prefix := c.prefixForList(); prefix != "" {
query.Set("prefix", prefix)
}
if continuationToken != "" {
query.Set("continuation-token", continuationToken)
}
if maxKeys > 0 {
query.Set("max-keys", strconv.Itoa(maxKeys))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects: %w", err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
}
if resp.StatusCode/100 != 2 {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
var result listObjectsV2Response
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("s3: ListObjects XML: %w", err)
}
objects := make([]Object, len(result.Contents))
for i, obj := range result.Contents {
objects[i] = Object(obj)
}
return &ListResult{
Objects: objects,
IsTruncated: result.IsTruncated,
NextContinuationToken: result.NextContinuationToken,
}, nil
}
// ListAllObjects returns all objects under the client's configured prefix by paginating through
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
const maxPages = 10000
var all []Object
var token string
for page := 0; page < maxPages; page++ {
result, err := c.ListObjects(ctx, token, 0)
if err != nil {
return nil, err
}
all = append(all, result.Objects...)
if !result.IsTruncated {
return all, nil
}
token = result.NextContinuationToken
}
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
}
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
if err != nil {
return fmt.Errorf("s3: PutObject request: %w", err)
}
req.ContentLength = size
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
if err != nil {
return err
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
reqURL := c.objectURL(fullKey) + "?uploads"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
}
req.ContentLength = 0
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
etag := resp.Header.Get("ETag")
return etag, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// Read response body to check for errors (S3 can return 200 with an error body)
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
}
// Check if the response contains an error
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
errResp.StatusCode = resp.StatusCode
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
now := time.Now().UTC()
datestamp := now.Format("20060102")
amzDate := now.Format("20060102T150405Z")
// Required headers
req.Header.Set("Host", c.hostHeader())
req.Header.Set("X-Amz-Date", amzDate)
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
// Canonical headers (all headers we set, sorted by lowercase key)
signedKeys := make([]string, 0, len(req.Header))
canonHeaders := make(map[string]string, len(req.Header))
for k := range req.Header {
lk := strings.ToLower(k)
signedKeys = append(signedKeys, lk)
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
}
sort.Strings(signedKeys)
signedHeadersStr := strings.Join(signedKeys, ";")
var chBuf strings.Builder
for _, k := range signedKeys {
chBuf.WriteString(k)
chBuf.WriteByte(':')
chBuf.WriteString(canonHeaders[k])
chBuf.WriteByte('\n')
}
// Canonical request
canonicalRequest := strings.Join([]string{
req.Method,
canonicalURI(req.URL),
canonicalQueryString(req.URL.Query()),
chBuf.String(),
signedHeadersStr,
payloadHash,
}, "\n")
// String to sign
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
// Signing key
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
[]byte(c.Region)),
[]byte("s3")),
[]byte("aws4_request"))
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
req.Header.Set("Authorization", fmt.Sprintf(
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.AccessKey, credentialScope, signedHeadersStr, signature,
))
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
// objectKey prepends the configured prefix to the given key.
func (c *Client) objectKey(key string) string {
if c.Prefix != "" {
return c.Prefix + "/" + key
}
return key
}
// prefixForList returns the prefix to use in ListObjectsV2 requests,
// with a trailing slash so that only objects under the prefix directory are returned.
func (c *Client) prefixForList() string {
if c.Prefix != "" {
return c.Prefix + "/"
}
return ""
}
// bucketURL returns the base URL for bucket-level operations.
func (c *Client) bucketURL() string {
if c.PathStyle {
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
}
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
}
// objectURL returns the full URL for an object (key should already include the prefix).
// Each path segment is URI-encoded to handle special characters in keys.
func (c *Client) objectURL(key string) string {
segments := strings.Split(key, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return c.bucketURL() + "/" + strings.Join(segments, "/")
}
// hostHeader returns the value for the Host header.
func (c *Client) hostHeader() string {
if c.PathStyle {
return c.Endpoint
}
return c.Bucket + "." + c.Endpoint
}

860
s3/client_test.go Normal file
View File

@@ -0,0 +1,860 @@
package s3
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"sort"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
// --- Mock S3 server ---
//
// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() (*httptest.Server, *mockS3Server) {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m), m
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
}
}
func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
m.objects[path] = body
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
m.mu.RUnlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
w.WriteHeader(http.StatusOK)
w.Write(body)
}
type listObjectsResponse struct {
XMLName xml.Name `xml:"ListBucketResult"`
Contents []listObject `xml:"Contents"`
// Pagination support
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) {
// bucketPath is just the bucket name
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var req struct {
Objects []struct {
Key string `xml:"Key"`
} `xml:"Object"`
}
if err := xml.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
m.mu.Lock()
for _, obj := range req.Objects {
delete(m.objects, bucketPath+"/"+obj.Key)
}
m.mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?><DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></DeleteResult>`))
}
func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) {
prefix := r.URL.Query().Get("prefix")
contToken := r.URL.Query().Get("continuation-token")
m.mu.RLock()
var allKeys []string
for key := range m.objects {
objKey := strings.TrimPrefix(key, bucketPath+"/")
if objKey == key {
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
allKeys = append(allKeys, objKey)
}
}
m.mu.RUnlock()
sort.Strings(allKeys)
// Simple continuation token: it's the key to start after
startIdx := 0
if contToken != "" {
for i, k := range allKeys {
if k == contToken {
startIdx = i + 1
break
}
}
}
maxKeys := 1000
if mk := r.URL.Query().Get("max-keys"); mk != "" {
fmt.Sscanf(mk, "%d", &maxKeys)
}
endIdx := startIdx + maxKeys
truncated := false
nextToken := ""
if endIdx < len(allKeys) {
truncated = true
nextToken = allKeys[endIdx-1]
allKeys = allKeys[startIdx:endIdx]
} else {
allKeys = allKeys[startIdx:]
}
m.mu.RLock()
var contents []listObject
for _, objKey := range allKeys {
body := m.objects[bucketPath+"/"+objKey]
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
}
m.mu.RUnlock()
resp := listObjectsResponse{
Contents: contents,
IsTruncated: truncated,
NextContinuationToken: nextToken,
}
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
xml.NewEncoder(w).Encode(resp)
}
func (m *mockS3Server) objectCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.objects)
}
// --- Helper to create a test client pointing at mock server ---
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
return &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: host,
Bucket: bucket,
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
}
}
// --- URL parsing tests ---
func TestParseURL_Success(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "attachments", cfg.Prefix)
require.Equal(t, "us-east-1", cfg.Region)
require.Equal(t, "AKID", cfg.AccessKey)
require.Equal(t, "SECRET", cfg.SecretKey)
require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint)
require.False(t, cfg.PathStyle)
}
func TestParseURL_NoPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "", cfg.Prefix)
}
func TestParseURL_WithEndpoint(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "prefix", cfg.Prefix)
require.Equal(t, "s3.example.com", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointHTTP(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000")
require.Nil(t, err)
require.Equal(t, "localhost:9000", cfg.Endpoint)
require.True(t, cfg.PathStyle)
}
func TestParseURL_EndpointTrailingSlash(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/")
require.Nil(t, err)
require.Equal(t, "s3.example.com", cfg.Endpoint)
}
func TestParseURL_NestedPrefix(t *testing.T) {
cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1")
require.Nil(t, err)
require.Equal(t, "my-bucket", cfg.Bucket)
require.Equal(t, "a/b/c", cfg.Prefix)
}
func TestParseURL_MissingRegion(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@my-bucket")
require.Error(t, err)
require.Contains(t, err.Error(), "region")
}
func TestParseURL_MissingCredentials(t *testing.T) {
_, err := ParseURL("s3://my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "access key")
}
func TestParseURL_MissingSecretKey(t *testing.T) {
_, err := ParseURL("s3://AKID@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "secret key")
}
func TestParseURL_WrongScheme(t *testing.T) {
_, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "scheme")
}
func TestParseURL_EmptyBucket(t *testing.T) {
_, err := ParseURL("s3://AKID:SECRET@?region=us-east-1")
require.Error(t, err)
require.Contains(t, err.Error(), "bucket")
}
// --- Unit tests: URL construction ---
func TestClient_BucketURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
}
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
}
func TestClient_ObjectURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_HostHeader_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "s3.example.com", c.hostHeader())
}
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
}
func TestClient_ObjectKey(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/file123", c.objectKey("file123"))
c2 := &Client{Prefix: ""}
require.Equal(t, "file123", c2.objectKey("file123"))
}
func TestClient_PrefixForList(t *testing.T) {
c := &Client{Prefix: "attachments"}
require.Equal(t, "attachments/", c.prefixForList())
c2 := &Client{Prefix: ""}
require.Equal(t, "", c2.prefixForList())
}
// --- Integration tests using mock S3 server ---
func TestClient_PutGetObject(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
require.Equal(t, int64(11), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, "hello world", string(data))
}
func TestClient_PutGetObject_WithPrefix(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "hello", string(data))
}
func TestClient_GetObject_NotFound(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
_, _, err := client.GetObject(context.Background(), "nonexistent")
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
}
func TestClient_DeleteObjects(t *testing.T) {
server, mock := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put several objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
require.Nil(t, err)
}
require.Equal(t, 5, mock.objectCount())
// Delete some
err := client.DeleteObjects(ctx, []string{"key-1", "key-3"})
require.Nil(t, err)
require.Equal(t, 3, mock.objectCount())
// Verify deleted ones are gone
_, _, err = client.GetObject(ctx, "key-1")
require.Error(t, err)
_, _, err = client.GetObject(ctx, "key-3")
require.Error(t, err)
// Verify remaining ones are still there
reader, _, err := client.GetObject(ctx, "key-0")
require.Nil(t, err)
reader.Close()
}
func TestClient_ListObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
ctx := context.Background()
// Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ {
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "")
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
require.Nil(t, err)
// List with prefix client: should only see 3
result, err := client.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 3)
require.False(t, result.IsTruncated)
// List with no-prefix client: should see all 4
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 4)
}
func TestClient_ListObjects_Pagination(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Put 5 objects
for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// List with max-keys=2
result, err := client.ListObjects(ctx, "", 2)
require.Nil(t, err)
require.Len(t, result.Objects, 2)
require.True(t, result.IsTruncated)
require.NotEmpty(t, result.NextContinuationToken)
// Get next page
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result2.Objects, 2)
require.True(t, result2.IsTruncated)
// Get last page
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result3.Objects, 1)
require.False(t, result3.IsTruncated)
}
func TestClient_ListAllObjects(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "pfx")
ctx := context.Background()
for i := 0; i < 10; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
}
func TestClient_PutObject_LargeBody(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 1 MB object
data := make([]byte, 1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "large", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large")
require.Nil(t, err)
require.Equal(t, int64(1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
data := make([]byte, 12*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "multipart")
require.Nil(t, err)
require.Equal(t, int64(12*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ExactPartSize(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
data := make([]byte, 5*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "exact")
require.Nil(t, err)
require.Equal(t, int64(5*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
require.Nil(t, err)
data, _ := io.ReadAll(reader)
reader.Close()
require.Equal(t, "nested", string(data))
}
// --- Scale test: 20k objects (ntfy-adjacent) ---
func TestClient_ListAllObjects_20k(t *testing.T) {
if testing.Short() {
t.Skip("skipping 20k object test in short mode")
}
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "attachments")
ctx := context.Background()
const numObjects = 20000
const batchSize = 500
// Insert 20k objects in batches to keep it fast
for batch := 0; batch < numObjects/batchSize; batch++ {
for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx)
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
}
// List all 20k objects with pagination
objects, err := client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects)
// Verify total size
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
require.Equal(t, int64(numObjects), totalSize)
// Delete 1000 objects (simulating attachment expiry cleanup)
keys := make([]string, 1000)
for i := range keys {
keys[i] = fmt.Sprintf("%08d", i)
}
err = client.DeleteObjects(ctx, keys)
require.Nil(t, err)
// List again: should have 19000
objects, err = client.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, numObjects-1000)
}
// --- Real S3 integration test ---
//
// Set the following environment variables to run this test against a real S3 bucket:
//
// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET
//
// Optional:
//
// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com")
// S3_PATH_STYLE: set to "true" for path-style addressing
// S3_PREFIX: key prefix to use (default: "ntfy-s3-test")
func TestClient_RealBucket(t *testing.T) {
accessKey := os.Getenv("S3_ACCESS_KEY")
secretKey := os.Getenv("S3_SECRET_KEY")
region := os.Getenv("S3_REGION")
bucket := os.Getenv("S3_BUCKET")
if accessKey == "" || secretKey == "" || region == "" || bucket == "" {
t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET")
}
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
}
pathStyle := os.Getenv("S3_PATH_STYLE") == "true"
prefix := os.Getenv("S3_PREFIX")
if prefix == "" {
prefix = "ntfy-s3-test"
}
client := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix,
PathStyle: pathStyle,
}
ctx := context.Background()
// Clean up any leftover objects from previous runs
existing, err := client.ListAllObjects(ctx)
require.Nil(t, err)
if len(existing) > 0 {
keys := make([]string, len(existing))
for i, obj := range existing {
// Strip the prefix since DeleteObjects will re-add it
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
}
// Batch delete in groups of 1000
for i := 0; i < len(keys); i += 1000 {
end := i + 1000
if end > len(keys) {
end = len(keys)
}
err := client.DeleteObjects(ctx, keys[i:end])
require.Nil(t, err)
}
}
t.Run("PutGetDelete", func(t *testing.T) {
key := "test-object"
content := "hello from ntfy s3 test"
// Put
err := client.PutObject(ctx, key, strings.NewReader(content))
require.Nil(t, err)
// Get
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(content)), size)
data, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, content, string(data))
// Delete
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
// Get after delete should fail
_, _, err = client.GetObject(ctx, key)
require.Error(t, err)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
})
t.Run("ListObjects", func(t *testing.T) {
// Use a sub-prefix client for isolation
listClient := &Client{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
Endpoint: endpoint,
Bucket: bucket,
Prefix: prefix + "/list-test",
PathStyle: pathStyle,
}
// Put 10 objects
for i := 0; i < 10; i++ {
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
require.Nil(t, err)
}
// List
objects, err := listClient.ListAllObjects(ctx)
require.Nil(t, err)
require.Len(t, objects, 10)
// Clean up
keys := make([]string, 10)
for i := range keys {
keys[i] = fmt.Sprintf("%d", i)
}
err = listClient.DeleteObjects(ctx, keys)
require.Nil(t, err)
})
t.Run("LargeObject", func(t *testing.T) {
key := "large-object"
data := make([]byte, 5*1024*1024) // 5 MB
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, key, bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key)
require.Nil(t, err)
require.Equal(t, int64(len(data)), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
err = client.DeleteObjects(ctx, []string{key})
require.Nil(t, err)
})
}

76
s3/types.go Normal file
View File

@@ -0,0 +1,76 @@
package s3
import "fmt"
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
type Config struct {
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
PathStyle bool
Bucket string
Prefix string
Region string
AccessKey string
SecretKey string
}
// Object represents an S3 object returned by list operations.
type Object struct {
Key string
Size int64
}
// ListResult holds the response from a ListObjectsV2 call.
type ListResult struct {
Objects []Object
IsTruncated bool
NextContinuationToken string
}
// ErrorResponse is returned when S3 responds with a non-2xx status code.
type ErrorResponse struct {
StatusCode int
Code string `xml:"Code"`
Message string `xml:"Message"`
Body string `xml:"-"` // raw response body
}
func (e *ErrorResponse) Error() string {
if e.Code != "" {
return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message)
}
return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body)
}
// listObjectsV2Response is the XML response from S3 ListObjectsV2
type listObjectsV2Response struct {
Contents []listObject `xml:"Contents"`
IsTruncated bool `xml:"IsTruncated"`
NextContinuationToken string `xml:"NextContinuationToken"`
}
type listObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
}
// deleteResult is the XML response from S3 DeleteObjects
type deleteResult struct {
Errors []deleteError `xml:"Error"`
}
type deleteError struct {
Key string `xml:"Key"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
type initiateMultipartUploadResult struct {
UploadID string `xml:"UploadId"`
}
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
type completedPart struct {
PartNumber int
ETag string
}

166
s3/util.go Normal file
View File

@@ -0,0 +1,166 @@
package s3
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
)
const (
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
unsignedPayload = "UNSIGNED-PAYLOAD"
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
maxResponseBytes = 10 * 1024 * 1024
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
// part size of 5 MB for all parts except the last.
partSize = 5 * 1024 * 1024
)
// ParseURL parses an S3 URL of the form:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
//
// When endpoint is specified, path-style addressing is enabled automatically.
func ParseURL(s3URL string) (*Config, error) {
u, err := url.Parse(s3URL)
if err != nil {
return nil, fmt.Errorf("s3: invalid URL: %w", err)
}
if u.Scheme != "s3" {
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
}
if u.Host == "" {
return nil, fmt.Errorf("s3: bucket name must be specified as host")
}
bucket := u.Host
prefix := strings.TrimPrefix(u.Path, "/")
accessKey := u.User.Username()
secretKey, _ := u.User.Password()
if accessKey == "" || secretKey == "" {
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
}
region := u.Query().Get("region")
if region == "" {
return nil, fmt.Errorf("s3: region query parameter is required")
}
endpointParam := u.Query().Get("endpoint")
var endpoint string
var pathStyle bool
if endpointParam != "" {
// Custom endpoint: strip scheme prefix to extract host[:port]
ep := strings.TrimRight(endpointParam, "/")
ep = strings.TrimPrefix(ep, "https://")
ep = strings.TrimPrefix(ep, "http://")
endpoint = ep
pathStyle = true
} else {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
pathStyle = false
}
return &Config{
Endpoint: endpoint,
PathStyle: pathStyle,
Bucket: bucket,
Prefix: prefix,
Region: region,
AccessKey: accessKey,
SecretKey: secretKey,
}, nil
}
// parseError reads an S3 error response and returns an *ErrorResponse.
func parseError(resp *http.Response) error {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: reading error response: %w", err)
}
return parseErrorFromBytes(resp.StatusCode, body)
}
func parseErrorFromBytes(statusCode int, body []byte) error {
errResp := &ErrorResponse{
StatusCode: statusCode,
Body: string(body),
}
// Try to parse XML error; if it fails, we still have StatusCode and Body
_ = xml.Unmarshal(body, errResp)
return errResp
}
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
// percent-encoded per RFC 3986; forward slashes are preserved.
func canonicalURI(u *url.URL) string {
p := u.Path
if p == "" {
return "/"
}
segments := strings.Split(p, "/")
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return strings.Join(segments, "/")
}
// canonicalQueryString builds the query string for the canonical request. Keys and values
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
func canonicalQueryString(values url.Values) string {
if len(values) == 0 {
return ""
}
keys := make([]string, 0, len(values))
for k := range values {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
ek := uriEncode(k)
vs := make([]string, len(values[k]))
copy(vs, values[k])
sort.Strings(vs)
for _, v := range vs {
pairs = append(pairs, ek+"="+uriEncode(v))
}
}
return strings.Join(pairs, "&")
}
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
// characters (A-Z a-z 0-9 - _ . ~).
func uriEncode(s string) string {
var buf strings.Builder
for i := 0; i < len(s); i++ {
b := s[i]
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "%%%02X", b)
}
}
return buf.String()
}
func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}

181
s3/util_test.go Normal file
View File

@@ -0,0 +1,181 @@
package s3
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestURIEncode(t *testing.T) {
// Unreserved characters are not encoded
require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz"))
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
require.Equal(t, "0123456789", uriEncode("0123456789"))
require.Equal(t, "-_.~", uriEncode("-_.~"))
// Spaces use %20, not +
require.Equal(t, "hello%20world", uriEncode("hello world"))
// Slashes are encoded (canonicalURI handles slash splitting separately)
require.Equal(t, "a%2Fb", uriEncode("a/b"))
// Special characters
require.Equal(t, "%2B", uriEncode("+"))
require.Equal(t, "%3D", uriEncode("="))
require.Equal(t, "%26", uriEncode("&"))
require.Equal(t, "%40", uriEncode("@"))
require.Equal(t, "%23", uriEncode("#"))
// Mixed
require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt"))
require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes"))
// Empty string
require.Equal(t, "", uriEncode(""))
}
func TestCanonicalURI(t *testing.T) {
// Simple path
u, _ := url.Parse("https://example.com/bucket/key")
require.Equal(t, "/bucket/key", canonicalURI(u))
// Root path
u, _ = url.Parse("https://example.com/")
require.Equal(t, "/", canonicalURI(u))
// Empty path
u, _ = url.Parse("https://example.com")
require.Equal(t, "/", canonicalURI(u))
// Path with special characters
u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces")
require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u))
// Nested path
u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt")
require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u))
}
func TestCanonicalQueryString(t *testing.T) {
// Multiple keys sorted alphabetically
vals := url.Values{
"prefix": {"test/"},
"list-type": {"2"},
}
require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals))
// Empty values
require.Equal(t, "", canonicalQueryString(url.Values{}))
// Single key
require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}}))
// Key with multiple values (sorted)
vals = url.Values{"key": {"b", "a"}}
require.Equal(t, "key=a&key=b", canonicalQueryString(vals))
// Keys requiring encoding
vals = url.Values{"continuation-token": {"abc+def"}}
require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals))
}
func TestSHA256Hex(t *testing.T) {
// SHA-256 of empty string
require.Equal(t, emptyPayloadHash, sha256Hex([]byte("")))
// SHA-256 of known value
require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello")))
}
func TestHmacSHA256(t *testing.T) {
// Known test vector: HMAC-SHA256("key", "message")
result := hmacSHA256([]byte("key"), []byte("message"))
require.Len(t, result, 32) // SHA-256 produces 32 bytes
require.NotEqual(t, make([]byte, 32), result)
// Same inputs should produce same output
result2 := hmacSHA256([]byte("key"), []byte("message"))
require.Equal(t, result, result2)
// Different inputs should produce different output
result3 := hmacSHA256([]byte("different-key"), []byte("message"))
require.NotEqual(t, result, result3)
}
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, emptyPayloadHash)
// All required SigV4 headers must be set
require.NotEmpty(t, req.Header.Get("Host"))
require.NotEmpty(t, req.Header.Get("X-Amz-Date"))
require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256"))
// Authorization header must have correct format
auth := req.Header.Get("Authorization")
require.Contains(t, auth, "AWS4-HMAC-SHA256")
require.Contains(t, auth, "Credential=AKID/")
require.Contains(t, auth, "/us-east-1/s3/aws4_request")
require.Contains(t, auth, "SignedHeaders=")
require.Contains(t, auth, "Signature=")
}
func TestSignV4_UnsignedPayload(t *testing.T) {
c := &Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, unsignedPayload)
require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256"))
}
func TestSignV4_DifferentRegions(t *testing.T) {
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
c1.signV4(req1, emptyPayloadHash)
req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil)
c2.signV4(req2, emptyPayloadHash)
// Different regions should produce different signatures
require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization"))
}
func TestParseError_XMLResponse(t *testing.T) {
xmlBody := []byte(`<?xml version="1.0" encoding="UTF-8"?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message></Error>`)
err := parseErrorFromBytes(404, xmlBody)
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 404, errResp.StatusCode)
require.Equal(t, "NoSuchKey", errResp.Code)
require.Equal(t, "The specified key does not exist.", errResp.Message)
}
func TestParseError_NonXMLResponse(t *testing.T) {
err := parseErrorFromBytes(500, []byte("internal server error"))
var errResp *ErrorResponse
require.ErrorAs(t, err, &errResp)
require.Equal(t, 500, errResp.StatusCode)
require.Equal(t, "", errResp.Code) // XML parsing failed, no code
require.Contains(t, errResp.Body, "internal server error")
}

View File

@@ -112,6 +112,7 @@ type Config struct {
AuthBcryptCost int
AuthStatsQueueWriterInterval time.Duration
AttachmentCacheDir string
AttachmentS3URL string
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
AttachmentExpiryDuration time.Duration

View File

@@ -142,6 +142,7 @@ var (
errHTTPBadRequestTemplateFileNotFound = &errHTTP{40047, http.StatusBadRequest, "invalid request: template file not found", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestTemplateFileInvalid = &errHTTP{40048, http.StatusBadRequest, "invalid request: template file invalid", "https://ntfy.sh/docs/publish/#message-templating", nil}
errHTTPBadRequestSequenceIDInvalid = &errHTTP{40049, http.StatusBadRequest, "invalid request: sequence ID invalid", "https://ntfy.sh/docs/publish/#updating-deleting-notifications", nil}
errHTTPBadRequestEmailAddressInvalid = &errHTTP{40050, http.StatusBadRequest, "invalid request: invalid e-mail address", "https://ntfy.sh/docs/publish/#e-mail-notifications", nil}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}

View File

@@ -1,76 +0,0 @@
package server
import (
"bytes"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
"os"
"strings"
"testing"
)
var (
oneKilobyteArray = make([]byte, 1024)
)
func TestFileCache_Write_Success(t *testing.T) {
dir, c := newTestFileCache(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining())
}
func TestFileCache_Write_Remove_Success(t *testing.T) {
dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), c.Size())
require.Equal(t, int64(2248), c.Remaining())
}
func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) {
dir, c := newTestFileCache(t)
for i := 0; i < 10; i++ {
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) {
dir, c := newTestFileCache(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func newTestFileCache(t *testing.T) (dir string, cache *fileCache) {
dir = t.TempDir()
cache, err := newFileCache(dir, 10*1024)
require.Nil(t, err)
return dir, cache
}
func readFile(t *testing.T, f string) string {
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)
}

View File

@@ -24,7 +24,6 @@ const (
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagTwilio = "twilio"
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
tagAccount = "account"

View File

@@ -32,6 +32,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/attachment"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/db/pg"
"heckel.io/ntfy/v2/log"
@@ -64,7 +65,7 @@ type Server struct {
userManager *user.Manager // Might be nil!
messageCache *message.Cache // Database that stores the messages
webPush *webpush.Store // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
fileCache attachment.Store // Attachment store (file system or S3)
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
@@ -122,6 +123,7 @@ var (
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
urlRegex = regexp.MustCompile(`^https?://`)
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
emailAddressRegex = regexp.MustCompile(`^[^\s,;]+@[^\s,;]+$`)
//go:embed site
webFs embed.FS
@@ -227,12 +229,9 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var fileCache *fileCache
if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
if err != nil {
return nil, err
}
fileCache, err := createAttachmentStore(conf)
if err != nil {
return nil, err
}
var userManager *user.Manager
if conf.AuthFile != "" || pool != nil {
@@ -301,6 +300,15 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
return message.NewMemStore()
}
func createAttachmentStore(conf *Config) (attachment.Store, error) {
if conf.AttachmentS3URL != "" {
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
} else if conf.AttachmentCacheDir != "" {
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
}
return nil, nil
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
// a manager go routine to print stats and prune messages.
func (s *Server) Run() error {
@@ -752,7 +760,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor)
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.AttachmentCacheDir == "" {
if s.fileCache == nil {
return errHTTPInternalError
}
matches := fileRegex.FindStringSubmatch(r.URL.Path)
@@ -760,16 +768,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
return errHTTPInternalErrorInvalidPath
}
messageID := matches[1]
file := filepath.Join(s.config.AttachmentCacheDir, messageID)
stat, err := os.Stat(file)
reader, size, err := s.fileCache.Read(messageID)
if err != nil {
return errHTTPNotFound.Fields(log.Context{
"message_id": messageID,
"error_context": "filesystem",
"error_context": "attachment_store",
})
}
defer reader.Close()
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Content-Length", fmt.Sprintf("%d", size))
if r.Method == http.MethodHead {
return nil
}
@@ -805,19 +813,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
} else if m.Sender.IsValid() {
bandwidthVisitor = s.visitor(m.Sender, nil)
}
if !bandwidthVisitor.BandwidthAllowed(stat.Size()) {
if !bandwidthVisitor.BandwidthAllowed(size) {
return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m)
}
// Actually send file
f, err := os.Open(file)
if err != nil {
return err
}
defer f.Close()
if m.Attachment.Name != "" {
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
}
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader)
return err
}
@@ -1163,6 +1166,9 @@ func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bo
m.Icon = icon
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if email != "" && !emailAddressRegex.MatchString(email) {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailAddressInvalid
}
if s.smtpSender == nil && email != "" {
return false, false, "", "", "", false, "", errHTTPBadRequestEmailDisabled
}
@@ -1409,7 +1415,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
if s.fileCache == nil || s.config.BaseURL == "" {
return errHTTPBadRequestAttachmentsDisallowed.With(m)
}
vinfo, err := v.Info()

View File

@@ -159,6 +159,7 @@
# - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h)
#
# attachment-cache-dir:
# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1"
# attachment-total-size-limit: "5G"
# attachment-file-size-limit: "15M"
# attachment-expiry-duration: "3h"

View File

@@ -99,6 +99,9 @@ func (s *Server) execManager() {
mset(metricUsers, usersCount)
mset(metricSubscribers, subscribers)
mset(metricTopics, topicsCount)
if s.fileCache != nil {
mset(metricAttachmentsTotalSize, s.fileCache.Size())
}
}
func (s *Server) pruneVisitors() {

View File

@@ -1543,6 +1543,30 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
})
}
func TestServer_PublishEmailAddressInvalid(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
s.smtpSender = &testMailer{}
addresses := []string{
"test@example.com, other@example.com",
"invalidaddress",
"@nope",
"nope@",
}
for _, email := range addresses {
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
"E-Mail": email,
})
require.Equal(t, 400, response.Code, "expected 400 for email: %s", email)
}
// Valid address should succeed
response := request(t, s, "PUT", "/mytopic", "success", map[string]string{
"E-Mail": "test@example.com",
})
require.Equal(t, 200, response.Code)
})
}
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
t.Parallel()

141
tools/s3cli/main.go Normal file
View File

@@ -0,0 +1,141 @@
// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls.
//
// Usage:
//
// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT"
//
// s3cli put <key> <file> Upload a file
// s3cli put <key> - Upload from stdin
// s3cli get <key> Download to stdout
// s3cli rm <key> [<key>...] Delete one or more objects
// s3cli ls List all objects
package main
import (
"context"
"fmt"
"io"
"os"
"text/tabwriter"
"heckel.io/ntfy/v2/s3"
)
func main() {
if len(os.Args) < 2 {
usage()
}
s3URL := os.Getenv("S3_URL")
if s3URL == "" {
fail("S3_URL environment variable is required")
}
cfg, err := s3.ParseURL(s3URL)
if err != nil {
fail("invalid S3_URL: %s", err)
}
client := s3.New(cfg)
ctx := context.Background()
switch os.Args[1] {
case "put":
cmdPut(ctx, client)
case "get":
cmdGet(ctx, client)
case "rm":
cmdRm(ctx, client)
case "ls":
cmdLs(ctx, client)
default:
usage()
}
}
func cmdPut(ctx context.Context, client *s3.Client) {
if len(os.Args) != 4 {
fail("usage: s3cli put <key> <file|->\n")
}
key := os.Args[2]
path := os.Args[3]
var r io.Reader
if path == "-" {
r = os.Stdin
} else {
f, err := os.Open(path)
if err != nil {
fail("open %s: %s", path, err)
}
defer f.Close()
r = f
}
if err := client.PutObject(ctx, key, r); err != nil {
fail("put: %s", err)
}
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
}
func cmdGet(ctx context.Context, client *s3.Client) {
if len(os.Args) != 3 {
fail("usage: s3cli get <key>\n")
}
key := os.Args[2]
reader, size, err := client.GetObject(ctx, key)
if err != nil {
fail("get: %s", err)
}
defer reader.Close()
n, err := io.Copy(os.Stdout, reader)
if err != nil {
fail("read: %s", err)
}
fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size)
}
func cmdRm(ctx context.Context, client *s3.Client) {
if len(os.Args) < 3 {
fail("usage: s3cli rm <key> [<key>...]\n")
}
keys := os.Args[2:]
if err := client.DeleteObjects(ctx, keys); err != nil {
fail("rm: %s", err)
}
fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys))
}
func cmdLs(ctx context.Context, client *s3.Client) {
objects, err := client.ListAllObjects(ctx)
if err != nil {
fail("ls: %s", err)
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
var totalSize int64
for _, obj := range objects {
fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key)
totalSize += obj.Size
}
w.Flush()
fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize)
}
func usage() {
fmt.Fprintf(os.Stderr, `Usage: s3cli <command> [args...]
Commands:
put <key> <file|-> Upload a file (use - for stdin)
get <key> Download to stdout
rm <key> [keys...] Delete objects
ls List all objects
Environment:
S3_URL S3 connection URL (required)
s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
`)
os.Exit(1)
}
func fail(format string, args ...any) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}