mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-21 00:54:26 -06:00
Progress on dolt clone auth
This commit is contained in:
@@ -115,6 +115,7 @@ const (
|
||||
ShallowFlag = "shallow"
|
||||
CachedFlag = "cached"
|
||||
ListFlag = "list"
|
||||
UserParam = "user"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -192,6 +193,7 @@ func CreateCloneArgParser() *argparser.ArgParser {
|
||||
ap.SupportsString(dbfactory.AWSCredsProfile, "", "profile", "AWS profile to use.")
|
||||
ap.SupportsString(dbfactory.OSSCredsFileParam, "", "file", "OSS credentials file.")
|
||||
ap.SupportsString(dbfactory.OSSCredsProfile, "", "profile", "OSS profile to use.")
|
||||
ap.SupportsString(UserParam, "u", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
|
||||
return ap
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
|
||||
eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
@@ -100,6 +102,8 @@ func clone(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEn
|
||||
return verr
|
||||
}
|
||||
|
||||
dEnv.UserPassConfig = getUserAndPassConfig(apr)
|
||||
|
||||
userDirExists, _ := dEnv.FS.Exists(dir)
|
||||
|
||||
// Check for a valid dolthub url and replace the urlStr with the parsed repoName.
|
||||
@@ -186,7 +190,7 @@ func parseArgs(apr *argparser.ArgParseResults) (string, string, errhand.VerboseE
|
||||
if dir == "." {
|
||||
dir = path.Dir(urlStr)
|
||||
} else if dir == "/" {
|
||||
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitily define a directory for this url").Build()
|
||||
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,3 +232,14 @@ func validateAndParseDolthubUrl(urlStr string) (string, bool) {
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func getUserAndPassConfig(apr *argparser.ArgParseResults) *creds.DoltCredsForPass {
|
||||
if !apr.Contains(cli.UserParam) {
|
||||
return nil
|
||||
}
|
||||
pass := os.Getenv("DOLT_REMOTE_PASSWORD")
|
||||
return &creds.DoltCredsForPass{
|
||||
Username: apr.GetValueOrDefault(cli.UserParam, ""),
|
||||
Password: pass,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +238,7 @@ func Serve(
|
||||
ReadOnly: true,
|
||||
HttpListenAddr: listenaddr,
|
||||
GrpcListenAddr: listenaddr,
|
||||
})
|
||||
}, &remotesrv.UserAuth{User: serverConfig.User(), Password: serverConfig.Password()})
|
||||
args.TLSConfig = serverConf.TLSConfig
|
||||
remoteSrv, err = remotesrv.NewServer(args)
|
||||
if err != nil {
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
@@ -58,6 +60,11 @@ type DoltCreds struct {
|
||||
KeyID []byte
|
||||
}
|
||||
|
||||
type DoltCredsForPass struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func PubKeyStrToKIDStr(pub string) (string, error) {
|
||||
data, err := B32CredsEncoding.DecodeString(pub)
|
||||
|
||||
@@ -119,15 +126,20 @@ func (dc DoltCreds) Sign(data []byte) []byte {
|
||||
}
|
||||
|
||||
type RPCCreds struct {
|
||||
PrivKey ed25519.PrivateKey
|
||||
KeyID string
|
||||
Audience string
|
||||
Issuer string
|
||||
Subject string
|
||||
RequireTLS bool
|
||||
PrivKey ed25519.PrivateKey
|
||||
KeyID string
|
||||
Audience string
|
||||
Issuer string
|
||||
Subject string
|
||||
RequireTLS bool
|
||||
UserPassContents string
|
||||
}
|
||||
|
||||
func (c *RPCCreds) toBearerToken() (string, error) {
|
||||
if len(c.UserPassContents) > 0 {
|
||||
return "", fmt.Errorf("cannot create bearer token with user/pass credentials")
|
||||
}
|
||||
|
||||
key := jose.SigningKey{Algorithm: jose.EdDSA, Key: c.PrivKey}
|
||||
opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
JWTKIDHeader: c.KeyID,
|
||||
@@ -151,6 +163,11 @@ func (c *RPCCreds) toBearerToken() (string, error) {
|
||||
}
|
||||
|
||||
func (c *RPCCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
if len(c.UserPassContents) > 0 {
|
||||
return map[string]string{
|
||||
"authorization": "Basic " + c.UserPassContents,
|
||||
}, nil
|
||||
}
|
||||
t, err := c.toBearerToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -177,3 +194,14 @@ func (dc DoltCreds) RPCCreds(audience string) *RPCCreds {
|
||||
RequireTLS: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (dcp DoltCredsForPass) ToBase64Str() string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", dcp.Username, dcp.Password)))
|
||||
}
|
||||
|
||||
func (dc DoltCredsForPass) RPCCreds() *RPCCreds {
|
||||
return &RPCCreds{
|
||||
RequireTLS: false,
|
||||
UserPassContents: dc.ToBase64Str(),
|
||||
}
|
||||
}
|
||||
|
||||
1
go/libraries/doltcore/env/environment.go
vendored
1
go/libraries/doltcore/env/environment.go
vendored
@@ -96,6 +96,7 @@ type DoltEnv struct {
|
||||
hdp HomeDirProvider
|
||||
|
||||
IgnoreLockFile bool
|
||||
UserPassConfig *creds.DoltCredsForPass
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r Remote, withCaching bool) (*doltdb.DoltDB, error) {
|
||||
|
||||
@@ -110,6 +110,10 @@ func (p GRPCDialProvider) getRPCCreds(endpoint string) (credentials.PerRPCCreden
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.dEnv.UserPassConfig != nil {
|
||||
return p.dEnv.UserPassConfig.RPCCreds(), nil
|
||||
}
|
||||
|
||||
dCreds, valid, err := p.dEnv.UserDoltCreds()
|
||||
if err != nil {
|
||||
return nil, ErrInvalidCredsFile
|
||||
|
||||
98
go/libraries/doltcore/remotesrv/interceptors.go
Normal file
98
go/libraries/doltcore/remotesrv/interceptors.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright 2023 DoltHub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package remotesrv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
type serverinterceptor struct {
|
||||
Lgr *logrus.Entry
|
||||
ExpectedUserAuth UserAuth
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := si.authenticate(ss.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) Unary() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if err := si.authenticate(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) Options() []grpc.ServerOption {
|
||||
return []grpc.ServerOption{
|
||||
grpc.ChainUnaryInterceptor(si.Unary()),
|
||||
grpc.ChainStreamInterceptor(si.Stream()),
|
||||
}
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) authenticate(ctx context.Context) error {
|
||||
if len(si.ExpectedUserAuth.User) == 0 && len(si.ExpectedUserAuth.Password) == 0 {
|
||||
return nil
|
||||
}
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
auths := md.Get("authorization")
|
||||
if len(auths) != 1 {
|
||||
si.Lgr.Info("incoming request had no authorization")
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
auth := auths[0]
|
||||
if !strings.HasPrefix(auth, "Basic ") {
|
||||
si.Lgr.Info("incoming request had malformed authentication header")
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
authTrim := strings.TrimPrefix(auth, "Basic ")
|
||||
uDec, err := base64.URLEncoding.DecodeString(authTrim)
|
||||
if err != nil {
|
||||
si.Lgr.Infof("incoming request authorization header failed to decode: %v", err)
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
compare := subtle.ConstantTimeCompare(uDec, []byte(fmt.Sprintf("%s:%s", si.ExpectedUserAuth.User, si.ExpectedUserAuth.Password)))
|
||||
if compare == 0 {
|
||||
|
||||
si.Lgr.Infof("incoming request authorization header failed to match")
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
@@ -61,7 +61,8 @@ type ServerArgs struct {
|
||||
ReadOnly bool
|
||||
Options []grpc.ServerOption
|
||||
|
||||
HttpInterceptor func(http.Handler) http.Handler
|
||||
HttpInterceptor func(http.Handler) http.Handler
|
||||
ServerInterceptor serverinterceptor
|
||||
|
||||
// If supplied, the listener(s) returned from Listeners() will be TLS
|
||||
// listeners. The scheme used in the URLs returned from the gRPC server
|
||||
|
||||
@@ -465,7 +465,7 @@ func (c *Controller) RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.Server
|
||||
args.HttpListenAddr = listenaddr
|
||||
args.GrpcListenAddr = listenaddr
|
||||
args.Options = c.ServerOptions()
|
||||
args = sqle.RemoteSrvServerArgs(ctx, args)
|
||||
args = sqle.RemoteSrvServerArgs(ctx, args, nil)
|
||||
args.DBCache = remotesrvStoreCache{args.DBCache, c}
|
||||
|
||||
keyID := creds.PubKeyToKID(c.pub)
|
||||
|
||||
@@ -60,9 +60,14 @@ func (s remotesrvStore) Get(path, nbfVerStr string) (remotesrv.RemoteSrvStore, e
|
||||
return rss, nil
|
||||
}
|
||||
|
||||
func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs) remotesrv.ServerArgs {
|
||||
func RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.ServerArgs, userAuth *remotesrv.UserAuth) remotesrv.ServerArgs {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
args.FS = sess.Provider().FileSystem()
|
||||
args.DBCache = remotesrvStore{ctx, args.ReadOnly}
|
||||
if userAuth != nil {
|
||||
args.ServerInterceptor.Lgr = args.Logger
|
||||
args.ServerInterceptor.ExpectedUserAuth = *userAuth
|
||||
args.Options = append(args.Options, args.ServerInterceptor.Options()...)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user