Merge pull request #1686 from dolthub/aaron/remotestorage-streaming-download-locations

go/libraries/doltcore/remotestorage: chunk_store.go: Stream chunk download locations instead of batching.
This commit is contained in:
Aaron Son
2021-05-07 16:27:30 -07:00
committed by GitHub
2 changed files with 123 additions and 39 deletions

View File

@@ -249,8 +249,7 @@ func (dcs *DoltChunkStore) GetManyCompressed(ctx context.Context, hashes hash.Ha
}
const (
getLocsBatchSize = 32 * 1024
getLocsMaxConcurrency = 4
getLocsBatchSize = (4 * 1024) / 20
)
type GetRange remotesapi.HttpGetRange
@@ -351,10 +350,14 @@ func (gr *GetRange) GetDownloadFunc(ctx context.Context, fetcher HTTPFetcher, ch
}
func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (map[string]*GetRange, error) {
span, ctx := tracing.StartSpan(ctx, "remotestorage.getDLLocs")
span.LogKV("num_hashes", len(hashes))
defer span.Finish()
res := make(map[string]*GetRange)
// channel for receiving results from go routines making grpc calls to get download locations for chunks
dlLocChan := make(chan []*remotesapi.HttpGetRange)
resCh := make(chan []*remotesapi.HttpGetRange)
eg, ctx := errgroup.WithContext(ctx)
@@ -362,7 +365,7 @@ func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (m
eg.Go(func() error {
for {
select {
case locs, ok := <-dlLocChan:
case locs, ok := <-resCh:
if !ok {
return nil
}
@@ -380,47 +383,67 @@ func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (m
}
})
hashesBytes := HashesToSlices(hashes)
var work []func() error
// batchItr creates work functions which request a batch of chunk download locations and write the results to the
// dlLocChan
batchItr(len(hashesBytes), getLocsBatchSize, func(st, end int) (stop bool) {
batch := hashesBytes[st:end]
f := func() error {
req := &remotesapi.GetDownloadLocsRequest{RepoId: dcs.getRepoId(), ChunkHashes: batch}
resp, err := dcs.csClient.GetDownloadLocations(ctx, req)
if err != nil {
return NewRpcError(err, "GetDownloadLocations", dcs.host, req)
}
tosend := make([]*remotesapi.HttpGetRange, len(resp.Locs))
for i, l := range resp.Locs {
tosend[i] = l.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange
}
select {
case dlLocChan <- tosend:
case <-ctx.Done():
}
return nil
}
work = append(work, f)
return false
})
span, ctx := tracing.StartSpan(ctx, "remotestorage.getDLLocs")
span.LogKV("num_batches", len(work), "num_hashes", len(hashes))
defer span.Finish()
// execute the work and close the channel after as no more results will come in
// go routine for batching the get location requests, streaming the requests and streaming the responses.
eg.Go(func() error {
defer close(dlLocChan)
return concurrentExec(work, getLocsMaxConcurrency)
var reqs []*remotesapi.GetDownloadLocsRequest
hashesBytes := HashesToSlices(hashes)
batchItr(len(hashesBytes), getLocsBatchSize, func(st, end int) (stop bool) {
batch := hashesBytes[st:end]
req := &remotesapi.GetDownloadLocsRequest{RepoId: dcs.getRepoId(), ChunkHashes: batch}
reqs = append(reqs, req)
return false
})
op := func() error {
stream, err := dcs.csClient.StreamDownloadLocations(ctx)
if err != nil {
return NewRpcError(err, "StreamDownloadLocations", dcs.host, nil)
}
seg, ctx := errgroup.WithContext(ctx)
completedReqs := 0
// Write requests
seg.Go(func() error {
for i := range reqs {
if err := stream.Send(reqs[i]); err != nil {
return NewRpcError(err, "StreamDownloadLocations", dcs.host, reqs[i])
}
}
return stream.CloseSend()
})
// Read responses
seg.Go(func() error {
for {
resp, err := stream.Recv()
if err != nil {
if err == io.EOF {
return nil
}
return NewRpcError(err, "StreamDownloadLocations", dcs.host, reqs[completedReqs])
}
tosend := make([]*remotesapi.HttpGetRange, len(resp.Locs))
for i, l := range resp.Locs {
tosend[i] = l.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange
}
select {
case resCh <- tosend:
completedReqs += 1
case <-ctx.Done():
return ctx.Err()
}
}
})
err = seg.Wait()
reqs = reqs[completedReqs:]
if len(reqs) == 0 {
close(resCh)
}
return processGrpcErr(err)
}
return backoff.Retry(op, backoff.WithMaxRetries(csRetryParams, csClientRetries))
})
if err := eg.Wait(); err != nil {
return nil, err
}
return res, nil
}

View File

@@ -17,6 +17,7 @@ package main
import (
"context"
"fmt"
"io"
"log"
"os"
"path/filepath"
@@ -24,6 +25,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
"github.com/dolthub/dolt/go/libraries/doltcore/remotestorage"
@@ -116,6 +118,7 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
url, err := rs.getDownloadUrl(logger, org, repoName, loc.String())
if err != nil {
log.Println("Failed to sign request", err)
return nil, err
}
logger("The URL is " + url)
@@ -127,6 +130,64 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
return &remotesapi.GetDownloadLocsResponse{Locs: locs}, nil
}
func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStoreService_StreamDownloadLocationsServer) error {
logger := getReqLogger("GRPC", "StreamDownloadLocations")
defer func() { logger("finished") }()
var repoID *remotesapi.RepoId
var cs *nbs.NomsBlockStore
for {
req, err := stream.Recv()
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if !proto.Equal(req.RepoId, repoID) {
repoID = req.RepoId
cs = rs.getStore(repoID, "StreamDownloadLoctions")
if cs == nil {
return status.Error(codes.Internal, "Could not get chunkstore")
}
logger(fmt.Sprintf("found repo %s/%s", repoID.Org, repoID.RepoName))
}
org := req.RepoId.Org
repoName := req.RepoId.RepoName
hashes, _ := remotestorage.ParseByteSlices(req.ChunkHashes)
locations, err := cs.GetChunkLocations(hashes)
if err != nil {
return err
}
var locs []*remotesapi.DownloadLoc
for loc, hashToRange := range locations {
var ranges []*remotesapi.RangeChunk
for h, r := range hashToRange {
hCpy := h
ranges = append(ranges, &remotesapi.RangeChunk{Hash: hCpy[:], Offset: r.Offset, Length: r.Length})
}
url, err := rs.getDownloadUrl(logger, org, repoName, loc.String())
if err != nil {
log.Println("Failed to sign request", err)
return err
}
logger("The URL is " + url)
getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges}
locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}})
}
if err := stream.Send(&remotesapi.GetDownloadLocsResponse{Locs: locs}); err != nil {
return err
}
}
}
func (rs *RemoteChunkStore) getDownloadUrl(logger func(string), org, repoName, fileId string) (string, error) {
return fmt.Sprintf("http://%s/%s/%s/%s", rs.HttpHost, org, repoName, fileId), nil
}