mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
Multipart upload
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,7 @@ server/site/
|
||||
tools/fbsend/fbsend
|
||||
tools/pgimport/pgimport
|
||||
tools/loadtest/loadtest
|
||||
tools/s3cli/s3cli
|
||||
playground/
|
||||
secrets/
|
||||
*.iml
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
@@ -51,33 +50,31 @@ func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int6
|
||||
}
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
|
||||
|
||||
// Write through limiters into a temp file. This avoids buffering the full attachment in
|
||||
// memory while still giving us the Content-Length that PutObject requires.
|
||||
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
|
||||
// uploads, so no temp file or Content-Length is needed.
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err)
|
||||
pr, pw := io.Pipe()
|
||||
lw := util.NewLimitWriter(pw, limiters...)
|
||||
var size int64
|
||||
var copyErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
size, copyErr = io.Copy(lw, in)
|
||||
if copyErr != nil {
|
||||
pw.CloseWithError(copyErr)
|
||||
} else {
|
||||
pw.Close()
|
||||
}
|
||||
}()
|
||||
putErr := c.client.PutObject(context.Background(), id, pr)
|
||||
pr.Close()
|
||||
<-done
|
||||
if copyErr != nil {
|
||||
return 0, copyErr
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
limitWriter := util.NewLimitWriter(tmpFile, limiters...)
|
||||
size, err := io.Copy(limitWriter, in)
|
||||
if err != nil {
|
||||
tmpFile.Close()
|
||||
return 0, err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Re-open the temp file for reading and stream it to S3
|
||||
f, err := os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
if err := c.client.PutObject(context.Background(), id, f, size); err != nil {
|
||||
return 0, err
|
||||
if putErr != nil {
|
||||
return 0, putErr
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
|
||||
@@ -167,27 +167,41 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() *httptest.Server {
|
||||
m := &mockS3Server{objects: make(map[string][]byte)}
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2":
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && r.URL.Query().Has("delete"):
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
@@ -206,6 +220,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
|
||||
197
s3/client.go
197
s3/client.go
@@ -10,6 +10,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -50,25 +51,22 @@ func New(config *Config) *Client {
|
||||
}
|
||||
|
||||
// PutObject uploads body to the given key. The key is automatically prefixed with the client's
|
||||
// configured prefix. The body size must be known in advance. The payload is sent as
|
||||
// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS.
|
||||
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
|
||||
// configured prefix. The body size does not need to be known in advance.
|
||||
//
|
||||
// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request.
|
||||
// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time
|
||||
// into memory.
|
||||
func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error {
|
||||
first := make([]byte, partSize)
|
||||
n, err := io.ReadFull(body, first)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF {
|
||||
return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n))
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject request: %w", err)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
req.ContentLength = size
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
return nil
|
||||
combined := io.MultiReader(bytes.NewReader(first), body)
|
||||
return c.putObjectMultipart(ctx, key, combined)
|
||||
}
|
||||
|
||||
// GetObject downloads an object. The key is automatically prefixed with the client's configured
|
||||
@@ -216,6 +214,171 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
|
||||
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
|
||||
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
|
||||
fullKey := c.objectKey(key)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject request: %w", err)
|
||||
}
|
||||
req.ContentLength = size
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
|
||||
// chunks, uploading each as a separate part. This allows uploading without knowing the total
|
||||
// body size in advance.
|
||||
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
||||
// Step 1: Initiate multipart upload
|
||||
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Upload parts
|
||||
var parts []completedPart
|
||||
buf := make([]byte, partSize)
|
||||
partNumber := 1
|
||||
for {
|
||||
n, err := io.ReadFull(body, buf)
|
||||
if n > 0 {
|
||||
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
|
||||
if uploadErr != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return uploadErr
|
||||
}
|
||||
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
|
||||
partNumber++
|
||||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
|
||||
}
|
||||
|
||||
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
|
||||
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
|
||||
reqURL := c.objectURL(fullKey) + "?uploads"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = 0
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
|
||||
}
|
||||
var result initiateMultipartUploadResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
|
||||
}
|
||||
return result.UploadID, nil
|
||||
}
|
||||
|
||||
// uploadPart uploads a single part of a multipart upload and returns the ETag.
|
||||
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(data))
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
etag := resp.Header.Get("ETag")
|
||||
return etag, nil
|
||||
}
|
||||
|
||||
// completeMultipartUpload finalizes a multipart upload with the given parts.
|
||||
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<CompleteMultipartUpload>")
|
||||
for _, p := range parts {
|
||||
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
|
||||
}
|
||||
body.WriteString("</CompleteMultipartUpload>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
return parseError(resp)
|
||||
}
|
||||
// Read response body to check for errors (S3 can return 200 with an error body)
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
|
||||
}
|
||||
// Check if the response contains an error
|
||||
var errResp ErrorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
errResp.StatusCode = resp.StatusCode
|
||||
return &errResp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
|
||||
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
|
||||
func (c *Client) signV4(req *http.Request, payloadHash string) {
|
||||
|
||||
@@ -23,27 +23,41 @@ import (
|
||||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() (*httptest.Server, *mockS3Server) {
|
||||
m := &mockS3Server{objects: make(map[string][]byte)}
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
}
|
||||
return httptest.NewTLSServer(m), m
|
||||
}
|
||||
|
||||
func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Path is /{bucket}[/{key...}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
q := r.URL.Query()
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodPut && q.Has("partNumber"):
|
||||
m.handleUploadPart(w, r, path)
|
||||
case r.Method == http.MethodPut:
|
||||
m.handlePut(w, r, path)
|
||||
case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2":
|
||||
case r.Method == http.MethodPost && q.Has("uploads"):
|
||||
m.handleInitiateMultipart(w, r, path)
|
||||
case r.Method == http.MethodPost && q.Has("uploadId"):
|
||||
m.handleCompleteMultipart(w, r, path)
|
||||
case r.Method == http.MethodDelete && q.Has("uploadId"):
|
||||
m.handleAbortMultipart(w, r, path)
|
||||
case r.Method == http.MethodGet && q.Get("list-type") == "2":
|
||||
m.handleList(w, r, path)
|
||||
case r.Method == http.MethodGet:
|
||||
m.handleGet(w, r, path)
|
||||
case r.Method == http.MethodPost && r.URL.Query().Has("delete"):
|
||||
case r.Method == http.MethodPost && q.Has("delete"):
|
||||
m.handleDelete(w, r, path)
|
||||
default:
|
||||
http.Error(w, "not implemented", http.StatusNotImplemented)
|
||||
@@ -62,6 +76,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.Lock()
|
||||
m.nextID++
|
||||
uploadID := fmt.Sprintf("upload-%d", m.nextID)
|
||||
m.uploads[uploadID] = make(map[int][]byte)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><InitiateMultipartUploadResult><UploadId>%s</UploadId></InitiateMultipartUploadResult>`, uploadID)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
var partNumber int
|
||||
fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
parts[partNumber] = body
|
||||
m.mu.Unlock()
|
||||
|
||||
etag := fmt.Sprintf(`"etag-part-%d"`, partNumber)
|
||||
w.Header().Set("ETag", etag)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
|
||||
m.mu.Lock()
|
||||
parts, ok := m.uploads[uploadID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
http.Error(w, "NoSuchUpload", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Assemble parts in order
|
||||
var assembled []byte
|
||||
for i := 1; i <= len(parts); i++ {
|
||||
assembled = append(assembled, parts[i]...)
|
||||
}
|
||||
m.objects[path] = assembled
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUploadResult><Key>%s</Key></CompleteMultipartUploadResult>`, path)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) {
|
||||
uploadID := r.URL.Query().Get("uploadId")
|
||||
m.mu.Lock()
|
||||
delete(m.uploads, uploadID)
|
||||
m.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) {
|
||||
m.mu.RLock()
|
||||
body, ok := m.objects[path]
|
||||
@@ -333,7 +418,7 @@ func TestClient_PutGetObject(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11)
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
@@ -353,7 +438,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5)
|
||||
err := client.PutObject(ctx, "test-key", strings.NewReader("hello"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "test-key")
|
||||
@@ -385,7 +470,7 @@ func TestClient_DeleteObjects(t *testing.T) {
|
||||
|
||||
// Put several objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4)
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, 5, mock.objectCount())
|
||||
@@ -416,13 +501,13 @@ func TestClient_ListObjects(t *testing.T) {
|
||||
// Client with prefix "pfx": list should only return objects under pfx/
|
||||
client := newTestClient(server, "my-bucket", "pfx")
|
||||
for i := 0; i < 3; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1)
|
||||
err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Also put an object outside the prefix using a no-prefix client
|
||||
clientNoPrefix := newTestClient(server, "my-bucket", "")
|
||||
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1)
|
||||
err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")))
|
||||
require.Nil(t, err)
|
||||
|
||||
// List with prefix client: should only see 3
|
||||
@@ -446,7 +531,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
|
||||
|
||||
// Put 5 objects
|
||||
for i := 0; i < 5; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1)
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -478,7 +563,7 @@ func TestClient_ListAllObjects(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1)
|
||||
err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -499,7 +584,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data)))
|
||||
err := client.PutObject(ctx, "large", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "large")
|
||||
@@ -511,6 +596,54 @@ func TestClient_PutObject_LargeBody(t *testing.T) {
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ChunkedUpload(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 12 MB object, exceeds 5 MB partSize, triggers multipart upload path
|
||||
data := make([]byte, 12*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "multipart", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "multipart")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(12*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_ExactPartSize(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
client := newTestClient(server, "my-bucket", "")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully)
|
||||
data := make([]byte, 5*1024*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
err := client.PutObject(ctx, "exact", bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, "exact")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(5*1024*1024), size)
|
||||
got, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, data, got)
|
||||
}
|
||||
|
||||
func TestClient_PutObject_NestedKey(t *testing.T) {
|
||||
server, _ := newMockS3Server()
|
||||
defer server.Close()
|
||||
@@ -518,7 +651,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6)
|
||||
err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt")
|
||||
@@ -548,7 +681,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) {
|
||||
for i := 0; i < batchSize; i++ {
|
||||
idx := batch*batchSize + i
|
||||
key := fmt.Sprintf("%08d", idx)
|
||||
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1)
|
||||
err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
@@ -647,7 +780,7 @@ func TestClient_RealBucket(t *testing.T) {
|
||||
content := "hello from ntfy s3 test"
|
||||
|
||||
// Put
|
||||
err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content)))
|
||||
err := client.PutObject(ctx, key, strings.NewReader(content))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Get
|
||||
@@ -685,7 +818,7 @@ func TestClient_RealBucket(t *testing.T) {
|
||||
|
||||
// Put 10 objects
|
||||
for i := 0; i < 10; i++ {
|
||||
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1)
|
||||
err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -710,7 +843,7 @@ func TestClient_RealBucket(t *testing.T) {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data)))
|
||||
err := client.PutObject(ctx, key, bytes.NewReader(data))
|
||||
require.Nil(t, err)
|
||||
|
||||
reader, size, err := client.GetObject(ctx, key)
|
||||
|
||||
11
s3/types.go
11
s3/types.go
@@ -63,3 +63,14 @@ type deleteError struct {
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
}
|
||||
|
||||
// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload
|
||||
type initiateMultipartUploadResult struct {
|
||||
UploadID string `xml:"UploadId"`
|
||||
}
|
||||
|
||||
// completedPart represents a successfully uploaded part for CompleteMultipartUpload
|
||||
type completedPart struct {
|
||||
PartNumber int
|
||||
ETag string
|
||||
}
|
||||
|
||||
@@ -22,6 +22,11 @@ const (
|
||||
|
||||
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
|
||||
maxResponseBytes = 10 * 1024 * 1024
|
||||
|
||||
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
|
||||
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
|
||||
// part size of 5 MB for all parts except the last.
|
||||
partSize = 5 * 1024 * 1024
|
||||
)
|
||||
|
||||
// ParseURL parses an S3 URL of the form:
|
||||
|
||||
@@ -58,44 +58,21 @@ func cmdPut(ctx context.Context, client *s3.Client) {
|
||||
path := os.Args[3]
|
||||
|
||||
var r io.Reader
|
||||
var size int64
|
||||
if path == "-" {
|
||||
// Read stdin into a temp file to get the size
|
||||
tmp, err := os.CreateTemp("", "s3cli-*")
|
||||
if err != nil {
|
||||
fail("create temp file: %s", err)
|
||||
}
|
||||
defer os.Remove(tmp.Name())
|
||||
n, err := io.Copy(tmp, os.Stdin)
|
||||
if err != nil {
|
||||
tmp.Close()
|
||||
fail("read stdin: %s", err)
|
||||
}
|
||||
if _, err := tmp.Seek(0, io.SeekStart); err != nil {
|
||||
tmp.Close()
|
||||
fail("seek: %s", err)
|
||||
}
|
||||
r = tmp
|
||||
size = n
|
||||
defer tmp.Close()
|
||||
r = os.Stdin
|
||||
} else {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
fail("open %s: %s", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
fail("stat %s: %s", path, err)
|
||||
}
|
||||
r = f
|
||||
size = info.Size()
|
||||
}
|
||||
|
||||
if err := client.PutObject(ctx, key, r, size); err != nil {
|
||||
if err := client.PutObject(ctx, key, r); err != nil {
|
||||
fail("put: %s", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size)
|
||||
fmt.Fprintf(os.Stderr, "uploaded %s\n", key)
|
||||
}
|
||||
|
||||
func cmdGet(ctx context.Context, client *s3.Client) {
|
||||
|
||||
Reference in New Issue
Block a user