mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-10 03:09:35 -06:00
Merge pull request #1686 from dolthub/aaron/remotestorage-streaming-download-locations
go/libraries/doltcore/remotestorage: chunk_store.go: Stream chunk download locations instead of batching.
This commit is contained in:
@@ -249,8 +249,7 @@ func (dcs *DoltChunkStore) GetManyCompressed(ctx context.Context, hashes hash.Ha
|
||||
}
|
||||
|
||||
const (
|
||||
getLocsBatchSize = 32 * 1024
|
||||
getLocsMaxConcurrency = 4
|
||||
getLocsBatchSize = (4 * 1024) / 20
|
||||
)
|
||||
|
||||
type GetRange remotesapi.HttpGetRange
|
||||
@@ -351,10 +350,14 @@ func (gr *GetRange) GetDownloadFunc(ctx context.Context, fetcher HTTPFetcher, ch
|
||||
}
|
||||
|
||||
func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (map[string]*GetRange, error) {
|
||||
span, ctx := tracing.StartSpan(ctx, "remotestorage.getDLLocs")
|
||||
span.LogKV("num_hashes", len(hashes))
|
||||
defer span.Finish()
|
||||
|
||||
res := make(map[string]*GetRange)
|
||||
|
||||
// channel for receiving results from go routines making grpc calls to get download locations for chunks
|
||||
dlLocChan := make(chan []*remotesapi.HttpGetRange)
|
||||
resCh := make(chan []*remotesapi.HttpGetRange)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
@@ -362,7 +365,7 @@ func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (m
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case locs, ok := <-dlLocChan:
|
||||
case locs, ok := <-resCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@@ -380,47 +383,67 @@ func (dcs *DoltChunkStore) getDLLocs(ctx context.Context, hashes []hash.Hash) (m
|
||||
}
|
||||
})
|
||||
|
||||
hashesBytes := HashesToSlices(hashes)
|
||||
var work []func() error
|
||||
|
||||
// batchItr creates work functions which request a batch of chunk download locations and write the results to the
|
||||
// dlLocChan
|
||||
batchItr(len(hashesBytes), getLocsBatchSize, func(st, end int) (stop bool) {
|
||||
batch := hashesBytes[st:end]
|
||||
f := func() error {
|
||||
req := &remotesapi.GetDownloadLocsRequest{RepoId: dcs.getRepoId(), ChunkHashes: batch}
|
||||
resp, err := dcs.csClient.GetDownloadLocations(ctx, req)
|
||||
if err != nil {
|
||||
return NewRpcError(err, "GetDownloadLocations", dcs.host, req)
|
||||
}
|
||||
tosend := make([]*remotesapi.HttpGetRange, len(resp.Locs))
|
||||
for i, l := range resp.Locs {
|
||||
tosend[i] = l.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange
|
||||
}
|
||||
select {
|
||||
case dlLocChan <- tosend:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
}
|
||||
work = append(work, f)
|
||||
return false
|
||||
})
|
||||
|
||||
span, ctx := tracing.StartSpan(ctx, "remotestorage.getDLLocs")
|
||||
span.LogKV("num_batches", len(work), "num_hashes", len(hashes))
|
||||
defer span.Finish()
|
||||
|
||||
// execute the work and close the channel after as no more results will come in
|
||||
// go routine for batching the get location requests, streaming the requests and streaming the responses.
|
||||
eg.Go(func() error {
|
||||
defer close(dlLocChan)
|
||||
return concurrentExec(work, getLocsMaxConcurrency)
|
||||
var reqs []*remotesapi.GetDownloadLocsRequest
|
||||
hashesBytes := HashesToSlices(hashes)
|
||||
batchItr(len(hashesBytes), getLocsBatchSize, func(st, end int) (stop bool) {
|
||||
batch := hashesBytes[st:end]
|
||||
req := &remotesapi.GetDownloadLocsRequest{RepoId: dcs.getRepoId(), ChunkHashes: batch}
|
||||
reqs = append(reqs, req)
|
||||
return false
|
||||
})
|
||||
op := func() error {
|
||||
stream, err := dcs.csClient.StreamDownloadLocations(ctx)
|
||||
if err != nil {
|
||||
return NewRpcError(err, "StreamDownloadLocations", dcs.host, nil)
|
||||
}
|
||||
seg, ctx := errgroup.WithContext(ctx)
|
||||
completedReqs := 0
|
||||
// Write requests
|
||||
seg.Go(func() error {
|
||||
for i := range reqs {
|
||||
if err := stream.Send(reqs[i]); err != nil {
|
||||
return NewRpcError(err, "StreamDownloadLocations", dcs.host, reqs[i])
|
||||
}
|
||||
}
|
||||
return stream.CloseSend()
|
||||
})
|
||||
// Read responses
|
||||
seg.Go(func() error {
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return NewRpcError(err, "StreamDownloadLocations", dcs.host, reqs[completedReqs])
|
||||
}
|
||||
tosend := make([]*remotesapi.HttpGetRange, len(resp.Locs))
|
||||
for i, l := range resp.Locs {
|
||||
tosend[i] = l.Location.(*remotesapi.DownloadLoc_HttpGetRange).HttpGetRange
|
||||
}
|
||||
select {
|
||||
case resCh <- tosend:
|
||||
completedReqs += 1
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
err = seg.Wait()
|
||||
reqs = reqs[completedReqs:]
|
||||
if len(reqs) == 0 {
|
||||
close(resCh)
|
||||
}
|
||||
return processGrpcErr(err)
|
||||
}
|
||||
return backoff.Retry(op, backoff.WithMaxRetries(csRetryParams, csClientRetries))
|
||||
})
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/remotestorage"
|
||||
@@ -116,6 +118,7 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
|
||||
url, err := rs.getDownloadUrl(logger, org, repoName, loc.String())
|
||||
if err != nil {
|
||||
log.Println("Failed to sign request", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger("The URL is " + url)
|
||||
@@ -127,6 +130,64 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
|
||||
return &remotesapi.GetDownloadLocsResponse{Locs: locs}, nil
|
||||
}
|
||||
|
||||
func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStoreService_StreamDownloadLocationsServer) error {
|
||||
logger := getReqLogger("GRPC", "StreamDownloadLocations")
|
||||
defer func() { logger("finished") }()
|
||||
|
||||
var repoID *remotesapi.RepoId
|
||||
var cs *nbs.NomsBlockStore
|
||||
for {
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if !proto.Equal(req.RepoId, repoID) {
|
||||
repoID = req.RepoId
|
||||
cs = rs.getStore(repoID, "StreamDownloadLoctions")
|
||||
if cs == nil {
|
||||
return status.Error(codes.Internal, "Could not get chunkstore")
|
||||
}
|
||||
logger(fmt.Sprintf("found repo %s/%s", repoID.Org, repoID.RepoName))
|
||||
}
|
||||
|
||||
org := req.RepoId.Org
|
||||
repoName := req.RepoId.RepoName
|
||||
hashes, _ := remotestorage.ParseByteSlices(req.ChunkHashes)
|
||||
locations, err := cs.GetChunkLocations(hashes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var locs []*remotesapi.DownloadLoc
|
||||
for loc, hashToRange := range locations {
|
||||
var ranges []*remotesapi.RangeChunk
|
||||
for h, r := range hashToRange {
|
||||
hCpy := h
|
||||
ranges = append(ranges, &remotesapi.RangeChunk{Hash: hCpy[:], Offset: r.Offset, Length: r.Length})
|
||||
}
|
||||
|
||||
url, err := rs.getDownloadUrl(logger, org, repoName, loc.String())
|
||||
if err != nil {
|
||||
log.Println("Failed to sign request", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger("The URL is " + url)
|
||||
|
||||
getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges}
|
||||
locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}})
|
||||
}
|
||||
|
||||
if err := stream.Send(&remotesapi.GetDownloadLocsResponse{Locs: locs}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RemoteChunkStore) getDownloadUrl(logger func(string), org, repoName, fileId string) (string, error) {
|
||||
return fmt.Sprintf("http://%s/%s/%s/%s", rs.HttpHost, org, repoName, fileId), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user