diff --git a/go/nbs/s3_fake_test.go b/go/nbs/s3_fake_test.go index 9ccb8b3008..ab8d96759e 100644 --- a/go/nbs/s3_fake_test.go +++ b/go/nbs/s3_fake_test.go @@ -10,8 +10,11 @@ import ( "io/ioutil" "strconv" "strings" + "sync" + "time" "github.com/attic-labs/noms/go/d" + "github.com/attic-labs/noms/go/hash" "github.com/attic-labs/testify/assert" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" @@ -24,13 +27,113 @@ func (m mockAWSError) Code() string { return string(m) } func (m mockAWSError) Message() string { return string(m) } func (m mockAWSError) OrigErr() error { return nil } -type fakeS3 struct { - data map[string][]byte - assert *assert.Assertions +func makeFakeS3(a *assert.Assertions) *fakeS3 { + return &fakeS3{ + assert: a, + data: map[string][]byte{}, + inProgress: map[string]fakeS3Multipart{}, + parts: map[string][]byte{}, + } } -func makeFakeS3(a *assert.Assertions) *fakeS3 { - return &fakeS3{data: map[string][]byte{}, assert: a} +type fakeS3 struct { + assert *assert.Assertions + + mu sync.Mutex + data map[string][]byte + inProgressCounter int + inProgress map[string]fakeS3Multipart // Key -> {UploadId, Etags...} + parts map[string][]byte // ETag -> data +} + +type fakeS3Multipart struct { + uploadID string + etags []string +} + +func (m *fakeS3) readerForTable(name addr) chunkReader { + m.mu.Lock() + defer m.mu.Unlock() + if buff, present := m.data[name.String()]; present { + return newTableReader(buff, bytes.NewReader(buff)) + } + return nil +} + +func (m *fakeS3) AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) { + m.assert.NotNil(input.Bucket, "Bucket is a required field") + m.assert.NotNil(input.Key, "Key is a required field") + m.assert.NotNil(input.UploadId, "UploadId is a required field") + + m.mu.Lock() + defer m.mu.Unlock() + m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) + for _, etag := range m.inProgress[*input.Key].etags { + delete(m.parts, etag) + } + delete(m.inProgress, *input.Key) + return &s3.AbortMultipartUploadOutput{}, nil +} + +func (m *fakeS3) CreateMultipartUpload(input *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) { + m.assert.NotNil(input.Bucket, "Bucket is a required field") + m.assert.NotNil(input.Key, "Key is a required field") + + out := &s3.CreateMultipartUploadOutput{ + Bucket: input.Bucket, + Key: input.Key, + } + + m.mu.Lock() + defer m.mu.Unlock() + uploadID := strconv.Itoa(m.inProgressCounter) + out.UploadId = aws.String(uploadID) + m.inProgress[*input.Key] = fakeS3Multipart{uploadID, nil} + m.inProgressCounter++ + return out, nil +} + +func (m *fakeS3) UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + m.assert.NotNil(input.Bucket, "Bucket is a required field") + m.assert.NotNil(input.Key, "Key is a required field") + m.assert.NotNil(input.PartNumber, "PartNumber is a required field") + m.assert.NotNil(input.UploadId, "UploadId is a required field") + m.assert.NotNil(input.Body, "Body is a required field") + + data, err := ioutil.ReadAll(input.Body) + m.assert.NoError(err) + + m.mu.Lock() + defer m.mu.Unlock() + etag := hash.Of(data).String() + time.Now().String() + m.parts[etag] = data + + inProgress, present := m.inProgress[*input.Key] + m.assert.True(present) + m.assert.Equal(inProgress.uploadID, *input.UploadId) + inProgress.etags = append(inProgress.etags, etag) + m.inProgress[*input.Key] = inProgress + return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil +} + +func (m *fakeS3) CompleteMultipartUpload(input *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) { + m.assert.NotNil(input.Bucket, "Bucket is a required field") + m.assert.NotNil(input.Key, "Key is a required field") + m.assert.NotNil(input.UploadId, "UploadId is a required field") + m.assert.NotNil(input.MultipartUpload, "MultipartUpload is a required field") + m.assert.True(len(input.MultipartUpload.Parts) > 0) + + m.mu.Lock() + defer m.mu.Unlock() + m.assert.Equal(m.inProgress[*input.Key].uploadID, *input.UploadId) + for idx, part := range input.MultipartUpload.Parts { + m.assert.EqualValues(idx+1, *part.PartNumber) // Part numbers are 1-indexed + m.data[*input.Key] = append(m.data[*input.Key], m.parts[*part.ETag]...) + delete(m.parts, *part.ETag) + } + delete(m.inProgress, *input.Key) + + return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil } func (m *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { diff --git a/go/nbs/s3_table_reader.go b/go/nbs/s3_table_reader.go index aa12ba03eb..adda8a9023 100644 --- a/go/nbs/s3_table_reader.go +++ b/go/nbs/s3_table_reader.go @@ -23,6 +23,10 @@ type s3TableReader struct { } type s3svc interface { + AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) + CreateMultipartUpload(input *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) + UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error) + CompleteMultipartUpload(input *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) } diff --git a/go/nbs/table_set.go b/go/nbs/table_set.go index b1ede58924..de3817b83c 100644 --- a/go/nbs/table_set.go +++ b/go/nbs/table_set.go @@ -10,14 +10,18 @@ import ( "io/ioutil" "os" "path/filepath" + "sort" + "sync" "github.com/attic-labs/noms/go/d" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" ) +const defaultS3PartSize = 5 * 1 << 20 // 5MiB, smallest allowed by S3 + func newS3TableSet(s3 s3svc, bucket string) tableSet { - return tableSet{p: s3TablePersister{s3, bucket}} + return tableSet{p: s3TablePersister{s3, bucket, defaultS3PartSize}} } func newFSTableSet(dir string) tableSet { @@ -81,27 +85,148 @@ type tablePersister interface { } type s3TablePersister struct { - s3 s3svc - bucket string + s3 s3svc + bucket string + partSize int +} + +func (s3p s3TablePersister) Open(name addr, chunkCount uint32) chunkSource { + return newS3TableReader(s3p.s3, s3p.bucket, name, chunkCount) +} + +type s3UploadedPart struct { + idx int64 + etag string } func (s3p s3TablePersister) Compact(mt *memTable, haver chunkReader) (name addr, chunkCount uint32) { name, data, chunkCount := mt.write(haver) if chunkCount > 0 { - _, err := s3p.s3.PutObject(&s3.PutObjectInput{ - Bucket: aws.String(s3p.bucket), - Key: aws.String(name.String()), - Body: bytes.NewReader(data), - ContentLength: aws.Int64(int64(len(data))), + result, err := s3p.s3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ + Bucket: aws.String(s3p.bucket), + Key: aws.String(name.String()), }) d.PanicIfError(err) + uploadID := *result.UploadId + + multipartUpload, err := s3p.uploadParts(data, name.String(), uploadID) + if err != nil { + _, abrtErr := s3p.s3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + Bucket: aws.String(s3p.bucket), + Key: aws.String(name.String()), + UploadId: aws.String(uploadID), + }) + d.Chk.NoError(abrtErr) + panic(err) // TODO: Better error handling here + } + + _, err = s3p.s3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ + Bucket: aws.String(s3p.bucket), + Key: aws.String(name.String()), + MultipartUpload: multipartUpload, + UploadId: aws.String(uploadID), + }) + d.Chk.NoError(err) } return name, chunkCount } -func (s3p s3TablePersister) Open(name addr, chunkCount uint32) chunkSource { - return newS3TableReader(s3p.s3, s3p.bucket, name, chunkCount) +func (s3p s3TablePersister) uploadParts(data []byte, key, uploadID string) (*s3.CompletedMultipartUpload, error) { + sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{}) + + numParts := getNumParts(len(data), s3p.partSize) + var wg sync.WaitGroup + wg.Add(numParts) + sendPart := func(partNum int) { + defer wg.Done() + + // Check if upload has been terminated + select { + case <-done: + return + default: + } + // Upload the desired part + start, end := (partNum-1)*s3p.partSize, partNum*s3p.partSize + if partNum == numParts { // If this is the last part, make sure it includes any overflow + end = len(data) + } + result, err := s3p.s3.UploadPart(&s3.UploadPartInput{ + Bucket: aws.String(s3p.bucket), + Key: aws.String(key), + PartNumber: aws.Int64(int64(partNum)), + UploadId: aws.String(uploadID), + Body: bytes.NewReader(data[start:end]), + }) + if err != nil { + failed <- err + return + } + // Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly. + select { + case sent <- s3UploadedPart{int64(partNum), *result.ETag}: + case <-done: + return + } + } + for i := 1; i <= numParts; i++ { + go sendPart(i) + } + go func() { + wg.Wait() + close(sent) + close(failed) + }() + + multipartUpload := &s3.CompletedMultipartUpload{} + var lastFailure error + for cont := true; cont; { + select { + case sentPart, open := <-sent: + if open { + multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{ + ETag: aws.String(sentPart.etag), + PartNumber: aws.Int64(sentPart.idx), + }) + } + cont = open + + case err := <-failed: + if err != nil { // nil err may happen when failed gets closed + lastFailure = err + close(done) + } + } + } + + if lastFailure == nil { + close(done) + } + sort.Sort(partsByPartNum(multipartUpload.Parts)) + return multipartUpload, lastFailure +} + +func getNumParts(dataLen, partSize int) int { + numParts := dataLen / partSize + if numParts == 0 { + numParts = 1 + } + return numParts +} + +type partsByPartNum []*s3.CompletedPart + +func (s partsByPartNum) Len() int { + return len(s) +} + +func (s partsByPartNum) Less(i, j int) bool { + return *s[i].PartNumber < *s[j].PartNumber +} + +func (s partsByPartNum) Swap(i, j int) { + s[i], s[j] = s[j], s[i] } type fsTablePersister struct { diff --git a/go/nbs/table_set_test.go b/go/nbs/table_set_test.go index cafd4f9f7b..228bd1e2b7 100644 --- a/go/nbs/table_set_test.go +++ b/go/nbs/table_set_test.go @@ -9,9 +9,11 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "testing" "github.com/attic-labs/testify/assert" + "github.com/aws/aws-sdk-go/service/s3" ) var testChunks = [][]byte{[]byte("hello2"), []byte("goodbye2"), []byte("badbye2")} @@ -81,19 +83,70 @@ func TestS3TablePersisterCompact(t *testing.T) { } s3svc := makeFakeS3(assert) - s3p := s3TablePersister{s3svc, "bucket"} + s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, 3)} tableAddr, chunkCount := s3p.Compact(mt, nil) if assert.True(chunkCount > 0) { - buff, present := s3svc.data[tableAddr.String()] - assert.True(present) - tr := newTableReader(buff, bytes.NewReader(buff)) - for _, c := range testChunks { - assert.True(tr.has(computeAddr(c))) + if r := s3svc.readerForTable(tableAddr); assert.NotNil(r) { + assertChunksInReader(testChunks, r, assert) } } } +func calcPartSize(mt *memTable, maxPartNum int) int { + return int(maxTableSize(uint64(mt.count()), mt.totalData)) / maxPartNum +} + +func TestS3TablePersisterCompactSinglePart(t *testing.T) { + assert := assert.New(t) + mt := newMemTable(testMemTableSize) + + for _, c := range testChunks { + assert.True(mt.addChunk(computeAddr(c), c)) + } + + s3svc := makeFakeS3(assert) + s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, 1)} + + tableAddr, chunkCount := s3p.Compact(mt, nil) + if assert.True(chunkCount > 0) { + if r := s3svc.readerForTable(tableAddr); assert.NotNil(r) { + assertChunksInReader(testChunks, r, assert) + } + } +} + +func TestS3TablePersisterCompactAbort(t *testing.T) { + assert := assert.New(t) + mt := newMemTable(testMemTableSize) + + for _, c := range testChunks { + assert.True(mt.addChunk(computeAddr(c), c)) + } + + numParts := 4 + s3svc := &failingFakeS3{makeFakeS3(assert), sync.Mutex{}, 1} + s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: calcPartSize(mt, numParts)} + + assert.Panics(func() { s3p.Compact(mt, nil) }) +} + +type failingFakeS3 struct { + *fakeS3 + mu sync.Mutex + numSuccesses int +} + +func (m *failingFakeS3) UploadPart(input *s3.UploadPartInput) (*s3.UploadPartOutput, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.numSuccesses > 0 { + m.numSuccesses-- + return m.fakeS3.UploadPart(input) + } + return nil, mockAWSError("MalformedXML") +} + func TestS3TablePersisterCompactNoData(t *testing.T) { assert := assert.New(t) mt := newMemTable(testMemTableSize) @@ -105,7 +158,7 @@ func TestS3TablePersisterCompactNoData(t *testing.T) { } s3svc := makeFakeS3(assert) - s3p := s3TablePersister{s3svc, "bucket"} + s3p := s3TablePersister{s3: s3svc, bucket: "bucket", partSize: 1 << 10} tableAddr, chunkCount := s3p.Compact(mt, existingTable) assert.True(chunkCount == 0)