diff --git a/.gitignore b/.gitignore index ed17b2d4..6d5deb67 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ server/site/ tools/fbsend/fbsend tools/pgimport/pgimport tools/loadtest/loadtest +tools/s3cli/s3cli playground/ secrets/ *.iml diff --git a/attachment/store_s3.go b/attachment/store_s3.go index 5c47a81b..38f0353a 100644 --- a/attachment/store_s3.go +++ b/attachment/store_s3.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "os" "sync" "heckel.io/ntfy/v2/log" @@ -51,33 +50,31 @@ func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int6 } log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") - // Write through limiters into a temp file. This avoids buffering the full attachment in - // memory while still giving us the Content-Length that PutObject requires. + // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked + // uploads, so no temp file or Content-Length is needed. limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*") - if err != nil { - return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err) + pr, pw := io.Pipe() + lw := util.NewLimitWriter(pw, limiters...) + var size int64 + var copyErr error + done := make(chan struct{}) + go func() { + defer close(done) + size, copyErr = io.Copy(lw, in) + if copyErr != nil { + pw.CloseWithError(copyErr) + } else { + pw.Close() + } + }() + putErr := c.client.PutObject(context.Background(), id, pr) + pr.Close() + <-done + if copyErr != nil { + return 0, copyErr } - tmpPath := tmpFile.Name() - defer os.Remove(tmpPath) - limitWriter := util.NewLimitWriter(tmpFile, limiters...) - size, err := io.Copy(limitWriter, in) - if err != nil { - tmpFile.Close() - return 0, err - } - if err := tmpFile.Close(); err != nil { - return 0, err - } - - // Re-open the temp file for reading and stream it to S3 - f, err := os.Open(tmpPath) - if err != nil { - return 0, err - } - defer f.Close() - if err := c.client.PutObject(context.Background(), id, f, size); err != nil { - return 0, err + if putErr != nil { + return 0, putErr } c.mu.Lock() c.totalSizeCurrent += size diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index c898244d..872e8c23 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -167,27 +167,41 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs mu sync.RWMutex } func newMockS3Server() *httptest.Server { - m := &mockS3Server{objects: make(map[string][]byte)} + m := &mockS3Server{ + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + } return httptest.NewTLSServer(m) } func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Path is /{bucket}[/{key...}] path := strings.TrimPrefix(r.URL.Path, "/") + q := r.URL.Query() switch { + case r.Method == http.MethodPut && q.Has("partNumber"): + m.handleUploadPart(w, r, path) case r.Method == http.MethodPut: m.handlePut(w, r, path) - case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + case r.Method == http.MethodPost && q.Has("uploads"): + m.handleInitiateMultipart(w, r, path) + case r.Method == http.MethodPost && q.Has("uploadId"): + m.handleCompleteMultipart(w, r, path) + case r.Method == http.MethodDelete && q.Has("uploadId"): + m.handleAbortMultipart(w, r, path) + case r.Method == http.MethodGet && q.Get("list-type") == "2": m.handleList(w, r, path) case r.Method == http.MethodGet: m.handleGet(w, r, path) - case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + case r.Method == http.MethodPost && q.Has("delete"): m.handleDelete(w, r, path) default: http.Error(w, "not implemented", http.StatusNotImplemented) @@ -206,6 +220,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st w.WriteHeader(http.StatusOK) } +func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { + m.mu.Lock() + m.nextID++ + uploadID := fmt.Sprintf("upload-%d", m.nextID) + m.uploads[uploadID] = make(map[int][]byte) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, uploadID) +} + +func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + var partNumber int + fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + parts[partNumber] = body + m.mu.Unlock() + + etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) + w.Header().Set("ETag", etag) + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + + // Assemble parts in order + var assembled []byte + for i := 1; i <= len(parts); i++ { + assembled = append(assembled, parts[i]...) + } + m.objects[path] = assembled + delete(m.uploads, uploadID) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, path) +} + +func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + m.mu.Lock() + delete(m.uploads, uploadID) + m.mu.Unlock() + w.WriteHeader(http.StatusNoContent) +} + func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { m.mu.RLock() body, ok := m.objects[path] diff --git a/s3/client.go b/s3/client.go index 7fdd8093..29e10d2d 100644 --- a/s3/client.go +++ b/s3/client.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/xml" + "errors" "fmt" "io" "net/http" @@ -50,25 +51,22 @@ func New(config *Config) *Client { } // PutObject uploads body to the given key. The key is automatically prefixed with the client's -// configured prefix. The body size must be known in advance. The payload is sent as -// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS. -func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error { - fullKey := c.objectKey(key) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) +// configured prefix. The body size does not need to be known in advance. +// +// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request. +// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time +// into memory. +func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error { + first := make([]byte, partSize) + n, err := io.ReadFull(body, first) + if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { + return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) + } if err != nil { - return fmt.Errorf("s3: PutObject request: %w", err) + return fmt.Errorf("s3: PutObject read: %w", err) } - req.ContentLength = size - c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) - if err != nil { - return fmt.Errorf("s3: PutObject: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode/100 != 2 { - return parseError(resp) - } - return nil + combined := io.MultiReader(bytes.NewReader(first), body) + return c.putObjectMultipart(ctx, key, combined) } // GetObject downloads an object. The key is automatically prefixed with the client's configured @@ -216,6 +214,171 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } +// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { + fullKey := c.objectKey(key) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) + if err != nil { + return fmt.Errorf("s3: PutObject request: %w", err) + } + req.ContentLength = size + c.signV4(req, unsignedPayload) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: PutObject: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + return nil +} + +// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize +// chunks, uploading each as a separate part. This allows uploading without knowing the total +// body size in advance. +func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { + fullKey := c.objectKey(key) + + // Step 1: Initiate multipart upload + uploadID, err := c.initiateMultipartUpload(ctx, fullKey) + if err != nil { + return err + } + + // Step 2: Upload parts + var parts []completedPart + buf := make([]byte, partSize) + partNumber := 1 + for { + n, err := io.ReadFull(body, buf) + if n > 0 { + etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) + if uploadErr != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return uploadErr + } + parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) + partNumber++ + } + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + if err != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return fmt.Errorf("s3: PutObject read: %w", err) + } + } + + // Step 3: Complete multipart upload + return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) +} + +// initiateMultipartUpload starts a new multipart upload and returns the upload ID. +func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { + reqURL := c.objectURL(fullKey) + "?uploads" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err) + } + req.ContentLength = 0 + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return "", parseError(resp) + } + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err) + } + var result initiateMultipartUploadResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) + } + return result.UploadID, nil +} + +// uploadPart uploads a single part of a multipart upload and returns the ETag. +func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("s3: UploadPart request: %w", err) + } + req.ContentLength = int64(len(data)) + c.signV4(req, unsignedPayload) + resp, err := c.httpClient().Do(req) + if err != nil { + return "", fmt.Errorf("s3: UploadPart: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return "", parseError(resp) + } + etag := resp.Header.Get("ETag") + return etag, nil +} + +// completeMultipartUpload finalizes a multipart upload with the given parts. +func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { + var body bytes.Buffer + body.WriteString("") + for _, p := range parts { + fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) + } + body.WriteString("") + bodyBytes := body.Bytes() + payloadHash := sha256Hex(bodyBytes) + + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err) + } + req.ContentLength = int64(len(bodyBytes)) + req.Header.Set("Content-Type", "application/xml") + c.signV4(req, payloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + // Read response body to check for errors (S3 can return 200 with an error body) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err) + } + // Check if the response contains an error + var errResp ErrorResponse + if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { + errResp.StatusCode = resp.StatusCode + return &errResp + } + return nil +} + +// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. +func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) + if err != nil { + return + } + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return + } + resp.Body.Close() +} + // signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 // of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. func (c *Client) signV4(req *http.Request, payloadHash string) { diff --git a/s3/client_test.go b/s3/client_test.go index f4a10213..c3a8fe2c 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -23,27 +23,41 @@ import ( // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs mu sync.RWMutex } func newMockS3Server() (*httptest.Server, *mockS3Server) { - m := &mockS3Server{objects: make(map[string][]byte)} + m := &mockS3Server{ + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + } return httptest.NewTLSServer(m), m } func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Path is /{bucket}[/{key...}] path := strings.TrimPrefix(r.URL.Path, "/") + q := r.URL.Query() switch { + case r.Method == http.MethodPut && q.Has("partNumber"): + m.handleUploadPart(w, r, path) case r.Method == http.MethodPut: m.handlePut(w, r, path) - case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + case r.Method == http.MethodPost && q.Has("uploads"): + m.handleInitiateMultipart(w, r, path) + case r.Method == http.MethodPost && q.Has("uploadId"): + m.handleCompleteMultipart(w, r, path) + case r.Method == http.MethodDelete && q.Has("uploadId"): + m.handleAbortMultipart(w, r, path) + case r.Method == http.MethodGet && q.Get("list-type") == "2": m.handleList(w, r, path) case r.Method == http.MethodGet: m.handleGet(w, r, path) - case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + case r.Method == http.MethodPost && q.Has("delete"): m.handleDelete(w, r, path) default: http.Error(w, "not implemented", http.StatusNotImplemented) @@ -62,6 +76,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st w.WriteHeader(http.StatusOK) } +func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { + m.mu.Lock() + m.nextID++ + uploadID := fmt.Sprintf("upload-%d", m.nextID) + m.uploads[uploadID] = make(map[int][]byte) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, uploadID) +} + +func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + var partNumber int + fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + parts[partNumber] = body + m.mu.Unlock() + + etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) + w.Header().Set("ETag", etag) + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + + // Assemble parts in order + var assembled []byte + for i := 1; i <= len(parts); i++ { + assembled = append(assembled, parts[i]...) + } + m.objects[path] = assembled + delete(m.uploads, uploadID) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, path) +} + +func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + m.mu.Lock() + delete(m.uploads, uploadID) + m.mu.Unlock() + w.WriteHeader(http.StatusNoContent) +} + func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { m.mu.RLock() body, ok := m.objects[path] @@ -333,7 +418,7 @@ func TestClient_PutGetObject(t *testing.T) { ctx := context.Background() // Put - err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello world")) require.Nil(t, err) // Get @@ -353,7 +438,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello")) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "test-key") @@ -385,7 +470,7 @@ func TestClient_DeleteObjects(t *testing.T) { // Put several objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4) + err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data"))) require.Nil(t, err) } require.Equal(t, 5, mock.objectCount()) @@ -416,13 +501,13 @@ func TestClient_ListObjects(t *testing.T) { // Client with prefix "pfx": list should only return objects under pfx/ client := newTestClient(server, "my-bucket", "pfx") for i := 0; i < 3; i++ { - err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } // Also put an object outside the prefix using a no-prefix client clientNoPrefix := newTestClient(server, "my-bucket", "") - err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1) + err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y"))) require.Nil(t, err) // List with prefix client: should only see 3 @@ -446,7 +531,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) { // Put 5 objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } @@ -478,7 +563,7 @@ func TestClient_ListAllObjects(t *testing.T) { ctx := context.Background() for i := 0; i < 10; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } @@ -499,7 +584,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data))) + err := client.PutObject(ctx, "large", bytes.NewReader(data)) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "large") @@ -511,6 +596,54 @@ func TestClient_PutObject_LargeBody(t *testing.T) { require.Equal(t, data, got) } +func TestClient_PutObject_ChunkedUpload(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // 12 MB object, exceeds 5 MB partSize, triggers multipart upload path + data := make([]byte, 12*1024*1024) + for i := range data { + data[i] = byte(i % 256) + } + err := client.PutObject(ctx, "multipart", bytes.NewReader(data)) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "multipart") + require.Nil(t, err) + require.Equal(t, int64(12*1024*1024), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) +} + +func TestClient_PutObject_ExactPartSize(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully) + data := make([]byte, 5*1024*1024) + for i := range data { + data[i] = byte(i % 256) + } + err := client.PutObject(ctx, "exact", bytes.NewReader(data)) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "exact") + require.Nil(t, err) + require.Equal(t, int64(5*1024*1024), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) +} + func TestClient_PutObject_NestedKey(t *testing.T) { server, _ := newMockS3Server() defer server.Close() @@ -518,7 +651,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6) + err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested")) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") @@ -548,7 +681,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { for i := 0; i < batchSize; i++ { idx := batch*batchSize + i key := fmt.Sprintf("%08d", idx) - err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, key, bytes.NewReader([]byte("x"))) require.Nil(t, err) } } @@ -647,7 +780,7 @@ func TestClient_RealBucket(t *testing.T) { content := "hello from ntfy s3 test" // Put - err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content))) + err := client.PutObject(ctx, key, strings.NewReader(content)) require.Nil(t, err) // Get @@ -685,7 +818,7 @@ func TestClient_RealBucket(t *testing.T) { // Put 10 objects for i := 0; i < 10; i++ { - err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1) + err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x")) require.Nil(t, err) } @@ -710,7 +843,7 @@ func TestClient_RealBucket(t *testing.T) { data[i] = byte(i % 256) } - err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data))) + err := client.PutObject(ctx, key, bytes.NewReader(data)) require.Nil(t, err) reader, size, err := client.GetObject(ctx, key) diff --git a/s3/types.go b/s3/types.go index 5929ec6c..201c570b 100644 --- a/s3/types.go +++ b/s3/types.go @@ -63,3 +63,14 @@ type deleteError struct { Code string `xml:"Code"` Message string `xml:"Message"` } + +// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload +type initiateMultipartUploadResult struct { + UploadID string `xml:"UploadId"` +} + +// completedPart represents a successfully uploaded part for CompleteMultipartUpload +type completedPart struct { + PartNumber int + ETag string +} diff --git a/s3/util.go b/s3/util.go index cf9d4ba8..c24c1c5b 100644 --- a/s3/util.go +++ b/s3/util.go @@ -22,6 +22,11 @@ const ( // maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB) maxResponseBytes = 10 * 1024 * 1024 + + // partSize is the size of each part for multipart uploads (5 MB). This is also the threshold + // above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum + // part size of 5 MB for all parts except the last. + partSize = 5 * 1024 * 1024 ) // ParseURL parses an S3 URL of the form: diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go index 697d4e71..1dbac0cf 100644 --- a/tools/s3cli/main.go +++ b/tools/s3cli/main.go @@ -58,44 +58,21 @@ func cmdPut(ctx context.Context, client *s3.Client) { path := os.Args[3] var r io.Reader - var size int64 if path == "-" { - // Read stdin into a temp file to get the size - tmp, err := os.CreateTemp("", "s3cli-*") - if err != nil { - fail("create temp file: %s", err) - } - defer os.Remove(tmp.Name()) - n, err := io.Copy(tmp, os.Stdin) - if err != nil { - tmp.Close() - fail("read stdin: %s", err) - } - if _, err := tmp.Seek(0, io.SeekStart); err != nil { - tmp.Close() - fail("seek: %s", err) - } - r = tmp - size = n - defer tmp.Close() + r = os.Stdin } else { f, err := os.Open(path) if err != nil { fail("open %s: %s", path, err) } defer f.Close() - info, err := f.Stat() - if err != nil { - fail("stat %s: %s", path, err) - } r = f - size = info.Size() } - if err := client.PutObject(ctx, key, r, size); err != nil { + if err := client.PutObject(ctx, key, r); err != nil { fail("put: %s", err) } - fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size) + fmt.Fprintf(os.Stderr, "uploaded %s\n", key) } func cmdGet(ctx context.Context, client *s3.Client) {