mirror of
https://github.com/binwiederhier/ntfy.git
synced 2026-03-18 21:30:44 +01:00
167 lines
4.6 KiB
Go
167 lines
4.6 KiB
Go
package s3
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
// SHA-256 hash of the empty string, used as the payload hash for bodiless requests
|
|
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
|
|
|
// Sent as the payload hash for streaming uploads where the body is not buffered in memory
|
|
unsignedPayload = "UNSIGNED-PAYLOAD"
|
|
|
|
// maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB)
|
|
maxResponseBytes = 10 * 1024 * 1024
|
|
|
|
// partSize is the size of each part for multipart uploads (5 MB). This is also the threshold
|
|
// above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum
|
|
// part size of 5 MB for all parts except the last.
|
|
partSize = 5 * 1024 * 1024
|
|
)
|
|
|
|
// ParseURL parses an S3 URL of the form:
|
|
//
|
|
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
|
//
|
|
// When endpoint is specified, path-style addressing is enabled automatically.
|
|
func ParseURL(s3URL string) (*Config, error) {
|
|
u, err := url.Parse(s3URL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("s3: invalid URL: %w", err)
|
|
}
|
|
if u.Scheme != "s3" {
|
|
return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme)
|
|
}
|
|
if u.Host == "" {
|
|
return nil, fmt.Errorf("s3: bucket name must be specified as host")
|
|
}
|
|
bucket := u.Host
|
|
prefix := strings.TrimPrefix(u.Path, "/")
|
|
accessKey := u.User.Username()
|
|
secretKey, _ := u.User.Password()
|
|
if accessKey == "" || secretKey == "" {
|
|
return nil, fmt.Errorf("s3: access key and secret key must be specified in URL")
|
|
}
|
|
region := u.Query().Get("region")
|
|
if region == "" {
|
|
return nil, fmt.Errorf("s3: region query parameter is required")
|
|
}
|
|
endpointParam := u.Query().Get("endpoint")
|
|
var endpoint string
|
|
var pathStyle bool
|
|
if endpointParam != "" {
|
|
// Custom endpoint: strip scheme prefix to extract host[:port]
|
|
ep := strings.TrimRight(endpointParam, "/")
|
|
ep = strings.TrimPrefix(ep, "https://")
|
|
ep = strings.TrimPrefix(ep, "http://")
|
|
endpoint = ep
|
|
pathStyle = true
|
|
} else {
|
|
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region)
|
|
pathStyle = false
|
|
}
|
|
return &Config{
|
|
Endpoint: endpoint,
|
|
PathStyle: pathStyle,
|
|
Bucket: bucket,
|
|
Prefix: prefix,
|
|
Region: region,
|
|
AccessKey: accessKey,
|
|
SecretKey: secretKey,
|
|
}, nil
|
|
}
|
|
|
|
// parseError reads an S3 error response and returns an *ErrorResponse.
|
|
func parseError(resp *http.Response) error {
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
|
if err != nil {
|
|
return fmt.Errorf("s3: reading error response: %w", err)
|
|
}
|
|
return parseErrorFromBytes(resp.StatusCode, body)
|
|
}
|
|
|
|
func parseErrorFromBytes(statusCode int, body []byte) error {
|
|
errResp := &ErrorResponse{
|
|
StatusCode: statusCode,
|
|
Body: string(body),
|
|
}
|
|
// Try to parse XML error; if it fails, we still have StatusCode and Body
|
|
_ = xml.Unmarshal(body, errResp)
|
|
return errResp
|
|
}
|
|
|
|
// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is
|
|
// percent-encoded per RFC 3986; forward slashes are preserved.
|
|
func canonicalURI(u *url.URL) string {
|
|
p := u.Path
|
|
if p == "" {
|
|
return "/"
|
|
}
|
|
segments := strings.Split(p, "/")
|
|
for i, seg := range segments {
|
|
segments[i] = uriEncode(seg)
|
|
}
|
|
return strings.Join(segments, "/")
|
|
}
|
|
|
|
// canonicalQueryString builds the query string for the canonical request. Keys and values
|
|
// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key.
|
|
func canonicalQueryString(values url.Values) string {
|
|
if len(values) == 0 {
|
|
return ""
|
|
}
|
|
keys := make([]string, 0, len(values))
|
|
for k := range values {
|
|
keys = append(keys, k)
|
|
}
|
|
sort.Strings(keys)
|
|
var pairs []string
|
|
for _, k := range keys {
|
|
ek := uriEncode(k)
|
|
vs := make([]string, len(values[k]))
|
|
copy(vs, values[k])
|
|
sort.Strings(vs)
|
|
for _, v := range vs {
|
|
pairs = append(pairs, ek+"="+uriEncode(v))
|
|
}
|
|
}
|
|
return strings.Join(pairs, "&")
|
|
}
|
|
|
|
// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved
|
|
// characters (A-Z a-z 0-9 - _ . ~).
|
|
func uriEncode(s string) string {
|
|
var buf strings.Builder
|
|
for i := 0; i < len(s); i++ {
|
|
b := s[i]
|
|
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') ||
|
|
b == '-' || b == '_' || b == '.' || b == '~' {
|
|
buf.WriteByte(b)
|
|
} else {
|
|
fmt.Fprintf(&buf, "%%%02X", b)
|
|
}
|
|
}
|
|
return buf.String()
|
|
}
|
|
|
|
func sha256Hex(data []byte) string {
|
|
h := sha256.Sum256(data)
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
func hmacSHA256(key, data []byte) []byte {
|
|
h := hmac.New(sha256.New, key)
|
|
h.Write(data)
|
|
return h.Sum(nil)
|
|
}
|