diff --git a/.gitignore b/.gitignore
index ed17b2d4..6d5deb67 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,7 @@ server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
tools/loadtest/loadtest
+tools/s3cli/s3cli
playground/
secrets/
*.iml
diff --git a/attachment/store_s3.go b/attachment/store_s3.go
index 5c47a81b..38f0353a 100644
--- a/attachment/store_s3.go
+++ b/attachment/store_s3.go
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
- "os"
"sync"
"heckel.io/ntfy/v2/log"
@@ -51,33 +50,31 @@ func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int6
}
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
- // Write through limiters into a temp file. This avoids buffering the full attachment in
- // memory while still giving us the Content-Length that PutObject requires.
+ // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
+ // uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
- tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*")
- if err != nil {
- return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err)
+ pr, pw := io.Pipe()
+ lw := util.NewLimitWriter(pw, limiters...)
+ var size int64
+ var copyErr error
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ size, copyErr = io.Copy(lw, in)
+ if copyErr != nil {
+ pw.CloseWithError(copyErr)
+ } else {
+ pw.Close()
+ }
+ }()
+ putErr := c.client.PutObject(context.Background(), id, pr)
+ pr.Close()
+ <-done
+ if copyErr != nil {
+ return 0, copyErr
}
- tmpPath := tmpFile.Name()
- defer os.Remove(tmpPath)
- limitWriter := util.NewLimitWriter(tmpFile, limiters...)
- size, err := io.Copy(limitWriter, in)
- if err != nil {
- tmpFile.Close()
- return 0, err
- }
- if err := tmpFile.Close(); err != nil {
- return 0, err
- }
-
- // Re-open the temp file for reading and stream it to S3
- f, err := os.Open(tmpPath)
- if err != nil {
- return 0, err
- }
- defer f.Close()
- if err := c.client.PutObject(context.Background(), id, f, size); err != nil {
- return 0, err
+ if putErr != nil {
+ return 0, putErr
}
c.mu.Lock()
c.totalSizeCurrent += size
diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go
index c898244d..872e8c23 100644
--- a/attachment/store_s3_test.go
+++ b/attachment/store_s3_test.go
@@ -167,27 +167,41 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
- objects map[string][]byte // full key (bucket/key) -> body
+ objects map[string][]byte // full key (bucket/key) -> body
+ uploads map[string]map[int][]byte // uploadID -> partNumber -> data
+ nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() *httptest.Server {
- m := &mockS3Server{objects: make(map[string][]byte)}
+ m := &mockS3Server{
+ objects: make(map[string][]byte),
+ uploads: make(map[string]map[int][]byte),
+ }
return httptest.NewTLSServer(m)
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
+ q := r.URL.Query()
switch {
+ case r.Method == http.MethodPut && q.Has("partNumber"):
+ m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
- case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2":
+ case r.Method == http.MethodPost && q.Has("uploads"):
+ m.handleInitiateMultipart(w, r, path)
+ case r.Method == http.MethodPost && q.Has("uploadId"):
+ m.handleCompleteMultipart(w, r, path)
+ case r.Method == http.MethodDelete && q.Has("uploadId"):
+ m.handleAbortMultipart(w, r, path)
+ case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
- case r.Method == http.MethodPost && r.URL.Query().Has("delete"):
+ case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
@@ -206,6 +220,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
w.WriteHeader(http.StatusOK)
}
+func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ m.mu.Lock()
+ m.nextID++
+ uploadID := fmt.Sprintf("upload-%d", m.nextID)
+ m.uploads[uploadID] = make(map[int][]byte)
+ m.mu.Unlock()
+
+ w.Header().Set("Content-Type", "application/xml")
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `%s`, uploadID)
+}
+
+func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+ var partNumber int
+ fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
+
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ m.mu.Lock()
+ parts, ok := m.uploads[uploadID]
+ if !ok {
+ m.mu.Unlock()
+ http.Error(w, "NoSuchUpload", http.StatusNotFound)
+ return
+ }
+ parts[partNumber] = body
+ m.mu.Unlock()
+
+ etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
+ w.Header().Set("ETag", etag)
+ w.WriteHeader(http.StatusOK)
+}
+
+func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+
+ m.mu.Lock()
+ parts, ok := m.uploads[uploadID]
+ if !ok {
+ m.mu.Unlock()
+ http.Error(w, "NoSuchUpload", http.StatusNotFound)
+ return
+ }
+
+ // Assemble parts in order
+ var assembled []byte
+ for i := 1; i <= len(parts); i++ {
+ assembled = append(assembled, parts[i]...)
+ }
+ m.objects[path] = assembled
+ delete(m.uploads, uploadID)
+ m.mu.Unlock()
+
+ w.Header().Set("Content-Type", "application/xml")
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `%s`, path)
+}
+
+func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+ m.mu.Lock()
+ delete(m.uploads, uploadID)
+ m.mu.Unlock()
+ w.WriteHeader(http.StatusNoContent)
+}
+
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
diff --git a/s3/client.go b/s3/client.go
index 7fdd8093..29e10d2d 100644
--- a/s3/client.go
+++ b/s3/client.go
@@ -10,6 +10,7 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/xml"
+ "errors"
"fmt"
"io"
"net/http"
@@ -50,25 +51,22 @@ func New(config *Config) *Client {
}
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
-// configured prefix. The body size must be known in advance. The payload is sent as
-// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS.
-func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error {
- fullKey := c.objectKey(key)
- req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
+// configured prefix. The body size does not need to be known in advance.
+//
+// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
+// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
+// into memory.
+func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
+ first := make([]byte, partSize)
+ n, err := io.ReadFull(body, first)
+ if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
+ return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
+ }
if err != nil {
- return fmt.Errorf("s3: PutObject request: %w", err)
+ return fmt.Errorf("s3: PutObject read: %w", err)
}
- req.ContentLength = size
- c.signV4(req, unsignedPayload)
- resp, err := c.httpClient().Do(req)
- if err != nil {
- return fmt.Errorf("s3: PutObject: %w", err)
- }
- defer resp.Body.Close()
- if resp.StatusCode/100 != 2 {
- return parseError(resp)
- }
- return nil
+ combined := io.MultiReader(bytes.NewReader(first), body)
+ return c.putObjectMultipart(ctx, key, combined)
}
// GetObject downloads an object. The key is automatically prefixed with the client's configured
@@ -216,6 +214,171 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
}
+// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
+func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
+ fullKey := c.objectKey(key)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
+ if err != nil {
+ return fmt.Errorf("s3: PutObject request: %w", err)
+ }
+ req.ContentLength = size
+ c.signV4(req, unsignedPayload)
+ resp, err := c.httpClient().Do(req)
+ if err != nil {
+ return fmt.Errorf("s3: PutObject: %w", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode/100 != 2 {
+ return parseError(resp)
+ }
+ return nil
+}
+
+// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
+// chunks, uploading each as a separate part. This allows uploading without knowing the total
+// body size in advance.
+func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
+ fullKey := c.objectKey(key)
+
+ // Step 1: Initiate multipart upload
+ uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
+ if err != nil {
+ return err
+ }
+
+ // Step 2: Upload parts
+ var parts []completedPart
+ buf := make([]byte, partSize)
+ partNumber := 1
+ for {
+ n, err := io.ReadFull(body, buf)
+ if n > 0 {
+ etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
+ if uploadErr != nil {
+ c.abortMultipartUpload(ctx, fullKey, uploadID)
+ return uploadErr
+ }
+ parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
+ partNumber++
+ }
+ if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
+ break
+ }
+ if err != nil {
+ c.abortMultipartUpload(ctx, fullKey, uploadID)
+ return fmt.Errorf("s3: PutObject read: %w", err)
+ }
+ }
+
+ // Step 3: Complete multipart upload
+ return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
+}
+
+// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
+func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
+ reqURL := c.objectURL(fullKey) + "?uploads"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
+ if err != nil {
+ return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
+ }
+ req.ContentLength = 0
+ c.signV4(req, emptyPayloadHash)
+ resp, err := c.httpClient().Do(req)
+ if err != nil {
+ return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode/100 != 2 {
+ return "", parseError(resp)
+ }
+ respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
+ if err != nil {
+ return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
+ }
+ var result initiateMultipartUploadResult
+ if err := xml.Unmarshal(respBody, &result); err != nil {
+ return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
+ }
+ return result.UploadID, nil
+}
+
+// uploadPart uploads a single part of a multipart upload and returns the ETag.
+func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
+ reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
+ req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
+ if err != nil {
+ return "", fmt.Errorf("s3: UploadPart request: %w", err)
+ }
+ req.ContentLength = int64(len(data))
+ c.signV4(req, unsignedPayload)
+ resp, err := c.httpClient().Do(req)
+ if err != nil {
+ return "", fmt.Errorf("s3: UploadPart: %w", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode/100 != 2 {
+ return "", parseError(resp)
+ }
+ etag := resp.Header.Get("ETag")
+ return etag, nil
+}
+
+// completeMultipartUpload finalizes a multipart upload with the given parts.
+func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
+ var body bytes.Buffer
+ body.WriteString("")
+ for _, p := range parts {
+ fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag)
+ }
+ body.WriteString("")
+ bodyBytes := body.Bytes()
+ payloadHash := sha256Hex(bodyBytes)
+
+ reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
+ }
+ req.ContentLength = int64(len(bodyBytes))
+ req.Header.Set("Content-Type", "application/xml")
+ c.signV4(req, payloadHash)
+ resp, err := c.httpClient().Do(req)
+ if err != nil {
+ return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode/100 != 2 {
+ return parseError(resp)
+ }
+ // Read response body to check for errors (S3 can return 200 with an error body)
+ respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
+ if err != nil {
+ return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
+ }
+ // Check if the response contains an error
+ var errResp ErrorResponse
+ if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
+ errResp.StatusCode = resp.StatusCode
+ return &errResp
+ }
+ return nil
+}
+
+// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
+func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
+ reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
+ req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
+ if err != nil {
+ return
+ }
+ c.signV4(req, emptyPayloadHash)
+ resp, err := c.httpClient().Do(req)
+ if err != nil {
+ return
+ }
+ resp.Body.Close()
+}
+
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
diff --git a/s3/client_test.go b/s3/client_test.go
index f4a10213..c3a8fe2c 100644
--- a/s3/client_test.go
+++ b/s3/client_test.go
@@ -23,27 +23,41 @@ import (
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct {
- objects map[string][]byte // full key (bucket/key) -> body
+ objects map[string][]byte // full key (bucket/key) -> body
+ uploads map[string]map[int][]byte // uploadID -> partNumber -> data
+ nextID int // counter for generating upload IDs
mu sync.RWMutex
}
func newMockS3Server() (*httptest.Server, *mockS3Server) {
- m := &mockS3Server{objects: make(map[string][]byte)}
+ m := &mockS3Server{
+ objects: make(map[string][]byte),
+ uploads: make(map[string]map[int][]byte),
+ }
return httptest.NewTLSServer(m), m
}
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/")
+ q := r.URL.Query()
switch {
+ case r.Method == http.MethodPut && q.Has("partNumber"):
+ m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut:
m.handlePut(w, r, path)
- case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2":
+ case r.Method == http.MethodPost && q.Has("uploads"):
+ m.handleInitiateMultipart(w, r, path)
+ case r.Method == http.MethodPost && q.Has("uploadId"):
+ m.handleCompleteMultipart(w, r, path)
+ case r.Method == http.MethodDelete && q.Has("uploadId"):
+ m.handleAbortMultipart(w, r, path)
+ case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path)
case r.Method == http.MethodGet:
m.handleGet(w, r, path)
- case r.Method == http.MethodPost && r.URL.Query().Has("delete"):
+ case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path)
default:
http.Error(w, "not implemented", http.StatusNotImplemented)
@@ -62,6 +76,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
w.WriteHeader(http.StatusOK)
}
+func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ m.mu.Lock()
+ m.nextID++
+ uploadID := fmt.Sprintf("upload-%d", m.nextID)
+ m.uploads[uploadID] = make(map[int][]byte)
+ m.mu.Unlock()
+
+ w.Header().Set("Content-Type", "application/xml")
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `%s`, uploadID)
+}
+
+func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+ var partNumber int
+ fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
+
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ m.mu.Lock()
+ parts, ok := m.uploads[uploadID]
+ if !ok {
+ m.mu.Unlock()
+ http.Error(w, "NoSuchUpload", http.StatusNotFound)
+ return
+ }
+ parts[partNumber] = body
+ m.mu.Unlock()
+
+ etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
+ w.Header().Set("ETag", etag)
+ w.WriteHeader(http.StatusOK)
+}
+
+func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+
+ m.mu.Lock()
+ parts, ok := m.uploads[uploadID]
+ if !ok {
+ m.mu.Unlock()
+ http.Error(w, "NoSuchUpload", http.StatusNotFound)
+ return
+ }
+
+ // Assemble parts in order
+ var assembled []byte
+ for i := 1; i <= len(parts); i++ {
+ assembled = append(assembled, parts[i]...)
+ }
+ m.objects[path] = assembled
+ delete(m.uploads, uploadID)
+ m.mu.Unlock()
+
+ w.Header().Set("Content-Type", "application/xml")
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `%s`, path)
+}
+
+func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
+ uploadID := r.URL.Query().Get("uploadId")
+ m.mu.Lock()
+ delete(m.uploads, uploadID)
+ m.mu.Unlock()
+ w.WriteHeader(http.StatusNoContent)
+}
+
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock()
body, ok := m.objects[path]
@@ -333,7 +418,7 @@ func TestClient_PutGetObject(t *testing.T) {
ctx := context.Background()
// Put
- err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11)
+ err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
require.Nil(t, err)
// Get
@@ -353,7 +438,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) {
ctx := context.Background()
- err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5)
+ err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key")
@@ -385,7 +470,7 @@ func TestClient_DeleteObjects(t *testing.T) {
// Put several objects
for i := 0; i < 5; i++ {
- err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4)
+ err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
require.Nil(t, err)
}
require.Equal(t, 5, mock.objectCount())
@@ -416,13 +501,13 @@ func TestClient_ListObjects(t *testing.T) {
// Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ {
- err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1)
+ err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
// Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "")
- err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1)
+ err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
require.Nil(t, err)
// List with prefix client: should only see 3
@@ -446,7 +531,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
// Put 5 objects
for i := 0; i < 5; i++ {
- err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1)
+ err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
@@ -478,7 +563,7 @@ func TestClient_ListAllObjects(t *testing.T) {
ctx := context.Background()
for i := 0; i < 10; i++ {
- err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1)
+ err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
@@ -499,7 +584,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
for i := range data {
data[i] = byte(i % 256)
}
- err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data)))
+ err := client.PutObject(ctx, "large", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large")
@@ -511,6 +596,54 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
require.Equal(t, data, got)
}
+func TestClient_PutObject_ChunkedUpload(t *testing.T) {
+ server, _ := newMockS3Server()
+ defer server.Close()
+ client := newTestClient(server, "my-bucket", "")
+
+ ctx := context.Background()
+
+ // 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
+ data := make([]byte, 12*1024*1024)
+ for i := range data {
+ data[i] = byte(i % 256)
+ }
+ err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
+ require.Nil(t, err)
+
+ reader, size, err := client.GetObject(ctx, "multipart")
+ require.Nil(t, err)
+ require.Equal(t, int64(12*1024*1024), size)
+ got, err := io.ReadAll(reader)
+ reader.Close()
+ require.Nil(t, err)
+ require.Equal(t, data, got)
+}
+
+func TestClient_PutObject_ExactPartSize(t *testing.T) {
+ server, _ := newMockS3Server()
+ defer server.Close()
+ client := newTestClient(server, "my-bucket", "")
+
+ ctx := context.Background()
+
+ // Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
+ data := make([]byte, 5*1024*1024)
+ for i := range data {
+ data[i] = byte(i % 256)
+ }
+ err := client.PutObject(ctx, "exact", bytes.NewReader(data))
+ require.Nil(t, err)
+
+ reader, size, err := client.GetObject(ctx, "exact")
+ require.Nil(t, err)
+ require.Equal(t, int64(5*1024*1024), size)
+ got, err := io.ReadAll(reader)
+ reader.Close()
+ require.Nil(t, err)
+ require.Equal(t, data, got)
+}
+
func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
@@ -518,7 +651,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) {
ctx := context.Background()
- err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6)
+ err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
@@ -548,7 +681,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) {
for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx)
- err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1)
+ err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
require.Nil(t, err)
}
}
@@ -647,7 +780,7 @@ func TestClient_RealBucket(t *testing.T) {
content := "hello from ntfy s3 test"
// Put
- err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content)))
+ err := client.PutObject(ctx, key, strings.NewReader(content))
require.Nil(t, err)
// Get
@@ -685,7 +818,7 @@ func TestClient_RealBucket(t *testing.T) {
// Put 10 objects
for i := 0; i < 10; i++ {
- err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1)
+ err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
require.Nil(t, err)
}
@@ -710,7 +843,7 @@ func TestClient_RealBucket(t *testing.T) {
data[i] = byte(i % 256)
}
- err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data)))
+ err := client.PutObject(ctx, key, bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key)
diff --git a/s3/types.go b/s3/types.go
index 5929ec6c..201c570b 100644
--- a/s3/types.go
+++ b/s3/types.go
@@ -63,3 +63,14 @@ type deleteError struct {
Code string `xml:"Code"`
Message string `xml:"Message"`
}
+
+// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
+type initiateMultipartUploadResult struct {
+ UploadID string `xml:"UploadId"`
+}
+
+// completedPart represents a successfully uploaded part for CompleteMultipartUpload
+type completedPart struct {
+ PartNumber int
+ ETag string
+}
diff --git a/s3/util.go b/s3/util.go
index cf9d4ba8..c24c1c5b 100644
--- a/s3/util.go
+++ b/s3/util.go
@@ -22,6 +22,11 @@ const (
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
maxResponseBytes = 10 * 1024 * 1024
+
+ // partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
+ // above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
+ // part size of 5 MB for all parts except the last.
+ partSize = 5 * 1024 * 1024
)
// ParseURL parses an S3 URL of the form:
diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go
index 697d4e71..1dbac0cf 100644
--- a/tools/s3cli/main.go
+++ b/tools/s3cli/main.go
@@ -58,44 +58,21 @@ func cmdPut(ctx context.Context, client *s3.Client) {
path := os.Args[3]
var r io.Reader
- var size int64
if path == "-" {
- // Read stdin into a temp file to get the size
- tmp, err := os.CreateTemp("", "s3cli-*")
- if err != nil {
- fail("create temp file: %s", err)
- }
- defer os.Remove(tmp.Name())
- n, err := io.Copy(tmp, os.Stdin)
- if err != nil {
- tmp.Close()
- fail("read stdin: %s", err)
- }
- if _, err := tmp.Seek(0, io.SeekStart); err != nil {
- tmp.Close()
- fail("seek: %s", err)
- }
- r = tmp
- size = n
- defer tmp.Close()
+ r = os.Stdin
} else {
f, err := os.Open(path)
if err != nil {
fail("open %s: %s", path, err)
}
defer f.Close()
- info, err := f.Stat()
- if err != nil {
- fail("stat %s: %s", path, err)
- }
r = f
- size = info.Size()
}
- if err := client.PutObject(ctx, key, r, size); err != nil {
+ if err := client.PutObject(ctx, key, r); err != nil {
fail("put: %s", err)
}
- fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size)
+ fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
}
func cmdGet(ctx context.Context, client *s3.Client) {