go/utils/remotesrv: Make http server use TableFileStore.WriteTableFile() instead of writing directly against storage..

This commit is contained in:
Aaron Son
2022-08-31 16:15:41 -07:00
parent faea605f44
commit 6487f317bd

View File

@@ -16,11 +16,13 @@ package main
import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"net/http"
gohash "hash"
"os"
"path/filepath"
"strconv"
@@ -30,6 +32,7 @@ import (
remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
)
var (
@@ -84,7 +87,7 @@ func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
statusCode = readTableFile(logger, org, repo, hashStr, respWr, req)
case http.MethodPost, http.MethodPut:
statusCode = writeTableFile(logger, fh.expectedFiles, org, repo, hashStr, req)
statusCode = writeTableFile(req.Context(), logger, fh.dbCache, fh.expectedFiles, org, repo, hashStr, req)
}
if statusCode != -1 {
@@ -149,7 +152,42 @@ func readTableFile(logger func(string), org, repo, fileId string, respWr http.Re
return http.StatusOK
}
func writeTableFile(logger func(string), expectedFiles fileDetails, org, repo, fileId string, request *http.Request) int {
type uploadreader struct {
r io.ReadCloser
totalread int
expectedread uint64
expectedsum []byte
checksum gohash.Hash
}
func (u *uploadreader) Read(p []byte) (n int, err error) {
n, err = u.r.Read(p)
if err == nil || err == io.EOF {
u.totalread += n
u.checksum.Write(p[:n])
}
return n, err
}
var errBodyLengthTFDMismatch = errors.New("body upload length did not match table file details")
var errBodyHashTFDMismatch = errors.New("body upload hash did not match table file details")
func (u *uploadreader) Close() error {
cerr := u.r.Close()
if cerr != nil {
return cerr
}
if u.expectedread != 0 && u.expectedread != uint64(u.totalread) {
return errBodyLengthTFDMismatch
}
sum := u.checksum.Sum(nil)
if !bytes.Equal(u.expectedsum, sum[:]) {
return errBodyHashTFDMismatch
}
return nil
}
func writeTableFile(ctx context.Context, logger func(string), dbCache *DBCache, expectedFiles fileDetails, org, repo, fileId string, request *http.Request) int {
_, ok := hash.MaybeParse(fileId)
if !ok {
@@ -159,30 +197,41 @@ func writeTableFile(logger func(string), expectedFiles fileDetails, org, repo, f
tfd, ok := expectedFiles.Get(fileId)
if !ok {
logger("bad request for " + fileId + ": tfd not found")
return http.StatusBadRequest
}
logger(fileId + " is valid")
data, err := io.ReadAll(request.Body)
if tfd.ContentLength != 0 && tfd.ContentLength != uint64(len(data)) {
return http.StatusBadRequest
}
if len(tfd.ContentHash) > 0 {
actualMD5Bytes := md5.Sum(data)
if !bytes.Equal(tfd.ContentHash, actualMD5Bytes[:]) {
return http.StatusBadRequest
}
}
cs, err := dbCache.Get(org, repo, types.Format_Default.VersionString())
if err != nil {
logger("failed to read body " + err.Error())
logger("failed to get " + org + "/" + repo + " repository: " + err.Error())
return http.StatusInternalServerError
}
err = writeLocal(logger, org, repo, fileId, data)
err = cs.WriteTableFile(ctx, fileId, int(tfd.NumChunks), tfd.ContentHash, func() (io.ReadCloser, uint64, error) {
reader := request.Body
size := tfd.ContentLength
return &uploadreader{
reader,
0,
tfd.ContentLength,
tfd.ContentHash,
md5.New(),
}, size, nil
})
if err != nil {
if errors.Is(err, errBodyLengthTFDMismatch) {
logger("bad write file request for " + fileId + ": body length mismatch")
return http.StatusBadRequest
}
if errors.Is(err, errBodyHashTFDMismatch) {
logger("bad write file request for " + fileId + ": body hash mismatch")
return http.StatusBadRequest
}
logger("failed to read body " + err.Error())
return http.StatusInternalServerError
}