ntfy/s3/client_multipart.go
2026-03-21 16:27:41 -04:00

187 lines
6.4 KiB
Go

package s3
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"heckel.io/ntfy/v2/log"
)
// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated
// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads.
//
// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
// and https://docs.aws.amazon.com/AmazonS3/latest/API/API_AbortMultipartUpload.html
func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error {
uploads, err := c.listMultipartUploads(ctx)
if err != nil {
return err
}
for _, u := range uploads {
if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) {
c.abortMultipartUpload(ctx, u.Key, u.UploadID)
}
}
return nil
}
// listMultipartUploads returns in-progress multipart uploads for the client's prefix.
// It paginates automatically, stopping after 10,000 pages as a safety valve.
func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, error) {
var all []*multipartUpload
var keyMarker, uploadIDMarker string
for page := 0; page < maxPages; page++ {
query := url.Values{"uploads": {""}}
if prefix := c.config.ListPrefix(); prefix != "" {
query.Set("prefix", prefix)
}
if keyMarker != "" {
query.Set("key-marker", keyMarker)
query.Set("upload-id-marker", uploadIDMarker)
}
respBody, err := c.do(ctx, "ListMultipartUploads", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil)
if err != nil {
return nil, err
}
var result listMultipartUploadsResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("error unmarshalling multipart upload result: %w", err)
}
for _, u := range result.Uploads {
var initiated time.Time
if u.Initiated != "" {
initiated, _ = time.Parse(time.RFC3339, u.Initiated)
}
all = append(all, &multipartUpload{
Key: u.Key,
UploadID: u.UploadID,
Initiated: initiated,
})
}
if !result.IsTruncated {
return all, nil
}
keyMarker = result.NextKeyMarker
uploadIDMarker = result.NextUploadIDMarker
}
return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages)
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, key, uploadID string) {
log.Tag(tagS3Client).Info("Aborting multipart upload for object %s", key)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.http.Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
log.Tag(tagS3Client).Debug("Uploading multipart object %s", key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, key)
if err != nil {
return err
}
// Step 2: Upload parts
partNumber := 1
buf := make([]byte, partSize)
var parts []*completedPart
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, key, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, key, uploadID)
return uploadErr
}
parts = append(parts, &completedPart{
PartNumber: partNumber,
ETag: etag,
})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if err != nil {
c.abortMultipartUpload(ctx, key, uploadID)
return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, key, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, key string) (string, error) {
respBody, err := c.do(ctx, "InitiateMultipartUpload", http.MethodPost, c.config.ObjectURL(key)+"?uploads", nil, nil)
if err != nil {
return "", err
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("error unmarshalling initiate multipart upload response: %w", err)
}
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, key, uploadID string, partNumber int, data []byte) (string, error) {
log.Tag(tagS3Client).Debug("Uploading multipart part for object %s, part %d, size %d", key, partNumber, len(data))
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.config.ObjectURL(key), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("error creating multipart upload part request for object %s: %w", key, err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.http.Do(req)
if err != nil {
return "", fmt.Errorf("error uploading multipart part for object %s: %w", key, err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return "", parseError(resp)
}
return resp.Header.Get("ETag"), nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID string, parts []*completedPart) error {
log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts))
bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts})
if err != nil {
return fmt.Errorf("error marshalling complete multipart upload request: %w", err)
}
reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID))
respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil)
if err != nil {
return err
}
// Check if the response contains an error (S3 can return 200 with an error body)
var errResp errorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
return &errResp
}
return nil
}