Progress on dolt clone auth

This commit is contained in:
Taylor Bantle
2023-03-20 14:42:32 -07:00
parent 9a7ad41ebe
commit d7b06fd1ea
10 changed files with 165 additions and 11 deletions

View File

@@ -115,6 +115,7 @@ const (
ShallowFlag = "shallow"
CachedFlag = "cached"
ListFlag = "list"
UserParam = "user"
)
const (
@@ -192,6 +193,7 @@ func CreateCloneArgParser() *argparser.ArgParser {
ap.SupportsString(dbfactory.AWSCredsProfile, "", "profile", "AWS profile to use.")
ap.SupportsString(dbfactory.OSSCredsFileParam, "", "file", "OSS credentials file.")
ap.SupportsString(dbfactory.OSSCredsProfile, "", "profile", "OSS profile to use.")
ap.SupportsString(UserParam, "u", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
return ap
}

View File

@@ -16,6 +16,7 @@ package commands
import (
"context"
"os"
"path"
"strings"
@@ -24,6 +25,7 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
@@ -100,6 +102,8 @@ func clone(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEn
return verr
}
dEnv.UserPassConfig = getUserAndPassConfig(apr)
userDirExists, _ := dEnv.FS.Exists(dir)
// Check for a valid dolthub url and replace the urlStr with the parsed repoName.
@@ -186,7 +190,7 @@ func parseArgs(apr *argparser.ArgParseResults) (string, string, errhand.VerboseE
if dir == "." {
dir = path.Dir(urlStr)
} else if dir == "/" {
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitily define a directory for this url").Build()
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
}
}
@@ -228,3 +232,14 @@ func validateAndParseDolthubUrl(urlStr string) (string, bool) {
return "", false
}
func getUserAndPassConfig(apr *argparser.ArgParseResults) *creds.DoltCredsForPass {
if !apr.Contains(cli.UserParam) {
return nil
}
pass := os.Getenv("DOLT_REMOTE_PASSWORD")
return &creds.DoltCredsForPass{
Username: apr.GetValueOrDefault(cli.UserParam, ""),
Password: pass,
}
}

View File

@@ -238,7 +238,7 @@ func Serve(
ReadOnly: true,
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
})
}, &remotesrv.UserAuth{User: serverConfig.User(), Password: serverConfig.Password()})
args.TLSConfig = serverConf.TLSConfig
remoteSrv, err = remotesrv.NewServer(args)
if err != nil {

View File

@@ -18,7 +18,9 @@ import (
"context"
"crypto/sha512"
"encoding/base32"
"encoding/base64"
"errors"
"fmt"
"time"
"golang.org/x/crypto/ed25519"
@@ -58,6 +60,11 @@ type DoltCreds struct {
KeyID []byte
}
type DoltCredsForPass struct {
Username string
Password string
}
func PubKeyStrToKIDStr(pub string) (string, error) {
data, err := B32CredsEncoding.DecodeString(pub)
@@ -119,15 +126,20 @@ func (dc DoltCreds) Sign(data []byte) []byte {
}
type RPCCreds struct {
PrivKey ed25519.PrivateKey
KeyID string
Audience string
Issuer string
Subject string
RequireTLS bool
PrivKey ed25519.PrivateKey
KeyID string
Audience string
Issuer string
Subject string
RequireTLS bool
UserPassContents string
}
func (c *RPCCreds) toBearerToken() (string, error) {
if len(c.UserPassContents) > 0 {
return "", fmt.Errorf("cannot create bearer token with user/pass credentials")
}
key := jose.SigningKey{Algorithm: jose.EdDSA, Key: c.PrivKey}
opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
JWTKIDHeader: c.KeyID,
@@ -151,6 +163,11 @@ func (c *RPCCreds) toBearerToken() (string, error) {
}
func (c *RPCCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if len(c.UserPassContents) > 0 {
return map[string]string{
"authorization": "Basic " + c.UserPassContents,
}, nil
}
t, err := c.toBearerToken()
if err != nil {
return nil, err
@@ -177,3 +194,14 @@ func (dc DoltCreds) RPCCreds(audience string) *RPCCreds {
RequireTLS: false,
}
}
func (dcp DoltCredsForPass) ToBase64Str() string {
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", dcp.Username, dcp.Password)))
}
func (dc DoltCredsForPass) RPCCreds() *RPCCreds {
return &RPCCreds{
RequireTLS: false,
UserPassContents: dc.ToBase64Str(),
}
}

View File

@@ -96,6 +96,7 @@ type DoltEnv struct {
hdp HomeDirProvider
IgnoreLockFile bool
UserPassConfig *creds.DoltCredsForPass
}
func (dEnv *DoltEnv) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r Remote, withCaching bool) (*doltdb.DoltDB, error) {

View File

@@ -110,6 +110,10 @@ func (p GRPCDialProvider) getRPCCreds(endpoint string) (credentials.PerRPCCreden
return nil, nil
}
if p.dEnv.UserPassConfig != nil {
return p.dEnv.UserPassConfig.RPCCreds(), nil
}
dCreds, valid, err := p.dEnv.UserDoltCreds()
if err != nil {
return nil, ErrInvalidCredsFile

View File

@@ -0,0 +1,98 @@
// Copyright 2023 DoltHub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package remotesrv
import (
"context"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type UserAuth struct {
User string
Password string
}
type serverinterceptor struct {
Lgr *logrus.Entry
ExpectedUserAuth UserAuth
}
func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := si.authenticate(ss.Context()); err != nil {
return err
}
return handler(srv, ss)
}
}
func (si *serverinterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := si.authenticate(ctx); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
func (si *serverinterceptor) Options() []grpc.ServerOption {
return []grpc.ServerOption{
grpc.ChainUnaryInterceptor(si.Unary()),
grpc.ChainStreamInterceptor(si.Stream()),
}
}
func (si *serverinterceptor) authenticate(ctx context.Context) error {
if len(si.ExpectedUserAuth.User) == 0 && len(si.ExpectedUserAuth.Password) == 0 {
return nil
}
if md, ok := metadata.FromIncomingContext(ctx); ok {
auths := md.Get("authorization")
if len(auths) != 1 {
si.Lgr.Info("incoming request had no authorization")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
auth := auths[0]
if !strings.HasPrefix(auth, "Basic ") {
si.Lgr.Info("incoming request had malformed authentication header")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
authTrim := strings.TrimPrefix(auth, "Basic ")
uDec, err := base64.URLEncoding.DecodeString(authTrim)
if err != nil {
si.Lgr.Infof("incoming request authorization header failed to decode: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
compare := subtle.ConstantTimeCompare(uDec, []byte(fmt.Sprintf("%s:%s", si.ExpectedUserAuth.User, si.ExpectedUserAuth.Password)))
if compare == 0 {
si.Lgr.Infof("incoming request authorization header failed to match")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
}
return status.Error(codes.Unauthenticated, "unauthenticated")
}

View File

@@ -61,7 +61,8 @@ type ServerArgs struct {
ReadOnly bool
Options []grpc.ServerOption
HttpInterceptor func(http.Handler) http.Handler
HttpInterceptor func(http.Handler) http.Handler
ServerInterceptor serverinterceptor
// If supplied, the listener(s) returned from Listeners() will be TLS
// listeners. The scheme used in the URLs returned from the gRPC server

View File

@@ -465,7 +465,7 @@ func (c *Controller) RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.Server
args.HttpListenAddr = listenaddr
args.GrpcListenAddr = listenaddr
args.Options = c.ServerOptions()
args = sqle.RemoteSrvServerArgs(ctx, args)
args = sqle.RemoteSrvServerArgs(ctx, args, nil)
args.DBCache = remotesrvStoreCache{args.DBCache, c}
keyID := creds.PubKeyToKID(c.pub)

View File

@@ -60,9 +60,14 @@ func (s remotesrvStore) Get(path, nbfVerStr string) (remotesrv.RemoteSrvStore, e
return rss, nil
}
func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs) remotesrv.ServerArgs {
func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs, userAuth *remotesrv.UserAuth) remotesrv.ServerArgs {
sess := dsess.DSessFromSess(ctx.Session)
args.FS = sess.Provider().FileSystem()
args.DBCache = remotesrvStore{ctx, args.ReadOnly}
if userAuth != nil {
args.ServerInterceptor.Lgr = args.Logger
args.ServerInterceptor.ExpectedUserAuth = *userAuth
args.Options = append(args.Options, args.ServerInterceptor.Options()...)
}
return args
}