diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index ea8dbc7e6e..3d550e2058 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -232,7 +232,12 @@ func Serve( HttpPort: *serverConfig.RemotesapiPort(), GrpcPort: *serverConfig.RemotesapiPort(), }) - remoteSrv = remotesrv.NewServer(args) + remoteSrv, err = remotesrv.NewServer(args) + if err != nil { + lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) + startError = err + return + } listeners, err := remoteSrv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners on port %d: %v", *serverConfig.RemotesapiPort(), err) @@ -256,7 +261,12 @@ func Serve( args := clusterController.RemoteSrvServerArgs(remoteSrvSqlCtx, remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), }) - clusterRemoteSrv = remotesrv.NewServer(args) + clusterRemoteSrv, err = remotesrv.NewServer(args) + if err != nil { + lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) + startError = err + return + } listeners, err := clusterRemoteSrv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners for cluster config on port %d: %v", clusterController.RemoteSrvPort(), err) diff --git a/go/libraries/doltcore/remotesrv/grpc.go b/go/libraries/doltcore/remotesrv/grpc.go index 4f8407a6f4..c00ec393a5 100644 --- a/go/libraries/doltcore/remotesrv/grpc.go +++ b/go/libraries/doltcore/remotesrv/grpc.go @@ -47,10 +47,11 @@ type RemoteChunkStore struct { bucket string fs filesys.Filesys lgr *logrus.Entry + sealer Sealer remotesapi.UnimplementedChunkStoreServiceServer } -func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCache, fs filesys.Filesys) *RemoteChunkStore { +func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCache, fs filesys.Filesys, sealer Sealer) *RemoteChunkStore { return &RemoteChunkStore{ HttpHost: httpHost, csCache: csCache, @@ -59,6 +60,7 @@ func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCac lgr: lgr.WithFields(logrus.Fields{ "service": "dolt.services.remotesapi.v1alpha1.ChunkStoreServiceServer", }), + sealer: sealer, } } @@ -177,10 +179,15 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot logger.Println("Failed to sign request", err) return nil, err } + preurl := url.String() + url, err = rs.sealer.Seal(url) + if err != nil { + logger.Println("Failed to seal request", err) + return nil, err + } + logger.Println("The URL is", preurl, "the ranges are", ranges, "sealed url", url.String()) - logger.Println("The URL is", url) - - getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges} + getRange := &remotesapi.HttpGetRange{Url: url.String(), Ranges: ranges} locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}}) } @@ -242,10 +249,15 @@ func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStore logger.Println("Failed to sign request", err) return err } + preurl := url.String() + url, err = rs.sealer.Seal(url) + if err != nil { + logger.Println("Failed to seal request", err) + return err + } + logger.Println("The URL is", preurl, "the ranges are", ranges, "sealed url", url.String()) - logger.Println("The URL is", url) - - getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges} + getRange := &remotesapi.HttpGetRange{Url: url.String(), Ranges: ranges} locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}}) } @@ -271,13 +283,13 @@ func (rs *RemoteChunkStore) getHost(md metadata.MD) string { return host } -func (rs *RemoteChunkStore) getDownloadUrl(logger *logrus.Entry, md metadata.MD, path string) (string, error) { +func (rs *RemoteChunkStore) getDownloadUrl(logger *logrus.Entry, md metadata.MD, path string) (*url.URL, error) { host := rs.getHost(md) - return (&url.URL{ + return &url.URL{ Scheme: "http", Host: host, Path: path, - }).String(), nil + }, nil } func parseTableFileDetails(req *remotesapi.GetUploadLocsRequest) []*remotesapi.TableFileDetails { @@ -323,32 +335,35 @@ func (rs *RemoteChunkStore) GetUploadLocations(ctx context.Context, req *remotes for _, tfd := range tfds { h := hash.New(tfd.Id) url, err := rs.getUploadUrl(logger, md, repoPath, tfd) - if err != nil { return nil, status.Error(codes.Internal, "Failed to get upload Url.") } + url, err = rs.sealer.Seal(url) + if err != nil { + return nil, status.Error(codes.Internal, "Failed to seal upload Url.") + } - loc := &remotesapi.UploadLoc_HttpPost{HttpPost: &remotesapi.HttpPostTableFile{Url: url}} + loc := &remotesapi.UploadLoc_HttpPost{HttpPost: &remotesapi.HttpPostTableFile{Url: url.String()}} locs = append(locs, &remotesapi.UploadLoc{TableFileHash: h[:], Location: loc}) - logger.Printf("sending upload location for chunk %s: %s", h.String(), url) + logger.Printf("sending upload location for chunk %s: %s", h.String(), url.String()) } return &remotesapi.GetUploadLocsResponse{Locs: locs}, nil } -func (rs *RemoteChunkStore) getUploadUrl(logger *logrus.Entry, md metadata.MD, repoPath string, tfd *remotesapi.TableFileDetails) (string, error) { +func (rs *RemoteChunkStore) getUploadUrl(logger *logrus.Entry, md metadata.MD, repoPath string, tfd *remotesapi.TableFileDetails) (*url.URL, error) { fileID := hash.New(tfd.Id).String() params := url.Values{} params.Add("num_chunks", strconv.Itoa(int(tfd.NumChunks))) params.Add("content_length", strconv.Itoa(int(tfd.ContentLength))) params.Add("content_hash", base64.RawURLEncoding.EncodeToString(tfd.ContentHash)) - return (&url.URL{ + return &url.URL{ Scheme: "http", Host: rs.getHost(md), Path: fmt.Sprintf("%s/%s", repoPath, fileID), RawQuery: params.Encode(), - }).String(), nil + }, nil } func (rs *RemoteChunkStore) Rebase(ctx context.Context, req *remotesapi.RebaseRequest) (*remotesapi.RebaseResponse, error) { @@ -536,11 +551,15 @@ func getTableFileInfo( if err != nil { return nil, status.Error(codes.Internal, "failed to get download url for "+t.FileID()) } + url, err = rs.sealer.Seal(url) + if err != nil { + return nil, status.Error(codes.Internal, "failed to get seal download url for "+t.FileID()) + } appendixTableFileInfo = append(appendixTableFileInfo, &remotesapi.TableFileInfo{ FileId: t.FileID(), NumChunks: uint32(t.NumChunks()), - Url: url, + Url: url.String(), }) } return appendixTableFileInfo, nil diff --git a/go/libraries/doltcore/remotesrv/http.go b/go/libraries/doltcore/remotesrv/http.go index f94698e977..b399041d12 100644 --- a/go/libraries/doltcore/remotesrv/http.go +++ b/go/libraries/doltcore/remotesrv/http.go @@ -46,9 +46,10 @@ type filehandler struct { fs filesys.Filesys readOnly bool lgr *logrus.Entry + sealer Sealer } -func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, readOnly bool) filehandler { +func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, readOnly bool, sealer Sealer) filehandler { return filehandler{ dbCache, fs, @@ -56,6 +57,7 @@ func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, read lgr.WithFields(logrus.Fields{ "service": "dolt.services.remotesapi.v1alpha1.HttpFileServer", }), + sealer, } } @@ -63,6 +65,15 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) { logger := getReqLogger(fh.lgr, req.Method+"_"+req.RequestURI) defer func() { logger.Println("finished") }() + var err error + req.URL, err = fh.sealer.Unseal(req.URL) + if err != nil { + logger.Printf("could not unseal incoming request URL: %s", err.Error()) + respWr.WriteHeader(http.StatusBadRequest) + return + } + logger.Printf("unsealed url %s", req.URL.String()) + path := strings.TrimLeft(req.URL.Path, "/") statusCode := http.StatusMethodNotAllowed @@ -92,7 +103,7 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) { respWr.WriteHeader(http.StatusInternalServerError) return } - statusCode = readTableFile(logger, abs, respWr, req) + statusCode = readTableFile(logger, abs, respWr, req.Header.Get("Range")) case http.MethodPost, http.MethodPut: if fh.readOnly { @@ -157,15 +168,13 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) { } } -func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter, req *http.Request) int { - rangeStr := req.Header.Get("Range") - +func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter, rangeStr string) int { var r io.ReadCloser var readSize int64 var fileErr error { if rangeStr == "" { - logger.Println("going to read entire file") + logger.Println("going to read entire file", path) r, readSize, fileErr = getFileReader(path) } else { offset, length, err := offsetAndLenFromRange(rangeStr) @@ -173,7 +182,7 @@ func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter logger.Println(err.Error()) return http.StatusBadRequest } - logger.Printf("going to read file at offset %d, length %d", offset, length) + logger.Printf("going to read file %s at offset %d, length %d", path, offset, length) readSize = length r, fileErr = getFileReaderAt(path, offset, length) } diff --git a/go/libraries/doltcore/remotesrv/sealer.go b/go/libraries/doltcore/remotesrv/sealer.go new file mode 100644 index 0000000000..65bd4c1dbd --- /dev/null +++ b/go/libraries/doltcore/remotesrv/sealer.go @@ -0,0 +1,179 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotesrv + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "time" +) + +// Interface to seal requests to the HTTP server so that they cannot be forged. +// The gRPC server seals URLs and the HTTP server unseals them. +type Sealer interface { + Seal(*url.URL) (*url.URL, error) + Unseal(*url.URL) (*url.URL, error) +} + +var _ Sealer = identitySealer{} + +type identitySealer struct { +} + +func (identitySealer) Seal(u *url.URL) (*url.URL, error) { + return u, nil +} + +func (identitySealer) Unseal(u *url.URL) (*url.URL, error) { + return u, nil +} + +// Seals a URL by encrypting its Path and Query components and passing those in +// a base64 encoded query parameter. Adds a not before timestamp (nbf) and an +// expiration timestamp (exp) as query parameters. Encrypts the URL with +// AES-256 GCM and adds the nbf and exp parameters as authenticated data. +type singleSymmetricKeySealer struct { + privateKeyBytes []byte +} + +func NewSingleSymmetricKeySealer() (Sealer, error) { + var key [32]byte + _, err := rand.Read(key[:]) + if err != nil { + return nil, err + } + return singleSymmetricKeySealer{privateKeyBytes: key[:]}, nil +} + +func (s singleSymmetricKeySealer) Seal(u *url.URL) (*url.URL, error) { + requestURI := (&url.URL{ + Path: u.EscapedPath(), + RawQuery: u.RawQuery, + }).String() + nbf := time.Now().Add(-10 * time.Second) + exp := time.Now().Add(15 * time.Minute) + nbfStr := strconv.FormatInt(nbf.UnixMilli(), 10) + expStr := strconv.FormatInt(exp.UnixMilli(), 10) + var nonceBytes [12]byte + _, err := rand.Read(nonceBytes[:]) + if err != nil { + return nil, err + } + nonceStr := base64.RawURLEncoding.EncodeToString(nonceBytes[:]) + + block, err := aes.NewCipher(s.privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("internal error: error making aes cipher with key: %w", err) + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("internal error: error making gcm mode opener with key: %w", err) + } + + reqBytes := aesgcm.Seal(nil, nonceBytes[:], []byte(requestURI), []byte(nbfStr+":"+expStr)) + reqStr := base64.RawURLEncoding.EncodeToString(reqBytes) + + ret := *u + ret.Path = "/single_symmetric_key_sealed_request/" + u.EscapedPath() + ret.RawQuery = url.Values(map[string][]string{ + "req": []string{reqStr}, + "nbf": []string{strconv.FormatInt(nbf.UnixMilli(), 10)}, + "exp": []string{strconv.FormatInt(exp.UnixMilli(), 10)}, + "nonce": []string{nonceStr}, + }).Encode() + return &ret, nil +} + +func (s singleSymmetricKeySealer) Unseal(u *url.URL) (*url.URL, error) { + if !strings.HasPrefix(u.Path, "/single_symmetric_key_sealed_request/") { + return nil, errors.New("bad request: cannot unseal URL whose path does not start with /single_symmetric_key_sealed_request/") + } + q := u.Query() + if !q.Has("nbf") { + return nil, errors.New("bad request: cannot unseal URL which does not include an nbf") + } + if !q.Has("exp") { + return nil, errors.New("bad request: cannot unseal URL which does not include an exp") + } + if !q.Has("nonce") { + return nil, errors.New("bad request: cannot unseal URL which does not include a nonce") + } + if !q.Has("req") { + return nil, errors.New("bad request: cannot unseal URL which does not include a req") + } + nbfStr := q.Get("nbf") + expStr := q.Get("exp") + nonceStr := q.Get("nonce") + + nbf, err := strconv.ParseInt(nbfStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("bad request: error parsing nbf as int64: %w", err) + } + exp, err := strconv.ParseInt(expStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("bad request: error parsing exp as int64: %w", err) + } + nonce, err := base64.RawURLEncoding.DecodeString(nonceStr) + if err != nil { + return nil, fmt.Errorf("bad request: error parsing nonce as base64 URL encoded: %w", err) + } + + if time.Now().Before(time.UnixMilli(nbf)) { + return nil, fmt.Errorf("bad request: nbf is invalid") + } + if time.Now().After(time.UnixMilli(exp)) { + return nil, fmt.Errorf("bad request: exp is invalid") + } + + block, err := aes.NewCipher(s.privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("internal error: error making aes cipher with key: %w", err) + } + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("internal error: error making gcm mode opener with key: %w", err) + } + + reqStr := q.Get("req") + reqBytes, err := base64.RawURLEncoding.DecodeString(reqStr) + if err != nil { + return nil, fmt.Errorf("bad request: error parsing req as base64 URL encoded: %w", err) + } + + requestURI, err := aesgcm.Open(nil, nonce, reqBytes, []byte(nbfStr+":"+expStr)) + if err != nil { + return nil, fmt.Errorf("bad request: error opening sealed url: %w", err) + } + requestURL, err := url.Parse(string(requestURI)) + if err != nil { + return nil, fmt.Errorf("bad request: error parsing unsealed request uri: %w", err) + } + + if strings.TrimPrefix(u.Path, "/single_symmetric_key_sealed_request/") != requestURL.EscapedPath() { + return nil, fmt.Errorf("bad request: unsealed request path did not equal request path in sealed request") + } + + ret := *u + ret.Path = requestURL.Path + ret.RawQuery = requestURL.RawQuery + return &ret, nil +} diff --git a/go/libraries/doltcore/remotesrv/sealer_test.go b/go/libraries/doltcore/remotesrv/sealer_test.go new file mode 100644 index 0000000000..ab1fb88c32 --- /dev/null +++ b/go/libraries/doltcore/remotesrv/sealer_test.go @@ -0,0 +1,88 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotesrv + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSingleSymmetricKeySealer(t *testing.T) { + s, err := NewSingleSymmetricKeySealer() + assert.NoError(t, err) + assert.NotNil(t, s) + + u := &url.URL{ + Scheme: "https", + Host: "remotesapi.dolthub.com:443", + Path: "somedatabasename/sometablefilename", + } + sealed, err := s.Seal(u) + assert.NoError(t, err) + unsealed, err := s.Unseal(sealed) + assert.NoError(t, err) + assert.Equal(t, u, unsealed) + + corruptednbf := &(*sealed) + ps := corruptednbf.Query() + ps.Set("nbf", fmt.Sprintf("%v", time.Now())) + corruptednbf.RawQuery = ps.Encode() + unsealed, err = s.Unseal(corruptednbf) + assert.Error(t, err) + + nonbf := &(*sealed) + ps = nonbf.Query() + ps.Del("nbf") + nonbf.RawQuery = ps.Encode() + unsealed, err = s.Unseal(nonbf) + assert.Error(t, err) + + corruptedexp := &(*sealed) + ps = corruptedexp.Query() + ps.Set("exp", fmt.Sprintf("%v", time.Now())) + corruptedexp.RawQuery = ps.Encode() + unsealed, err = s.Unseal(corruptedexp) + assert.Error(t, err) + + noexp := &(*sealed) + ps = noexp.Query() + ps.Del("exp") + noexp.RawQuery = ps.Encode() + unsealed, err = s.Unseal(noexp) + assert.Error(t, err) + + corruptednonce := &(*sealed) + ps = corruptednonce.Query() + var differentnonce [12]byte + _, err = rand.Read(differentnonce[:]) + assert.NoError(t, err) + ps.Set("nonce", base64.RawURLEncoding.EncodeToString(differentnonce[:])) + corruptednonce.RawQuery = ps.Encode() + unsealed, err = s.Unseal(corruptednonce) + assert.Error(t, err) + + nononce := &(*sealed) + ps = nononce.Query() + ps.Del("nonce") + nononce.RawQuery = ps.Encode() + unsealed, err = s.Unseal(nononce) + assert.Error(t, err) +} diff --git a/go/libraries/doltcore/remotesrv/server.go b/go/libraries/doltcore/remotesrv/server.go index 9263623dab..43612ef2e6 100644 --- a/go/libraries/doltcore/remotesrv/server.go +++ b/go/libraries/doltcore/remotesrv/server.go @@ -57,7 +57,7 @@ type ServerArgs struct { Options []grpc.ServerOption } -func NewServer(args ServerArgs) *Server { +func NewServer(args ServerArgs) (*Server, error) { if args.Logger == nil { args.Logger = logrus.NewEntry(logrus.StandardLogger()) } @@ -65,16 +65,21 @@ func NewServer(args ServerArgs) *Server { s := new(Server) s.stopChan = make(chan struct{}) + sealer, err := NewSingleSymmetricKeySealer() + if err != nil { + return nil, err + } + s.wg.Add(2) s.grpcPort = args.GrpcPort s.grpcSrv = grpc.NewServer(append([]grpc.ServerOption{grpc.MaxRecvMsgSize(128 * 1024 * 1024)}, args.Options...)...) - var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(args.Logger, args.HttpHost, args.DBCache, args.FS) + var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(args.Logger, args.HttpHost, args.DBCache, args.FS, sealer) if args.ReadOnly { chnkSt = ReadOnlyChunkStore{chnkSt} } remotesapi.RegisterChunkStoreServiceServer(s.grpcSrv, chnkSt) - var handler http.Handler = newFileHandler(args.Logger, args.DBCache, args.FS, args.ReadOnly) + var handler http.Handler = newFileHandler(args.Logger, args.DBCache, args.FS, args.ReadOnly, sealer) if args.HttpPort == args.GrpcPort { handler = grpcMultiplexHandler(s.grpcSrv, handler) } else { @@ -87,7 +92,7 @@ func NewServer(args ServerArgs) *Server { Handler: handler, } - return s + return s, nil } func grpcMultiplexHandler(grpcSrv *grpc.Server, handler http.Handler) http.Handler { diff --git a/go/utils/remotesrv/main.go b/go/utils/remotesrv/main.go index 08f4faa32f..8a43f99ae1 100644 --- a/go/utils/remotesrv/main.go +++ b/go/utils/remotesrv/main.go @@ -81,7 +81,7 @@ func main() { dbCache = NewLocalCSCache(fs) } - server := remotesrv.NewServer(remotesrv.ServerArgs{ + server, err := remotesrv.NewServer(remotesrv.ServerArgs{ HttpHost: *httpHostParam, HttpPort: *httpPortParam, GrpcPort: *grpcPortParam, @@ -89,6 +89,9 @@ func main() { DBCache: dbCache, ReadOnly: *readOnlyParam, }) + if err != nil { + log.Fatalf("error creating remotesrv Server: %v\n", err) + } listeners, err := server.Listeners() if err != nil { log.Fatalf("error starting remotesrv Server listeners: %v\n", err) diff --git a/integration-tests/bats/sql-server-remotesrv.bats b/integration-tests/bats/sql-server-remotesrv.bats index fa52cff5dc..635b393d7c 100644 --- a/integration-tests/bats/sql-server-remotesrv.bats +++ b/integration-tests/bats/sql-server-remotesrv.bats @@ -49,6 +49,7 @@ call dolt_commit('-am', 'add some vals'); SQL dolt pull + run dolt sql -q 'select count(*) from vals;' [[ "$output" =~ "10" ]] || false }