go/utils/remotesrv: Allow http-port and grpc-port to be the same. Allow GRPC server to echo :authority header if no -http-host is supplied.

This commit is contained in:
Aaron Son
2022-09-12 15:54:27 -07:00
parent af64b04ad2
commit 6099c71966
4 changed files with 104 additions and 37 deletions
+36 -8
View File
@@ -20,9 +20,11 @@ import (
"io"
"log"
"path/filepath"
"strings"
"sync/atomic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
@@ -130,6 +132,8 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
return nil, err
}
md, _ := metadata.FromIncomingContext(ctx)
var locs []*remotesapi.DownloadLoc
for loc, hashToRange := range locations {
var ranges []*remotesapi.RangeChunk
@@ -138,7 +142,7 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
ranges = append(ranges, &remotesapi.RangeChunk{Hash: hCpy[:], Offset: r.Offset, Length: r.Length})
}
url, err := rs.getDownloadUrl(logger, prefix+"/"+loc)
url, err := rs.getDownloadUrl(logger, md, prefix+"/"+loc)
if err != nil {
log.Println("Failed to sign request", err)
return nil, err
@@ -157,6 +161,8 @@ func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStore
logger := getReqLogger("GRPC", "StreamDownloadLocations")
defer func() { logger("finished") }()
md, _ := metadata.FromIncomingContext(stream.Context())
var repoID *remotesapi.RepoId
var cs RemoteSrvStore
var prefix string
@@ -197,7 +203,7 @@ func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStore
ranges = append(ranges, &remotesapi.RangeChunk{Hash: hCpy[:], Offset: r.Offset, Length: r.Length})
}
url, err := rs.getDownloadUrl(logger, prefix+"/"+loc)
url, err := rs.getDownloadUrl(logger, md, prefix+"/"+loc)
if err != nil {
log.Println("Failed to sign request", err)
return err
@@ -215,8 +221,21 @@ func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStore
}
}
func (rs *RemoteChunkStore) getDownloadUrl(logger func(string), path string) (string, error) {
return fmt.Sprintf("http://%s/%s", rs.HttpHost, path), nil
func (rs *RemoteChunkStore) getDownloadUrl(logger func(string), md metadata.MD, path string) (string, error) {
host := rs.HttpHost
if strings.HasPrefix(rs.HttpHost, ":") && rs.HttpHost != ":80" {
hosts := md.Get(":authority")
if len(hosts) > 0 {
host = strings.Split(hosts[0], ":")[0] + rs.HttpHost
}
} else if rs.HttpHost == "" || rs.HttpHost == ":80" {
hosts := md.Get(":authority")
if len(hosts) > 0 {
host = hosts[0]
}
}
return fmt.Sprintf("http://%s/%s", host, path), nil
}
func parseTableFileDetails(req *remotesapi.GetUploadLocsRequest) []*remotesapi.TableFileDetails {
@@ -404,12 +423,14 @@ func (rs *RemoteChunkStore) ListTableFiles(ctx context.Context, req *remotesapi.
return nil, status.Error(codes.Internal, "failed to get sources")
}
tableFileInfo, err := getTableFileInfo(rs, logger, tables, req, cs)
md, _ := metadata.FromIncomingContext(ctx)
tableFileInfo, err := getTableFileInfo(logger, md, rs, tables, req, cs)
if err != nil {
return nil, err
}
appendixTableFileInfo, err := getTableFileInfo(rs, logger, appendixTables, req, cs)
appendixTableFileInfo, err := getTableFileInfo(logger, md, rs, appendixTables, req, cs)
if err != nil {
return nil, err
}
@@ -423,14 +444,21 @@ func (rs *RemoteChunkStore) ListTableFiles(ctx context.Context, req *remotesapi.
return resp, nil
}
func getTableFileInfo(rs *RemoteChunkStore, logger func(string), tableList []nbs.TableFile, req *remotesapi.ListTableFilesRequest, cs RemoteSrvStore) ([]*remotesapi.TableFileInfo, error) {
func getTableFileInfo(
logger func(string),
md metadata.MD,
rs *RemoteChunkStore,
tableList []nbs.TableFile,
req *remotesapi.ListTableFilesRequest,
cs RemoteSrvStore,
) ([]*remotesapi.TableFileInfo, error) {
prefix, err := rs.getRelativeStorePath(cs)
if err != nil {
return nil, err
}
appendixTableFileInfo := make([]*remotesapi.TableFileInfo, 0)
for _, t := range tableList {
url, err := rs.getDownloadUrl(logger, prefix+"/"+t.FileID())
url, err := rs.getDownloadUrl(logger, md, prefix+"/"+t.FileID())
if err != nil {
return nil, status.Error(codes.Internal, "failed to get download url for "+t.FileID())
}
+47 -22
View File
@@ -20,8 +20,11 @@ import (
"log"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
@@ -46,49 +49,71 @@ func (s *Server) GracefulStop() {
func NewServer(httpHost string, httpPort, grpcPort int, fs filesys.Filesys, dbCache DBCache, readOnly bool) *Server {
s := new(Server)
s.stopChan = make(chan struct{})
s.wg.Add(4)
expectedFiles := newFileDetails()
s.wg.Add(2)
s.grpcPort = grpcPort
s.grpcSrv = grpc.NewServer(grpc.MaxRecvMsgSize(128 * 1024 * 1024))
var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(httpHost, dbCache, expectedFiles, fs)
if readOnly {
chnkSt = ReadOnlyChunkStore{chnkSt}
}
var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(httpHost, dbCache, expectedFiles, fs)
if readOnly {
chnkSt = ReadOnlyChunkStore{chnkSt}
}
remotesapi.RegisterChunkStoreServiceServer(s.grpcSrv, chnkSt)
var handler http.Handler = newFileHandler(dbCache, expectedFiles, fs, readOnly)
if httpPort == grpcPort {
handler = grpcMultiplexHandler(s.grpcSrv, handler)
} else {
s.wg.Add(2)
}
s.httpPort = httpPort
s.httpSrv = http.Server{
Addr: fmt.Sprintf(":%d", httpPort),
Handler: newFileHandler(dbCache, expectedFiles, fs, readOnly),
Handler: handler,
}
return s
}
func grpcMultiplexHandler(grpcSrv *grpc.Server, handler http.Handler) http.Handler {
h2s := &http2.Server{}
newHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor == 2 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") {
grpcSrv.ServeHTTP(w, r)
} else {
handler.ServeHTTP(w, r)
}
})
return h2c.NewHandler(newHandler, h2s)
}
func (s *Server) Serve() error {
grpcListener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.grpcPort))
if err != nil {
return err
}
httpListener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.httpPort))
if err != nil {
grpcListener.Close()
return err
}
go func() {
defer s.wg.Done()
log.Println("Starting grpc server on port", s.grpcPort)
err = s.grpcSrv.Serve(grpcListener)
log.Println("grpc server exited. error:", err)
}()
go func() {
defer s.wg.Done()
<-s.stopChan
s.grpcSrv.GracefulStop()
}()
if s.grpcPort != s.httpPort {
grpcListener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.grpcPort))
if err != nil {
httpListener.Close()
return err
}
go func() {
defer s.wg.Done()
log.Println("Starting grpc server on port", s.grpcPort)
err = s.grpcSrv.Serve(grpcListener)
log.Println("grpc server exited. error:", err)
}()
go func() {
defer s.wg.Done()
<-s.stopChan
s.grpcSrv.GracefulStop()
}()
}
go func() {
defer s.wg.Done()
+6 -7
View File
@@ -29,14 +29,13 @@ import (
"github.com/dolthub/dolt/go/store/datas"
)
var readOnlyParam *bool = flag.Bool("read-only", false, "run a read-only server which does not allow writes")
func main() {
repoModeParam := flag.Bool("repo-mode", false, "act as a remote for a dolt directory, instead of stand alone")
dirParam := flag.String("dir", "", "root directory that this command will run in.")
grpcPortParam := flag.Int("grpc-port", -1, "root directory that this command will run in.")
httpPortParam := flag.Int("http-port", -1, "root directory that this command will run in.")
httpHostParam := flag.String("http-host", "localhost", "host url that this command will assume.")
readOnlyParam := flag.Bool("read-only", false, "run a read-only server which does not allow writes")
repoModeParam := flag.Bool("repo-mode", false, "act as a remote for an existing dolt directory, instead of stand alone")
dirParam := flag.String("dir", "", "root directory that this command will run in; default cwd")
grpcPortParam := flag.Int("grpc-port", -1, "the port the grpc server will listen on; default 50051")
httpPortParam := flag.Int("http-port", -1, "the port the http server will listen on; default 80; if http-port is equal to grpc-port, both services will serve over the same port")
httpHostParam := flag.String("http-host", "", "hostname to use in the host component of the URLs that the server generates; default ''; if '', server will echo the :authority header")
flag.Parse()
if dirParam != nil && len(*dirParam) > 0 {
+15
View File
@@ -94,3 +94,18 @@ teardown() {
run dolt push origin main:main
[[ "$status" != 0 ]] || false
}
@test "remotesrv: can run grpc and http on same port" {
mkdir remote
cd remote
dolt init
dolt sql -q 'create table vals (i int);'
dolt add vals
dolt commit -m 'create vals table.'
remotesrv --grpc-port 1234 --http-port 1234 --repo-mode &
remotesrv_pid=$!
cd ../
dolt clone http://localhost:1234/test-org/test-repo repo1
}