Multipart upload

This commit is contained in:
binwiederhier
2026-03-16 20:00:19 -04:00
parent 458fbad770
commit 86015e100c
8 changed files with 462 additions and 90 deletions

1
.gitignore vendored
View File

@@ -9,6 +9,7 @@ server/site/
tools/fbsend/fbsend tools/fbsend/fbsend
tools/pgimport/pgimport tools/pgimport/pgimport
tools/loadtest/loadtest tools/loadtest/loadtest
tools/s3cli/s3cli
playground/ playground/
secrets/ secrets/
*.iml *.iml

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"os"
"sync" "sync"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
@@ -51,33 +50,31 @@ func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int6
} }
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
// Write through limiters into a temp file. This avoids buffering the full attachment in // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
// memory while still giving us the Content-Length that PutObject requires. // uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*") pr, pw := io.Pipe()
if err != nil { lw := util.NewLimitWriter(pw, limiters...)
return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err) var size int64
var copyErr error
done := make(chan struct{})
go func() {
defer close(done)
size, copyErr = io.Copy(lw, in)
if copyErr != nil {
pw.CloseWithError(copyErr)
} else {
pw.Close()
}
}()
putErr := c.client.PutObject(context.Background(), id, pr)
pr.Close()
<-done
if copyErr != nil {
return 0, copyErr
} }
tmpPath := tmpFile.Name() if putErr != nil {
defer os.Remove(tmpPath) return 0, putErr
limitWriter := util.NewLimitWriter(tmpFile, limiters...)
size, err := io.Copy(limitWriter, in)
if err != nil {
tmpFile.Close()
return 0, err
}
if err := tmpFile.Close(); err != nil {
return 0, err
}
// Re-open the temp file for reading and stream it to S3
f, err := os.Open(tmpPath)
if err != nil {
return 0, err
}
defer f.Close()
if err := c.client.PutObject(context.Background(), id, f, size); err != nil {
return 0, err
} }
c.mu.Lock() c.mu.Lock()
c.totalSizeCurrent += size c.totalSizeCurrent += size

View File

@@ -167,27 +167,41 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct { type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex mu sync.RWMutex
} }
func newMockS3Server() *httptest.Server { func newMockS3Server() *httptest.Server {
m := &mockS3Server{objects: make(map[string][]byte)} m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m) return httptest.NewTLSServer(m)
} }
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}] // Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/") path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch { switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut: case r.Method == http.MethodPut:
m.handlePut(w, r, path) m.handlePut(w, r, path)
case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path) m.handleList(w, r, path)
case r.Method == http.MethodGet: case r.Method == http.MethodGet:
m.handleGet(w, r, path) m.handleGet(w, r, path)
case r.Method == http.MethodPost && r.URL.Query().Has("delete"): case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path) m.handleDelete(w, r, path)
default: default:
http.Error(w, "not implemented", http.StatusNotImplemented) http.Error(w, "not implemented", http.StatusNotImplemented)
@@ -206,6 +220,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock() m.mu.RLock()
body, ok := m.objects[path] body, ok := m.objects[path]

View File

@@ -10,6 +10,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -50,25 +51,22 @@ func New(config *Config) *Client {
} }
// PutObject uploads body to the given key. The key is automatically prefixed with the client's // PutObject uploads body to the given key. The key is automatically prefixed with the client's
// configured prefix. The body size must be known in advance. The payload is sent as // configured prefix. The body size does not need to be known in advance.
// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS. //
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error { // If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
fullKey := c.objectKey(key) // Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) // into memory.
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
first := make([]byte, partSize)
n, err := io.ReadFull(body, first)
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
}
if err != nil { if err != nil {
return fmt.Errorf("s3: PutObject request: %w", err) return fmt.Errorf("s3: PutObject read: %w", err)
} }
req.ContentLength = size combined := io.MultiReader(bytes.NewReader(first), body)
c.signV4(req, unsignedPayload) return c.putObjectMultipart(ctx, key, combined)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
return nil
} }
// GetObject downloads an object. The key is automatically prefixed with the client's configured // GetObject downloads an object. The key is automatically prefixed with the client's configured
@@ -216,6 +214,171 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
} }
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
fullKey := c.objectKey(key)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
if err != nil {
return fmt.Errorf("s3: PutObject request: %w", err)
}
req.ContentLength = size
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
if err != nil {
return err
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
reqURL := c.objectURL(fullKey) + "?uploads"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
}
req.ContentLength = 0
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", parseError(resp)
}
etag := resp.Header.Get("ETag")
return etag, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return parseError(resp)
}
// Read response body to check for errors (S3 can return 200 with an error body)
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
}
// Check if the response contains an error
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
errResp.StatusCode = resp.StatusCode
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 // signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. // of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) { func (c *Client) signV4(req *http.Request, payloadHash string) {

View File

@@ -23,27 +23,41 @@ import (
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
type mockS3Server struct { type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
mu sync.RWMutex mu sync.RWMutex
} }
func newMockS3Server() (*httptest.Server, *mockS3Server) { func newMockS3Server() (*httptest.Server, *mockS3Server) {
m := &mockS3Server{objects: make(map[string][]byte)} m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
}
return httptest.NewTLSServer(m), m return httptest.NewTLSServer(m), m
} }
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Path is /{bucket}[/{key...}] // Path is /{bucket}[/{key...}]
path := strings.TrimPrefix(r.URL.Path, "/") path := strings.TrimPrefix(r.URL.Path, "/")
q := r.URL.Query()
switch { switch {
case r.Method == http.MethodPut && q.Has("partNumber"):
m.handleUploadPart(w, r, path)
case r.Method == http.MethodPut: case r.Method == http.MethodPut:
m.handlePut(w, r, path) m.handlePut(w, r, path)
case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": case r.Method == http.MethodPost && q.Has("uploads"):
m.handleInitiateMultipart(w, r, path)
case r.Method == http.MethodPost && q.Has("uploadId"):
m.handleCompleteMultipart(w, r, path)
case r.Method == http.MethodDelete && q.Has("uploadId"):
m.handleAbortMultipart(w, r, path)
case r.Method == http.MethodGet && q.Get("list-type") == "2":
m.handleList(w, r, path) m.handleList(w, r, path)
case r.Method == http.MethodGet: case r.Method == http.MethodGet:
m.handleGet(w, r, path) m.handleGet(w, r, path)
case r.Method == http.MethodPost && r.URL.Query().Has("delete"): case r.Method == http.MethodPost && q.Has("delete"):
m.handleDelete(w, r, path) m.handleDelete(w, r, path)
default: default:
http.Error(w, "not implemented", http.StatusNotImplemented) http.Error(w, "not implemented", http.StatusNotImplemented)
@@ -62,6 +76,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
m.mu.Lock()
m.nextID++
uploadID := fmt.Sprintf("upload-%d", m.nextID)
m.uploads[uploadID] = make(map[int][]byte)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
}
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
var partNumber int
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
parts[partNumber] = body
m.mu.Unlock()
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
w.Header().Set("ETag", etag)
w.WriteHeader(http.StatusOK)
}
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
parts, ok := m.uploads[uploadID]
if !ok {
m.mu.Unlock()
http.Error(w, "NoSuchUpload", http.StatusNotFound)
return
}
// Assemble parts in order
var assembled []byte
for i := 1; i <= len(parts); i++ {
assembled = append(assembled, parts[i]...)
}
m.objects[path] = assembled
delete(m.uploads, uploadID)
m.mu.Unlock()
w.Header().Set("Content-Type", "application/xml")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
}
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
uploadID := r.URL.Query().Get("uploadId")
m.mu.Lock()
delete(m.uploads, uploadID)
m.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
}
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
m.mu.RLock() m.mu.RLock()
body, ok := m.objects[path] body, ok := m.objects[path]
@@ -333,7 +418,7 @@ func TestClient_PutGetObject(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Put // Put
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11) err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
require.Nil(t, err) require.Nil(t, err)
// Get // Get
@@ -353,7 +438,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) {
ctx := context.Background() ctx := context.Background()
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5) err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
require.Nil(t, err) require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "test-key") reader, _, err := client.GetObject(ctx, "test-key")
@@ -385,7 +470,7 @@ func TestClient_DeleteObjects(t *testing.T) {
// Put several objects // Put several objects
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4) err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
require.Nil(t, err) require.Nil(t, err)
} }
require.Equal(t, 5, mock.objectCount()) require.Equal(t, 5, mock.objectCount())
@@ -416,13 +501,13 @@ func TestClient_ListObjects(t *testing.T) {
// Client with prefix "pfx": list should only return objects under pfx/ // Client with prefix "pfx": list should only return objects under pfx/
client := newTestClient(server, "my-bucket", "pfx") client := newTestClient(server, "my-bucket", "pfx")
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1) err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err) require.Nil(t, err)
} }
// Also put an object outside the prefix using a no-prefix client // Also put an object outside the prefix using a no-prefix client
clientNoPrefix := newTestClient(server, "my-bucket", "") clientNoPrefix := newTestClient(server, "my-bucket", "")
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1) err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
require.Nil(t, err) require.Nil(t, err)
// List with prefix client: should only see 3 // List with prefix client: should only see 3
@@ -446,7 +531,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
// Put 5 objects // Put 5 objects
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err) require.Nil(t, err)
} }
@@ -478,7 +563,7 @@ func TestClient_ListAllObjects(t *testing.T) {
ctx := context.Background() ctx := context.Background()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
require.Nil(t, err) require.Nil(t, err)
} }
@@ -499,7 +584,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
for i := range data { for i := range data {
data[i] = byte(i % 256) data[i] = byte(i % 256)
} }
err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data))) err := client.PutObject(ctx, "large", bytes.NewReader(data))
require.Nil(t, err) require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "large") reader, size, err := client.GetObject(ctx, "large")
@@ -511,6 +596,54 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
require.Equal(t, data, got) require.Equal(t, data, got)
} }
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
data := make([]byte, 12*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "multipart")
require.Nil(t, err)
require.Equal(t, int64(12*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_ExactPartSize(t *testing.T) {
server, _ := newMockS3Server()
defer server.Close()
client := newTestClient(server, "my-bucket", "")
ctx := context.Background()
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
data := make([]byte, 5*1024*1024)
for i := range data {
data[i] = byte(i % 256)
}
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
require.Nil(t, err)
reader, size, err := client.GetObject(ctx, "exact")
require.Nil(t, err)
require.Equal(t, int64(5*1024*1024), size)
got, err := io.ReadAll(reader)
reader.Close()
require.Nil(t, err)
require.Equal(t, data, got)
}
func TestClient_PutObject_NestedKey(t *testing.T) { func TestClient_PutObject_NestedKey(t *testing.T) {
server, _ := newMockS3Server() server, _ := newMockS3Server()
defer server.Close() defer server.Close()
@@ -518,7 +651,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) {
ctx := context.Background() ctx := context.Background()
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6) err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
require.Nil(t, err) require.Nil(t, err)
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
@@ -548,7 +681,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) {
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
idx := batch*batchSize + i idx := batch*batchSize + i
key := fmt.Sprintf("%08d", idx) key := fmt.Sprintf("%08d", idx)
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1) err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
require.Nil(t, err) require.Nil(t, err)
} }
} }
@@ -647,7 +780,7 @@ func TestClient_RealBucket(t *testing.T) {
content := "hello from ntfy s3 test" content := "hello from ntfy s3 test"
// Put // Put
err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content))) err := client.PutObject(ctx, key, strings.NewReader(content))
require.Nil(t, err) require.Nil(t, err)
// Get // Get
@@ -685,7 +818,7 @@ func TestClient_RealBucket(t *testing.T) {
// Put 10 objects // Put 10 objects
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1) err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
require.Nil(t, err) require.Nil(t, err)
} }
@@ -710,7 +843,7 @@ func TestClient_RealBucket(t *testing.T) {
data[i] = byte(i % 256) data[i] = byte(i % 256)
} }
err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data))) err := client.PutObject(ctx, key, bytes.NewReader(data))
require.Nil(t, err) require.Nil(t, err)
reader, size, err := client.GetObject(ctx, key) reader, size, err := client.GetObject(ctx, key)

View File

@@ -63,3 +63,14 @@ type deleteError struct {
Code string `xml:"Code"` Code string `xml:"Code"`
Message string `xml:"Message"` Message string `xml:"Message"`
} }
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
type initiateMultipartUploadResult struct {
UploadID string `xml:"UploadId"`
}
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
type completedPart struct {
PartNumber int
ETag string
}

View File

@@ -22,6 +22,11 @@ const (
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB) // maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
maxResponseBytes = 10 * 1024 * 1024 maxResponseBytes = 10 * 1024 * 1024
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
// part size of 5 MB for all parts except the last.
partSize = 5 * 1024 * 1024
) )
// ParseURL parses an S3 URL of the form: // ParseURL parses an S3 URL of the form:

View File

@@ -58,44 +58,21 @@ func cmdPut(ctx context.Context, client *s3.Client) {
path := os.Args[3] path := os.Args[3]
var r io.Reader var r io.Reader
var size int64
if path == "-" { if path == "-" {
// Read stdin into a temp file to get the size r = os.Stdin
tmp, err := os.CreateTemp("", "s3cli-*")
if err != nil {
fail("create temp file: %s", err)
}
defer os.Remove(tmp.Name())
n, err := io.Copy(tmp, os.Stdin)
if err != nil {
tmp.Close()
fail("read stdin: %s", err)
}
if _, err := tmp.Seek(0, io.SeekStart); err != nil {
tmp.Close()
fail("seek: %s", err)
}
r = tmp
size = n
defer tmp.Close()
} else { } else {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
fail("open %s: %s", path, err) fail("open %s: %s", path, err)
} }
defer f.Close() defer f.Close()
info, err := f.Stat()
if err != nil {
fail("stat %s: %s", path, err)
}
r = f r = f
size = info.Size()
} }
if err := client.PutObject(ctx, key, r, size); err != nil { if err := client.PutObject(ctx, key, r); err != nil {
fail("put: %s", err) fail("put: %s", err)
} }
fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size) fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
} }
func cmdGet(ctx context.Context, client *s3.Client) { func cmdGet(ctx context.Context, client *s3.Client) {