Merge branch 'main' into james/import

This commit is contained in:
James Cor
2022-10-26 11:22:13 -07:00
8 changed files with 345 additions and 31 deletions

View File

@@ -232,7 +232,12 @@ func Serve(
HttpPort: *serverConfig.RemotesapiPort(),
GrpcPort: *serverConfig.RemotesapiPort(),
})
remoteSrv = remotesrv.NewServer(args)
remoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
startError = err
return
}
listeners, err := remoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", *serverConfig.RemotesapiPort(), err)
@@ -256,7 +261,12 @@ func Serve(
args := clusterController.RemoteSrvServerArgs(remoteSrvSqlCtx, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
})
clusterRemoteSrv = remotesrv.NewServer(args)
clusterRemoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
startError = err
return
}
listeners, err := clusterRemoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners for cluster config on port %d: %v", clusterController.RemoteSrvPort(), err)

View File

@@ -47,10 +47,11 @@ type RemoteChunkStore struct {
bucket string
fs filesys.Filesys
lgr *logrus.Entry
sealer Sealer
remotesapi.UnimplementedChunkStoreServiceServer
}
func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCache, fs filesys.Filesys) *RemoteChunkStore {
func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCache, fs filesys.Filesys, sealer Sealer) *RemoteChunkStore {
return &RemoteChunkStore{
HttpHost: httpHost,
csCache: csCache,
@@ -59,6 +60,7 @@ func NewHttpFSBackedChunkStore(lgr *logrus.Entry, httpHost string, csCache DBCac
lgr: lgr.WithFields(logrus.Fields{
"service": "dolt.services.remotesapi.v1alpha1.ChunkStoreServiceServer",
}),
sealer: sealer,
}
}
@@ -177,10 +179,15 @@ func (rs *RemoteChunkStore) GetDownloadLocations(ctx context.Context, req *remot
logger.Println("Failed to sign request", err)
return nil, err
}
preurl := url.String()
url, err = rs.sealer.Seal(url)
if err != nil {
logger.Println("Failed to seal request", err)
return nil, err
}
logger.Println("The URL is", preurl, "the ranges are", ranges, "sealed url", url.String())
logger.Println("The URL is", url)
getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges}
getRange := &remotesapi.HttpGetRange{Url: url.String(), Ranges: ranges}
locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}})
}
@@ -242,10 +249,15 @@ func (rs *RemoteChunkStore) StreamDownloadLocations(stream remotesapi.ChunkStore
logger.Println("Failed to sign request", err)
return err
}
preurl := url.String()
url, err = rs.sealer.Seal(url)
if err != nil {
logger.Println("Failed to seal request", err)
return err
}
logger.Println("The URL is", preurl, "the ranges are", ranges, "sealed url", url.String())
logger.Println("The URL is", url)
getRange := &remotesapi.HttpGetRange{Url: url, Ranges: ranges}
getRange := &remotesapi.HttpGetRange{Url: url.String(), Ranges: ranges}
locs = append(locs, &remotesapi.DownloadLoc{Location: &remotesapi.DownloadLoc_HttpGetRange{HttpGetRange: getRange}})
}
@@ -271,13 +283,13 @@ func (rs *RemoteChunkStore) getHost(md metadata.MD) string {
return host
}
func (rs *RemoteChunkStore) getDownloadUrl(logger *logrus.Entry, md metadata.MD, path string) (string, error) {
func (rs *RemoteChunkStore) getDownloadUrl(logger *logrus.Entry, md metadata.MD, path string) (*url.URL, error) {
host := rs.getHost(md)
return (&url.URL{
return &url.URL{
Scheme: "http",
Host: host,
Path: path,
}).String(), nil
}, nil
}
func parseTableFileDetails(req *remotesapi.GetUploadLocsRequest) []*remotesapi.TableFileDetails {
@@ -323,32 +335,35 @@ func (rs *RemoteChunkStore) GetUploadLocations(ctx context.Context, req *remotes
for _, tfd := range tfds {
h := hash.New(tfd.Id)
url, err := rs.getUploadUrl(logger, md, repoPath, tfd)
if err != nil {
return nil, status.Error(codes.Internal, "Failed to get upload Url.")
}
url, err = rs.sealer.Seal(url)
if err != nil {
return nil, status.Error(codes.Internal, "Failed to seal upload Url.")
}
loc := &remotesapi.UploadLoc_HttpPost{HttpPost: &remotesapi.HttpPostTableFile{Url: url}}
loc := &remotesapi.UploadLoc_HttpPost{HttpPost: &remotesapi.HttpPostTableFile{Url: url.String()}}
locs = append(locs, &remotesapi.UploadLoc{TableFileHash: h[:], Location: loc})
logger.Printf("sending upload location for chunk %s: %s", h.String(), url)
logger.Printf("sending upload location for chunk %s: %s", h.String(), url.String())
}
return &remotesapi.GetUploadLocsResponse{Locs: locs}, nil
}
func (rs *RemoteChunkStore) getUploadUrl(logger *logrus.Entry, md metadata.MD, repoPath string, tfd *remotesapi.TableFileDetails) (string, error) {
func (rs *RemoteChunkStore) getUploadUrl(logger *logrus.Entry, md metadata.MD, repoPath string, tfd *remotesapi.TableFileDetails) (*url.URL, error) {
fileID := hash.New(tfd.Id).String()
params := url.Values{}
params.Add("num_chunks", strconv.Itoa(int(tfd.NumChunks)))
params.Add("content_length", strconv.Itoa(int(tfd.ContentLength)))
params.Add("content_hash", base64.RawURLEncoding.EncodeToString(tfd.ContentHash))
return (&url.URL{
return &url.URL{
Scheme: "http",
Host: rs.getHost(md),
Path: fmt.Sprintf("%s/%s", repoPath, fileID),
RawQuery: params.Encode(),
}).String(), nil
}, nil
}
func (rs *RemoteChunkStore) Rebase(ctx context.Context, req *remotesapi.RebaseRequest) (*remotesapi.RebaseResponse, error) {
@@ -536,11 +551,15 @@ func getTableFileInfo(
if err != nil {
return nil, status.Error(codes.Internal, "failed to get download url for "+t.FileID())
}
url, err = rs.sealer.Seal(url)
if err != nil {
return nil, status.Error(codes.Internal, "failed to get seal download url for "+t.FileID())
}
appendixTableFileInfo = append(appendixTableFileInfo, &remotesapi.TableFileInfo{
FileId: t.FileID(),
NumChunks: uint32(t.NumChunks()),
Url: url,
Url: url.String(),
})
}
return appendixTableFileInfo, nil

View File

@@ -46,9 +46,10 @@ type filehandler struct {
fs filesys.Filesys
readOnly bool
lgr *logrus.Entry
sealer Sealer
}
func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, readOnly bool) filehandler {
func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, readOnly bool, sealer Sealer) filehandler {
return filehandler{
dbCache,
fs,
@@ -56,6 +57,7 @@ func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, read
lgr.WithFields(logrus.Fields{
"service": "dolt.services.remotesapi.v1alpha1.HttpFileServer",
}),
sealer,
}
}
@@ -63,6 +65,15 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
logger := getReqLogger(fh.lgr, req.Method+"_"+req.RequestURI)
defer func() { logger.Println("finished") }()
var err error
req.URL, err = fh.sealer.Unseal(req.URL)
if err != nil {
logger.Printf("could not unseal incoming request URL: %s", err.Error())
respWr.WriteHeader(http.StatusBadRequest)
return
}
logger.Printf("unsealed url %s", req.URL.String())
path := strings.TrimLeft(req.URL.Path, "/")
statusCode := http.StatusMethodNotAllowed
@@ -92,7 +103,7 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
respWr.WriteHeader(http.StatusInternalServerError)
return
}
statusCode = readTableFile(logger, abs, respWr, req)
statusCode = readTableFile(logger, abs, respWr, req.Header.Get("Range"))
case http.MethodPost, http.MethodPut:
if fh.readOnly {
@@ -157,15 +168,13 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
}
}
func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter, req *http.Request) int {
rangeStr := req.Header.Get("Range")
func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter, rangeStr string) int {
var r io.ReadCloser
var readSize int64
var fileErr error
{
if rangeStr == "" {
logger.Println("going to read entire file")
logger.Println("going to read entire file", path)
r, readSize, fileErr = getFileReader(path)
} else {
offset, length, err := offsetAndLenFromRange(rangeStr)
@@ -173,7 +182,7 @@ func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter
logger.Println(err.Error())
return http.StatusBadRequest
}
logger.Printf("going to read file at offset %d, length %d", offset, length)
logger.Printf("going to read file %s at offset %d, length %d", path, offset, length)
readSize = length
r, fileErr = getFileReaderAt(path, offset, length)
}

View File

@@ -0,0 +1,179 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package remotesrv
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"time"
)
// Interface to seal requests to the HTTP server so that they cannot be forged.
// The gRPC server seals URLs and the HTTP server unseals them.
type Sealer interface {
Seal(*url.URL) (*url.URL, error)
Unseal(*url.URL) (*url.URL, error)
}
var _ Sealer = identitySealer{}
type identitySealer struct {
}
func (identitySealer) Seal(u *url.URL) (*url.URL, error) {
return u, nil
}
func (identitySealer) Unseal(u *url.URL) (*url.URL, error) {
return u, nil
}
// Seals a URL by encrypting its Path and Query components and passing those in
// a base64 encoded query parameter. Adds a not before timestamp (nbf) and an
// expiration timestamp (exp) as query parameters. Encrypts the URL with
// AES-256 GCM and adds the nbf and exp parameters as authenticated data.
type singleSymmetricKeySealer struct {
privateKeyBytes []byte
}
func NewSingleSymmetricKeySealer() (Sealer, error) {
var key [32]byte
_, err := rand.Read(key[:])
if err != nil {
return nil, err
}
return singleSymmetricKeySealer{privateKeyBytes: key[:]}, nil
}
func (s singleSymmetricKeySealer) Seal(u *url.URL) (*url.URL, error) {
requestURI := (&url.URL{
Path: u.EscapedPath(),
RawQuery: u.RawQuery,
}).String()
nbf := time.Now().Add(-10 * time.Second)
exp := time.Now().Add(15 * time.Minute)
nbfStr := strconv.FormatInt(nbf.UnixMilli(), 10)
expStr := strconv.FormatInt(exp.UnixMilli(), 10)
var nonceBytes [12]byte
_, err := rand.Read(nonceBytes[:])
if err != nil {
return nil, err
}
nonceStr := base64.RawURLEncoding.EncodeToString(nonceBytes[:])
block, err := aes.NewCipher(s.privateKeyBytes)
if err != nil {
return nil, fmt.Errorf("internal error: error making aes cipher with key: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("internal error: error making gcm mode opener with key: %w", err)
}
reqBytes := aesgcm.Seal(nil, nonceBytes[:], []byte(requestURI), []byte(nbfStr+":"+expStr))
reqStr := base64.RawURLEncoding.EncodeToString(reqBytes)
ret := *u
ret.Path = "/single_symmetric_key_sealed_request/" + u.EscapedPath()
ret.RawQuery = url.Values(map[string][]string{
"req": []string{reqStr},
"nbf": []string{strconv.FormatInt(nbf.UnixMilli(), 10)},
"exp": []string{strconv.FormatInt(exp.UnixMilli(), 10)},
"nonce": []string{nonceStr},
}).Encode()
return &ret, nil
}
func (s singleSymmetricKeySealer) Unseal(u *url.URL) (*url.URL, error) {
if !strings.HasPrefix(u.Path, "/single_symmetric_key_sealed_request/") {
return nil, errors.New("bad request: cannot unseal URL whose path does not start with /single_symmetric_key_sealed_request/")
}
q := u.Query()
if !q.Has("nbf") {
return nil, errors.New("bad request: cannot unseal URL which does not include an nbf")
}
if !q.Has("exp") {
return nil, errors.New("bad request: cannot unseal URL which does not include an exp")
}
if !q.Has("nonce") {
return nil, errors.New("bad request: cannot unseal URL which does not include a nonce")
}
if !q.Has("req") {
return nil, errors.New("bad request: cannot unseal URL which does not include a req")
}
nbfStr := q.Get("nbf")
expStr := q.Get("exp")
nonceStr := q.Get("nonce")
nbf, err := strconv.ParseInt(nbfStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("bad request: error parsing nbf as int64: %w", err)
}
exp, err := strconv.ParseInt(expStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("bad request: error parsing exp as int64: %w", err)
}
nonce, err := base64.RawURLEncoding.DecodeString(nonceStr)
if err != nil {
return nil, fmt.Errorf("bad request: error parsing nonce as base64 URL encoded: %w", err)
}
if time.Now().Before(time.UnixMilli(nbf)) {
return nil, fmt.Errorf("bad request: nbf is invalid")
}
if time.Now().After(time.UnixMilli(exp)) {
return nil, fmt.Errorf("bad request: exp is invalid")
}
block, err := aes.NewCipher(s.privateKeyBytes)
if err != nil {
return nil, fmt.Errorf("internal error: error making aes cipher with key: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("internal error: error making gcm mode opener with key: %w", err)
}
reqStr := q.Get("req")
reqBytes, err := base64.RawURLEncoding.DecodeString(reqStr)
if err != nil {
return nil, fmt.Errorf("bad request: error parsing req as base64 URL encoded: %w", err)
}
requestURI, err := aesgcm.Open(nil, nonce, reqBytes, []byte(nbfStr+":"+expStr))
if err != nil {
return nil, fmt.Errorf("bad request: error opening sealed url: %w", err)
}
requestURL, err := url.Parse(string(requestURI))
if err != nil {
return nil, fmt.Errorf("bad request: error parsing unsealed request uri: %w", err)
}
if strings.TrimPrefix(u.Path, "/single_symmetric_key_sealed_request/") != requestURL.EscapedPath() {
return nil, fmt.Errorf("bad request: unsealed request path did not equal request path in sealed request")
}
ret := *u
ret.Path = requestURL.Path
ret.RawQuery = requestURL.RawQuery
return &ret, nil
}

View File

@@ -0,0 +1,88 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package remotesrv
import (
"crypto/rand"
"encoding/base64"
"fmt"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSingleSymmetricKeySealer(t *testing.T) {
s, err := NewSingleSymmetricKeySealer()
assert.NoError(t, err)
assert.NotNil(t, s)
u := &url.URL{
Scheme: "https",
Host: "remotesapi.dolthub.com:443",
Path: "somedatabasename/sometablefilename",
}
sealed, err := s.Seal(u)
assert.NoError(t, err)
unsealed, err := s.Unseal(sealed)
assert.NoError(t, err)
assert.Equal(t, u, unsealed)
corruptednbf := &(*sealed)
ps := corruptednbf.Query()
ps.Set("nbf", fmt.Sprintf("%v", time.Now()))
corruptednbf.RawQuery = ps.Encode()
unsealed, err = s.Unseal(corruptednbf)
assert.Error(t, err)
nonbf := &(*sealed)
ps = nonbf.Query()
ps.Del("nbf")
nonbf.RawQuery = ps.Encode()
unsealed, err = s.Unseal(nonbf)
assert.Error(t, err)
corruptedexp := &(*sealed)
ps = corruptedexp.Query()
ps.Set("exp", fmt.Sprintf("%v", time.Now()))
corruptedexp.RawQuery = ps.Encode()
unsealed, err = s.Unseal(corruptedexp)
assert.Error(t, err)
noexp := &(*sealed)
ps = noexp.Query()
ps.Del("exp")
noexp.RawQuery = ps.Encode()
unsealed, err = s.Unseal(noexp)
assert.Error(t, err)
corruptednonce := &(*sealed)
ps = corruptednonce.Query()
var differentnonce [12]byte
_, err = rand.Read(differentnonce[:])
assert.NoError(t, err)
ps.Set("nonce", base64.RawURLEncoding.EncodeToString(differentnonce[:]))
corruptednonce.RawQuery = ps.Encode()
unsealed, err = s.Unseal(corruptednonce)
assert.Error(t, err)
nononce := &(*sealed)
ps = nononce.Query()
ps.Del("nonce")
nononce.RawQuery = ps.Encode()
unsealed, err = s.Unseal(nononce)
assert.Error(t, err)
}

View File

@@ -57,7 +57,7 @@ type ServerArgs struct {
Options []grpc.ServerOption
}
func NewServer(args ServerArgs) *Server {
func NewServer(args ServerArgs) (*Server, error) {
if args.Logger == nil {
args.Logger = logrus.NewEntry(logrus.StandardLogger())
}
@@ -65,16 +65,21 @@ func NewServer(args ServerArgs) *Server {
s := new(Server)
s.stopChan = make(chan struct{})
sealer, err := NewSingleSymmetricKeySealer()
if err != nil {
return nil, err
}
s.wg.Add(2)
s.grpcPort = args.GrpcPort
s.grpcSrv = grpc.NewServer(append([]grpc.ServerOption{grpc.MaxRecvMsgSize(128 * 1024 * 1024)}, args.Options...)...)
var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(args.Logger, args.HttpHost, args.DBCache, args.FS)
var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(args.Logger, args.HttpHost, args.DBCache, args.FS, sealer)
if args.ReadOnly {
chnkSt = ReadOnlyChunkStore{chnkSt}
}
remotesapi.RegisterChunkStoreServiceServer(s.grpcSrv, chnkSt)
var handler http.Handler = newFileHandler(args.Logger, args.DBCache, args.FS, args.ReadOnly)
var handler http.Handler = newFileHandler(args.Logger, args.DBCache, args.FS, args.ReadOnly, sealer)
if args.HttpPort == args.GrpcPort {
handler = grpcMultiplexHandler(s.grpcSrv, handler)
} else {
@@ -87,7 +92,7 @@ func NewServer(args ServerArgs) *Server {
Handler: handler,
}
return s
return s, nil
}
func grpcMultiplexHandler(grpcSrv *grpc.Server, handler http.Handler) http.Handler {

View File

@@ -81,7 +81,7 @@ func main() {
dbCache = NewLocalCSCache(fs)
}
server := remotesrv.NewServer(remotesrv.ServerArgs{
server, err := remotesrv.NewServer(remotesrv.ServerArgs{
HttpHost: *httpHostParam,
HttpPort: *httpPortParam,
GrpcPort: *grpcPortParam,
@@ -89,6 +89,9 @@ func main() {
DBCache: dbCache,
ReadOnly: *readOnlyParam,
})
if err != nil {
log.Fatalf("error creating remotesrv Server: %v\n", err)
}
listeners, err := server.Listeners()
if err != nil {
log.Fatalf("error starting remotesrv Server listeners: %v\n", err)

View File

@@ -49,6 +49,7 @@ call dolt_commit('-am', 'add some vals');
SQL
dolt pull
run dolt sql -q 'select count(*) from vals;'
[[ "$output" =~ "10" ]] || false
}