mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-11 02:59:34 -06:00
Change NBS s3TablePersister to use Multipart upload (#2922)
Instead of putting an entire table to S3 in a single request, split it into 5MB parts (the smallest allowable) and send all the parts in parallel.
This commit is contained in:
@@ -10,8 +10,11 @@ import (
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/attic-labs/noms/go/d"
|
||||
"github.com/attic-labs/noms/go/hash"
|
||||
"github.com/attic-labs/testify/assert"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
@@ -24,13 +27,113 @@ func (m mockAWSError) Code() string { return string(m) }
|
||||
func (m mockAWSError) Message() string { return string(m) }
|
||||
func (m mockAWSError) OrigErr() error { return nil }
|
||||
|
||||
type fakeS3 struct {
|
||||
data map[string][]byte
|
||||
assert *assert.Assertions
|
||||
func makeFakeS3(a *assert.Assertions) *fakeS3 {
|
||||
return &fakeS3{
|
||||
assert: a,
|
||||
data: map[string][]byte{},
|
||||
inProgress: map[string]fakeS3Multipart{},
|
||||
parts: map[string][]byte{},
|
||||
}
|
||||
}
|
||||
|
||||
func makeFakeS3(a *assert.Assertions) *fakeS3 {
|
||||
return &fakeS3{data: map[string][]byte{}, assert: a}
|
||||
type fakeS3 struct {
|
||||
assert *assert.Assertions
|
||||
|
||||
mu sync.Mutex
|
||||
data map[string][]byte
|
||||
inProgressCounter int
|
||||
inProgress map[string]fakeS3Multipart // Key -> {UploadId, Etags...}
|
||||
parts map[string][]byte // ETag -> data
|
||||
}
|
||||
|
||||
type fakeS3Multipart struct {
|
||||
uploadID string
|
||||
etags []string
|
||||
}
|
||||
|
||||
func (m *fakeS3) readerForTable(name addr) chunkReader {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if buff, present := m.data[name.String()]; present {
|
||||
return newTableReader(buff, bytes.NewReader(buff))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *fakeS3) AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) {
|
||||
m.assert.NotNil(input.Bucket, "Bucket is a required field")
|
||||
m.assert.NotNil(input.Key, "Key is a required field")
|
||||
m.assert.NotNil(input.UploadId, "UploadId is a required field")
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
|
||||
for _, etag := range m.inProgress[*input.Key].etags {
|
||||
delete(m.parts, etag)
|
||||
}
|
||||
delete(m.inProgress, *input.Key)
|
||||
return &s3.AbortMultipartUploadOutput{}, nil
|
||||
}
|
||||
|
||||
func (m *fakeS3) CreateMultipartUpload(input *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
|
||||
m.assert.NotNil(input.Bucket, "Bucket is a required field")
|
||||
m.assert.NotNil(input.Key, "Key is a required field")
|
||||
|
||||
out := &s3.CreateMultipartUploadOutput{
|
||||
Bucket: input.Bucket,
|
||||
Key: input.Key,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
uploadID := strconv.Itoa(m.inProgressCounter)
|
||||
out.UploadId = aws.String(uploadID)
|
||||
m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil}
|
||||
m.inProgressCounter++
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (m *fakeS3) UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
|
||||
m.assert.NotNil(input.Bucket, "Bucket is a required field")
|
||||
m.assert.NotNil(input.Key, "Key is a required field")
|
||||
m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
|
||||
m.assert.NotNil(input.UploadId, "UploadId is a required field")
|
||||
m.assert.NotNil(input.Body, "Body is a required field")
|
||||
|
||||
data, err := ioutil.ReadAll(input.Body)
|
||||
m.assert.NoError(err)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
etag := hash.Of(data).String() + time.Now().String()
|
||||
m.parts[etag] = data
|
||||
|
||||
inProgress, present := m.inProgress[*input.Key]
|
||||
m.assert.True(present)
|
||||
m.assert.Equal(inProgress.uploadID, *input.UploadId)
|
||||
inProgress.etags = append(inProgress.etags, etag)
|
||||
m.inProgress[*input.Key] = inProgress
|
||||
return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil
|
||||
}
|
||||
|
||||
func (m *fakeS3) CompleteMultipartUpload(input *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
|
||||
m.assert.NotNil(input.Bucket, "Bucket is a required field")
|
||||
m.assert.NotNil(input.Key, "Key is a required field")
|
||||
m.assert.NotNil(input.UploadId, "UploadId is a required field")
|
||||
m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field")
|
||||
m.assert.True(len(input.MultipartUpload.Parts) > 0)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId)
|
||||
for idx, part := range input.MultipartUpload.Parts {
|
||||
m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed
|
||||
m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...)
|
||||
delete(m.parts, *part.ETag)
|
||||
}
|
||||
delete(m.inProgress, *input.Key)
|
||||
|
||||
return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil
|
||||
}
|
||||
|
||||
func (m *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
|
||||
|
||||
@@ -23,6 +23,10 @@ type s3TableReader struct {
|
||||
}
|
||||
|
||||
type s3svc interface {
|
||||
AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
|
||||
CreateMultipartUpload(input *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
|
||||
UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error)
|
||||
CompleteMultipartUpload(input *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
|
||||
GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error)
|
||||
PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error)
|
||||
}
|
||||
|
||||
@@ -10,14 +10,18 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/attic-labs/noms/go/d"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
const defaultS3PartSize = 5 * 1 << 20 // 5MiB, smallest allowed by S3
|
||||
|
||||
func newS3TableSet(s3 s3svc, bucket string) tableSet {
|
||||
return tableSet{p: s3TablePersister{s3, bucket}}
|
||||
return tableSet{p: s3TablePersister{s3, bucket, defaultS3PartSize}}
|
||||
}
|
||||
|
||||
func newFSTableSet(dir string) tableSet {
|
||||
@@ -81,27 +85,148 @@ type tablePersister interface {
|
||||
}
|
||||
|
||||
type s3TablePersister struct {
|
||||
s3 s3svc
|
||||
bucket string
|
||||
s3 s3svc
|
||||
bucket string
|
||||
partSize int
|
||||
}
|
||||
|
||||
func (s3p s3TablePersister) Open(name addr, chunkCount uint32) chunkSource {
|
||||
return newS3TableReader(s3p.s3, s3p.bucket, name, chunkCount)
|
||||
}
|
||||
|
||||
type s3UploadedPart struct {
|
||||
idx int64
|
||||
etag string
|
||||
}
|
||||
|
||||
func (s3p s3TablePersister) Compact(mt *memTable, haver chunkReader) (name addr, chunkCount uint32) {
|
||||
name, data, chunkCount := mt.write(haver)
|
||||
|
||||
if chunkCount > 0 {
|
||||
_, err := s3p.s3.PutObject(&s3.PutObjectInput{
|
||||
Bucket: aws.String(s3p.bucket),
|
||||
Key: aws.String(name.String()),
|
||||
Body: bytes.NewReader(data),
|
||||
ContentLength: aws.Int64(int64(len(data))),
|
||||
result, err := s3p.s3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
|
||||
Bucket: aws.String(s3p.bucket),
|
||||
Key: aws.String(name.String()),
|
||||
})
|
||||
d.PanicIfError(err)
|
||||
uploadID := *result.UploadId
|
||||
|
||||
multipartUpload, err := s3p.uploadParts(data, name.String(), uploadID)
|
||||
if err != nil {
|
||||
_, abrtErr := s3p.s3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
|
||||
Bucket: aws.String(s3p.bucket),
|
||||
Key: aws.String(name.String()),
|
||||
UploadId: aws.String(uploadID),
|
||||
})
|
||||
d.Chk.NoError(abrtErr)
|
||||
panic(err) // TODO: Better error handling here
|
||||
}
|
||||
|
||||
_, err = s3p.s3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
|
||||
Bucket: aws.String(s3p.bucket),
|
||||
Key: aws.String(name.String()),
|
||||
MultipartUpload: multipartUpload,
|
||||
UploadId: aws.String(uploadID),
|
||||
})
|
||||
d.Chk.NoError(err)
|
||||
}
|
||||
return name, chunkCount
|
||||
}
|
||||
|
||||
func (s3p s3TablePersister) Open(name addr, chunkCount uint32) chunkSource {
|
||||
return newS3TableReader(s3p.s3, s3p.bucket, name, chunkCount)
|
||||
func (s3p s3TablePersister) uploadParts(data []byte, key, uploadID string) (*s3.CompletedMultipartUpload, error) {
|
||||
sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{})
|
||||
|
||||
numParts := getNumParts(len(data), s3p.partSize)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParts)
|
||||
sendPart := func(partNum int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check if upload has been terminated
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
// Upload the desired part
|
||||
start, end := (partNum-1)*s3p.partSize, partNum*s3p.partSize
|
||||
if partNum == numParts { // If this is the last part, make sure it includes any overflow
|
||||
end = len(data)
|
||||
}
|
||||
result, err := s3p.s3.UploadPart(&s3.UploadPartInput{
|
||||
Bucket: aws.String(s3p.bucket),
|
||||
Key: aws.String(key),
|
||||
PartNumber: aws.Int64(int64(partNum)),
|
||||
UploadId: aws.String(uploadID),
|
||||
Body: bytes.NewReader(data[start:end]),
|
||||
})
|
||||
if err != nil {
|
||||
failed <- err
|
||||
return
|
||||
}
|
||||
// Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly.
|
||||
select {
|
||||
case sent <- s3UploadedPart{int64(partNum), *result.ETag}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
for i := 1; i <= numParts; i++ {
|
||||
go sendPart(i)
|
||||
}
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(sent)
|
||||
close(failed)
|
||||
}()
|
||||
|
||||
multipartUpload := &s3.CompletedMultipartUpload{}
|
||||
var lastFailure error
|
||||
for cont := true; cont; {
|
||||
select {
|
||||
case sentPart, open := <-sent:
|
||||
if open {
|
||||
multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{
|
||||
ETag: aws.String(sentPart.etag),
|
||||
PartNumber: aws.Int64(sentPart.idx),
|
||||
})
|
||||
}
|
||||
cont = open
|
||||
|
||||
case err := <-failed:
|
||||
if err != nil { // nil err may happen when failed gets closed
|
||||
lastFailure = err
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastFailure == nil {
|
||||
close(done)
|
||||
}
|
||||
sort.Sort(partsByPartNum(multipartUpload.Parts))
|
||||
return multipartUpload, lastFailure
|
||||
}
|
||||
|
||||
func getNumParts(dataLen, partSize int) int {
|
||||
numParts := dataLen / partSize
|
||||
if numParts == 0 {
|
||||
numParts = 1
|
||||
}
|
||||
return numParts
|
||||
}
|
||||
|
||||
type partsByPartNum []*s3.CompletedPart
|
||||
|
||||
func (s partsByPartNum) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s partsByPartNum) Less(i, j int) bool {
|
||||
return *s[i].PartNumber < *s[j].PartNumber
|
||||
}
|
||||
|
||||
func (s partsByPartNum) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
type fsTablePersister struct {
|
||||
|
||||
@@ -9,9 +9,11 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/attic-labs/testify/assert"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
)
|
||||
|
||||
var testChunks = [][]byte{[]byte("hello2"), []byte("goodbye2"), []byte("badbye2")}
|
||||
@@ -81,19 +83,70 @@ func TestS3TablePersisterCompact(t *testing.T) {
|
||||
}
|
||||
|
||||
s3svc := makeFakeS3(assert)
|
||||
s3p := s3TablePersister{s3svc, "bucket"}
|
||||
s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, 3)}
|
||||
|
||||
tableAddr, chunkCount := s3p.Compact(mt, nil)
|
||||
if assert.True(chunkCount > 0) {
|
||||
buff, present := s3svc.data[tableAddr.String()]
|
||||
assert.True(present)
|
||||
tr := newTableReader(buff, bytes.NewReader(buff))
|
||||
for _, c := range testChunks {
|
||||
assert.True(tr.has(computeAddr(c)))
|
||||
if r := s3svc.readerForTable(tableAddr); assert.NotNil(r) {
|
||||
assertChunksInReader(testChunks, r, assert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func calcPartSize(mt *memTable, maxPartNum int) int {
|
||||
return int(maxTableSize(uint64(mt.count()), mt.totalData)) / maxPartNum
|
||||
}
|
||||
|
||||
func TestS3TablePersisterCompactSinglePart(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
mt := newMemTable(testMemTableSize)
|
||||
|
||||
for _, c := range testChunks {
|
||||
assert.True(mt.addChunk(computeAddr(c), c))
|
||||
}
|
||||
|
||||
s3svc := makeFakeS3(assert)
|
||||
s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, 1)}
|
||||
|
||||
tableAddr, chunkCount := s3p.Compact(mt, nil)
|
||||
if assert.True(chunkCount > 0) {
|
||||
if r := s3svc.readerForTable(tableAddr); assert.NotNil(r) {
|
||||
assertChunksInReader(testChunks, r, assert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3TablePersisterCompactAbort(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
mt := newMemTable(testMemTableSize)
|
||||
|
||||
for _, c := range testChunks {
|
||||
assert.True(mt.addChunk(computeAddr(c), c))
|
||||
}
|
||||
|
||||
numParts := 4
|
||||
s3svc := &failingFakeS3{makeFakeS3(assert), sync.Mutex{}, 1}
|
||||
s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, numParts)}
|
||||
|
||||
assert.Panics(func() { s3p.Compact(mt, nil) })
|
||||
}
|
||||
|
||||
type failingFakeS3 struct {
|
||||
*fakeS3
|
||||
mu sync.Mutex
|
||||
numSuccesses int
|
||||
}
|
||||
|
||||
func (m *failingFakeS3) UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.numSuccesses > 0 {
|
||||
m.numSuccesses--
|
||||
return m.fakeS3.UploadPart(input)
|
||||
}
|
||||
return nil, mockAWSError("MalformedXML")
|
||||
}
|
||||
|
||||
func TestS3TablePersisterCompactNoData(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
mt := newMemTable(testMemTableSize)
|
||||
@@ -105,7 +158,7 @@ func TestS3TablePersisterCompactNoData(t *testing.T) {
|
||||
}
|
||||
|
||||
s3svc := makeFakeS3(assert)
|
||||
s3p := s3TablePersister{s3svc, "bucket"}
|
||||
s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: 1 << 10}
|
||||
|
||||
tableAddr, chunkCount := s3p.Compact(mt, existingTable)
|
||||
assert.True(chunkCount == 0)
|
||||
|
||||
Reference in New Issue
Block a user