diff --git a/go/utils/remotesrv/http.go b/go/utils/remotesrv/http.go index 6c9bdee9b3..b06db3f709 100644 --- a/go/utils/remotesrv/http.go +++ b/go/utils/remotesrv/http.go @@ -16,11 +16,13 @@ package main import ( "bytes" + "context" "crypto/md5" "errors" "fmt" "io" "net/http" + gohash "hash" "os" "path/filepath" "strconv" @@ -30,6 +32,7 @@ import ( remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1" "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/types" ) var ( @@ -84,7 +87,7 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) { statusCode = readTableFile(logger, org, repo, hashStr, respWr, req) case http.MethodPost, http.MethodPut: - statusCode = writeTableFile(logger, fh.expectedFiles, org, repo, hashStr, req) + statusCode = writeTableFile(req.Context(), logger, fh.dbCache, fh.expectedFiles, org, repo, hashStr, req) } if statusCode != -1 { @@ -149,7 +152,42 @@ func readTableFile(logger func(string), org, repo, fileId string, respWr http.Re return http.StatusOK } -func writeTableFile(logger func(string), expectedFiles fileDetails, org, repo, fileId string, request *http.Request) int { +type uploadreader struct { + r io.ReadCloser + totalread int + expectedread uint64 + expectedsum []byte + checksum gohash.Hash +} + +func (u *uploadreader) Read(p []byte) (n int, err error) { + n, err = u.r.Read(p) + if err == nil || err == io.EOF { + u.totalread += n + u.checksum.Write(p[:n]) + } + return n, err +} + +var errBodyLengthTFDMismatch = errors.New("body upload length did not match table file details") +var errBodyHashTFDMismatch = errors.New("body upload hash did not match table file details") + +func (u *uploadreader) Close() error { + cerr := u.r.Close() + if cerr != nil { + return cerr + } + if u.expectedread != 0 && u.expectedread != uint64(u.totalread) { + return errBodyLengthTFDMismatch + } + sum := u.checksum.Sum(nil) + if !bytes.Equal(u.expectedsum, sum[:]) { + return errBodyHashTFDMismatch + } + return nil +} + +func writeTableFile(ctx context.Context, logger func(string), dbCache *DBCache, expectedFiles fileDetails, org, repo, fileId string, request *http.Request) int { _, ok := hash.MaybeParse(fileId) if !ok { @@ -159,30 +197,41 @@ func writeTableFile(logger func(string), expectedFiles fileDetails, org, repo, f tfd, ok := expectedFiles.Get(fileId) if !ok { + logger("bad request for " + fileId + ": tfd not found") return http.StatusBadRequest } logger(fileId + " is valid") - data, err := io.ReadAll(request.Body) - - if tfd.ContentLength != 0 && tfd.ContentLength != uint64(len(data)) { - return http.StatusBadRequest - } - - if len(tfd.ContentHash) > 0 { - actualMD5Bytes := md5.Sum(data) - if !bytes.Equal(tfd.ContentHash, actualMD5Bytes[:]) { - return http.StatusBadRequest - } - } + cs, err := dbCache.Get(org, repo, types.Format_Default.VersionString()) if err != nil { - logger("failed to read body " + err.Error()) + logger("failed to get " + org + "/" + repo + " repository: " + err.Error()) return http.StatusInternalServerError } - err = writeLocal(logger, org, repo, fileId, data) + + err = cs.WriteTableFile(ctx, fileId, int(tfd.NumChunks), tfd.ContentHash, func() (io.ReadCloser, uint64, error) { + reader := request.Body + size := tfd.ContentLength + return &uploadreader{ + reader, + 0, + tfd.ContentLength, + tfd.ContentHash, + md5.New(), + }, size, nil + }) + if err != nil { + if errors.Is(err, errBodyLengthTFDMismatch) { + logger("bad write file request for " + fileId + ": body length mismatch") + return http.StatusBadRequest + } + if errors.Is(err, errBodyHashTFDMismatch) { + logger("bad write file request for " + fileId + ": body hash mismatch") + return http.StatusBadRequest + } + logger("failed to read body " + err.Error()) return http.StatusInternalServerError }