Merge pull request #7592 from dolthub/aaron/fix-parallel-hasmany

go/store/datas/pull: Restore puller optimization for parallel HasMany calls. Fix crash in pull_chunk_tracker.
This commit is contained in:
Aaron Son
2024-03-14 13:45:35 -07:00
committed by GitHub
3 changed files with 583 additions and 31 deletions

View File

@@ -0,0 +1,282 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pull
import (
"context"
"errors"
"sync"
"github.com/dolthub/dolt/go/store/hash"
)
type HasManyer interface {
HasMany(context.Context, hash.HashSet) (hash.HashSet, error)
}
type TrackerConfig struct {
BatchSize int
HasManyer HasManyer
}
const hasManyThreadCount = 3
// A PullChunkTracker keeps track of seen chunk addresses and returns every
// seen chunk address which is not already in the destination database exactly
// once. A Puller instantiantes one of these with the initial set of addresses
// to pull, and repeatedly calls |GetChunksToFetch|. It passes in all
// references it finds in the fetched chunks to |Seen|, and continues to call
// |GetChunksToFetch| and deliver new addresses to |Seen| until
// |GetChunksToFetch| returns |false| from its |more| return boolean.
//
// PullChunkTracker is able to call |HasMany| on the destination database in
// parallel with other work the Puller does and abstracts out the logic for
// keeping track of seen, unchecked and to pull hcunk addresses.
type PullChunkTracker struct {
ctx context.Context
seen hash.HashSet
cfg TrackerConfig
wg sync.WaitGroup
uncheckedCh chan hash.Hash
reqCh chan *trackerGetAbsentReq
}
func NewPullChunkTracker(ctx context.Context, initial hash.HashSet, cfg TrackerConfig) *PullChunkTracker {
ret := &PullChunkTracker{
ctx: ctx,
seen: make(hash.HashSet),
cfg: cfg,
uncheckedCh: make(chan hash.Hash),
reqCh: make(chan *trackerGetAbsentReq),
}
ret.seen.InsertAll(initial)
ret.wg.Add(1)
go func() {
defer ret.wg.Done()
ret.reqRespThread(initial)
}()
return ret
}
func (t *PullChunkTracker) Seen(h hash.Hash) {
if !t.seen.Has(h) {
t.seen.Insert(h)
t.addUnchecked(h)
}
}
func (t *PullChunkTracker) Close() {
close(t.uncheckedCh)
t.wg.Wait()
}
func (t *PullChunkTracker) addUnchecked(h hash.Hash) {
select {
case t.uncheckedCh <- h:
case <-t.ctx.Done():
}
}
func (t *PullChunkTracker) GetChunksToFetch() (hash.HashSet, bool, error) {
var req trackerGetAbsentReq
req.ready = make(chan struct{})
select {
case t.reqCh <- &req:
case <-t.ctx.Done():
return nil, false, t.ctx.Err()
}
select {
case <-req.ready:
case <-t.ctx.Done():
return nil, false, t.ctx.Err()
}
return req.hs, req.ok, req.err
}
// The main logic of the PullChunkTracker, receives requests from other threads
// and responds to them.
func (t *PullChunkTracker) reqRespThread(initial hash.HashSet) {
doneCh := make(chan struct{})
hasManyReqCh := make(chan trackerHasManyReq)
hasManyRespCh := make(chan trackerHasManyResp)
var wg sync.WaitGroup
wg.Add(hasManyThreadCount)
for i := 0; i < hasManyThreadCount; i++ {
go func() {
defer wg.Done()
hasManyThread(t.ctx, t.cfg.HasManyer, hasManyReqCh, hasManyRespCh, doneCh)
}()
}
defer func() {
close(doneCh)
wg.Wait()
}()
unchecked := make([]hash.HashSet, 0)
absent := make([]hash.HashSet, 0)
var err error
outstanding := 0
if len(initial) > 0 {
unchecked = append(unchecked, initial)
outstanding += 1
}
for {
var thisReqCh = t.reqCh
if outstanding != 0 && len(absent) == 0 {
// If we are waiting for a HasMany response and we don't currently have any
// absent addresses to return, block any absent requests.
thisReqCh = nil
}
var thisHasManyReqCh chan trackerHasManyReq
var hasManyReq trackerHasManyReq
if len(unchecked) > 0 {
hasManyReq.hs = unchecked[0]
thisHasManyReqCh = hasManyReqCh
}
select {
case h, ok := <-t.uncheckedCh:
if !ok {
return
}
if len(unchecked) == 0 || len(unchecked[len(unchecked)-1]) >= t.cfg.BatchSize {
outstanding += 1
unchecked = append(unchecked, make(hash.HashSet))
}
unchecked[len(unchecked)-1].Insert(h)
case resp := <-hasManyRespCh:
outstanding -= 1
if resp.err != nil {
err = errors.Join(err, resp.err)
} else if len(resp.hs) > 0 {
absent = append(absent, resp.hs)
}
case thisHasManyReqCh <- hasManyReq:
copy(unchecked[:], unchecked[1:])
if len(unchecked) > 1 {
unchecked[len(unchecked)-1] = nil
}
unchecked = unchecked[:len(unchecked)-1]
case req := <-thisReqCh:
if err != nil {
req.err = err
close(req.ready)
err = nil
} else if len(absent) == 0 {
req.ok = false
close(req.ready)
} else {
req.ok = true
req.hs = absent[0]
var i int
for i = 1; i < len(absent); i++ {
l := len(absent[i])
if len(req.hs)+l < t.cfg.BatchSize {
req.hs.InsertAll(absent[i])
} else {
for h := range absent[i] {
if len(req.hs) >= t.cfg.BatchSize {
break
}
req.hs.Insert(h)
absent[i].Remove(h)
}
break
}
}
copy(absent[:], absent[i:])
for j := len(absent) - i; j < len(absent); j++ {
absent[j] = nil
}
absent = absent[:len(absent)-i]
close(req.ready)
}
case <-t.ctx.Done():
return
}
}
}
// Run by a PullChunkTracker, calls HasMany on a batch of addresses and delivers the results.
func hasManyThread(ctx context.Context, hasManyer HasManyer, reqCh <-chan trackerHasManyReq, respCh chan<- trackerHasManyResp, doneCh <-chan struct{}) {
for {
select {
case req := <-reqCh:
hs, err := hasManyer.HasMany(ctx, req.hs)
if err != nil {
select {
case respCh <- trackerHasManyResp{err: err}:
case <-ctx.Done():
return
case <-doneCh:
return
}
} else {
select {
case respCh <- trackerHasManyResp{hs: hs}:
case <-ctx.Done():
return
case <-doneCh:
return
}
}
case <-doneCh:
return
case <-ctx.Done():
return
}
}
}
// Sent by the tracker thread to a HasMany thread, includes a batch of
// addresses to HasMany. The response comes back to the tracker thread on a
// separate channel as a |trackerHasManyResp|.
type trackerHasManyReq struct {
hs hash.HashSet
}
// Sent by the HasMany thread back to the tracker thread.
// If HasMany returned an error, it will be returned here.
type trackerHasManyResp struct {
hs hash.HashSet
err error
}
// Sent by a client calling |GetChunksToFetch| to the tracker thread. The
// tracker thread will return a batch of chunk addresses that need to be
// fetched from source and added to destination.
//
// This will block until HasMany requests are completed.
//
// If |ok| is |false|, then the Tracker is closing because every absent address
// has been delivered.
type trackerGetAbsentReq struct {
hs hash.HashSet
err error
ok bool
ready chan struct{}
}

View File

@@ -0,0 +1,287 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pull
import (
"context"
"errors"
"testing"
"github.com/dolthub/dolt/go/store/hash"
"github.com/stretchr/testify/assert"
)
func TestPullChunkTracker(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
tracker := NewPullChunkTracker(context.Background(), make(hash.HashSet), TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: nil,
})
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 0)
assert.False(t, ok)
assert.NoError(t, err)
tracker.Close()
})
t.Run("HasAllInitial", func(t *testing.T) {
hs := make(hash.HashSet)
for i := byte(0); i < byte(10); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: hasAllHaser{},
})
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 0)
assert.False(t, ok)
assert.NoError(t, err)
tracker.Close()
})
t.Run("HasNoneInitial", func(t *testing.T) {
hs := make(hash.HashSet)
for i := byte(1); i <= byte(10); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: hasNoneHaser{},
})
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 10)
assert.True(t, ok)
assert.NoError(t, err)
hs, ok, err = tracker.GetChunksToFetch()
assert.Len(t, hs, 0)
assert.False(t, ok)
assert.NoError(t, err)
for i := byte(1); i <= byte(10); i++ {
var h hash.Hash
h[1] = i
tracker.Seen(h)
}
cnt := 0
for {
hs, ok, err := tracker.GetChunksToFetch()
assert.NoError(t, err)
if !ok {
assert.Equal(t, 10, cnt)
break
}
cnt += len(hs)
}
tracker.Close()
})
t.Run("HasManyError", func(t *testing.T) {
hs := make(hash.HashSet)
for i := byte(0); i < byte(10); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: errHaser{},
})
_, _, err := tracker.GetChunksToFetch()
assert.Error(t, err)
tracker.Close()
})
t.Run("InitialAreSeen", func(t *testing.T) {
hs := make(hash.HashSet)
for i := byte(0); i < byte(10); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: hasNoneHaser{},
})
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 10)
assert.True(t, ok)
assert.NoError(t, err)
for i := byte(0); i < byte(10); i++ {
var h hash.Hash
h[0] = i
tracker.Seen(h)
}
hs, ok, err = tracker.GetChunksToFetch()
assert.Len(t, hs, 0)
assert.False(t, ok)
assert.NoError(t, err)
tracker.Close()
})
t.Run("StaticHaser", func(t *testing.T) {
haser := staticHaser{make(hash.HashSet)}
initial := make([]hash.Hash, 4)
initial[0][0] = 1
initial[1][0] = 2
initial[2][0] = 1
initial[2][1] = 1
initial[3][0] = 1
initial[3][1] = 2
haser.has.Insert(initial[0])
haser.has.Insert(initial[1])
haser.has.Insert(initial[2])
haser.has.Insert(initial[3])
hs := make(hash.HashSet)
// Start with 1 - 5
for i := byte(1); i <= byte(5); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 64 * 1024,
HasManyer: haser,
})
// Should get back 03, 04, 05
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 3)
assert.True(t, ok)
assert.NoError(t, err)
for i := byte(1); i <= byte(10); i++ {
var h hash.Hash
h[0] = 1
h[1] = i
tracker.Seen(h)
}
// Should get back 13, 14, 15, 16, 17, 18, 19, 1(10).
cnt := 0
for {
hs, ok, err := tracker.GetChunksToFetch()
assert.NoError(t, err)
if !ok {
break
}
cnt += len(hs)
}
assert.Equal(t, 8, cnt)
tracker.Close()
})
t.Run("SmallBatches", func(t *testing.T) {
haser := staticHaser{make(hash.HashSet)}
initial := make([]hash.Hash, 4)
initial[0][0] = 1
initial[1][0] = 2
initial[2][0] = 1
initial[2][1] = 1
initial[3][0] = 1
initial[3][1] = 2
haser.has.Insert(initial[0])
haser.has.Insert(initial[1])
haser.has.Insert(initial[2])
haser.has.Insert(initial[3])
hs := make(hash.HashSet)
// Start with 1 - 5
for i := byte(1); i <= byte(5); i++ {
var h hash.Hash
h[0] = i
hs.Insert(h)
}
tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
BatchSize: 1,
HasManyer: haser,
})
// First call doesn't actually respect batch size.
hs, ok, err := tracker.GetChunksToFetch()
assert.Len(t, hs, 3)
assert.True(t, ok)
assert.NoError(t, err)
for i := byte(1); i <= byte(10); i++ {
var h hash.Hash
h[0] = 1
h[1] = i
tracker.Seen(h)
}
// Should get back 13, 14, 15, 16, 17, 18, 19, 1(10); one at a time.
cnt := 0
for {
hs, ok, err := tracker.GetChunksToFetch()
assert.NoError(t, err)
if !ok {
break
}
assert.Len(t, hs, 1)
cnt += len(hs)
}
assert.Equal(t, 8, cnt)
tracker.Close()
})
}
type hasAllHaser struct {
}
func (hasAllHaser) HasMany(context.Context, hash.HashSet) (hash.HashSet, error) {
return make(hash.HashSet), nil
}
type hasNoneHaser struct {
}
func (hasNoneHaser) HasMany(ctx context.Context, hs hash.HashSet) (hash.HashSet, error) {
return hs, nil
}
type staticHaser struct {
has hash.HashSet
}
func (s staticHaser) HasMany(ctx context.Context, query hash.HashSet) (hash.HashSet, error) {
ret := make(hash.HashSet)
for h := range query {
if !s.has.Has(h) {
ret.Insert(h)
}
}
return ret, nil
}
type errHaser struct {
}
func (errHaser) HasMany(ctx context.Context, hs hash.HashSet) (hash.HashSet, error) {
return nil, errors.New("always throws an error")
}

View File

@@ -45,7 +45,6 @@ var ErrDBUpToDate = errors.New("the database does not need to be pulled as it's
var ErrIncompatibleSourceChunkStore = errors.New("the chunk store of the source database does not implement NBSCompressedChunkStore.")
const (
maxChunkWorkers = 2
outstandingTableFiles = 2
)
@@ -251,6 +250,7 @@ func emitStats(s *stats, ch chan Stats) (cancel func()) {
defer wg.Done()
updateduration := 1 * time.Second
ticker := time.NewTicker(updateduration)
defer ticker.Stop()
for {
select {
case <-ticker.C:
@@ -328,6 +328,7 @@ func (p *Puller) uploadTempTableFile(ctx context.Context, tmpTblFile tempTblFile
atomic.AddUint64(&p.stats.bufferedSendBytes, uint64(localUploaded))
localUploaded = 0
}
fWithStats := countingReader{countingReader{rc, &localUploaded}, &p.stats.finishedSendBytes}
return fWithStats, uint64(fileSize), nil
@@ -398,35 +399,21 @@ func (p *Puller) Pull(ctx context.Context) error {
}
const batchSize = 64 * 1024
// refs are added to |visited| on first sight
visited := p.hashes
// |absent| are visited, un-batched refs
absent := p.hashes.Copy()
// |batches| are visited, un-fetched refs
batches := make([]hash.HashSet, 0, 64)
tracker := NewPullChunkTracker(ctx, p.hashes, TrackerConfig{
BatchSize: batchSize,
HasManyer: p.sinkDBCS,
})
for absent.Size() > 0 || len(batches) > 0 {
if absent.Size() >= batchSize {
var bb []hash.HashSet
absent, bb = batchNovel(absent, batchSize)
batches = append(batches, bb...)
}
if len(batches) == 0 {
batches = append(batches, absent)
absent = make(hash.HashSet)
}
b := batches[len(batches)-1]
batches = batches[:len(batches)-1]
b, err = p.sinkDBCS.HasMany(ctx, b)
for {
toFetch, hasMore, err := tracker.GetChunksToFetch()
if err != nil {
return err
} else if b.Size() == 0 {
continue
}
if !hasMore {
break
}
err = p.getCmp(ctx, b, absent, visited, completedTables)
err = p.getCmp(ctx, toFetch, tracker, completedTables)
if err != nil {
return err
}
@@ -460,7 +447,7 @@ func batchNovel(absent hash.HashSet, batch int) (remainder hash.HashSet, batches
return
}
func (p *Puller) getCmp(ctx context.Context, batch, absent, visited hash.HashSet, completedTables chan FilledWriters) error {
func (p *Puller) getCmp(ctx context.Context, batch hash.HashSet, tracker *PullChunkTracker, completedTables chan FilledWriters) error {
found := make(chan nbs.CompressedChunk, 4096)
processed := make(chan CmpChnkAndRefs, 4096)
@@ -496,11 +483,7 @@ func (p *Puller) getCmp(ctx context.Context, batch, absent, visited hash.HashSet
return err
}
err = p.waf(chnk, func(h hash.Hash, _ bool) error {
if !visited.Has(h) {
// first sight of |h|
visited.Insert(h)
absent.Insert(h)
}
tracker.Seen(h)
return nil
})
if err != nil {