Merge pull request #7023 from dolthub/macneale4/fix_dolt_fetch_pull

fix dolt fetch and dolt pull commands to properly authenticate
This commit is contained in:
Neil Macneale IV
2023-11-20 20:32:04 -08:00
committed by GitHub
11 changed files with 208 additions and 84 deletions

View File

@@ -53,6 +53,7 @@ import (
const (
LocalConnectionUser = "__dolt_local_user__"
ApiSqleContextKey = "__sqle_context__"
)
// ExternalDisableUsers is called by implementing applications to disable users. This is not used by Dolt itself,
@@ -384,7 +385,7 @@ func Serve(
}
ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
authenticator := newAccessController(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig
@@ -587,29 +588,66 @@ func acquireGlobalSqlServerLock(port int, dEnv *env.DoltEnv) (*env.DBLock, error
return &lck, nil
}
// remotesapiAuth facilitates the implementation remotesrv.AccessControl for the remotesapi server.
type remotesapiAuth struct {
// ctxFactory is a function that returns a new sql.Context. This will create a new conext every time it is called,
// so it should be called once per API request.
ctxFactory func() (*sql.Context, error)
rawDb *mysql_db.MySQLDb
}
func newAuthenticator(ctxFactory func() (*sql.Context, error), rawDb *mysql_db.MySQLDb) remotesrv.Authenticator {
func newAccessController(ctxFactory func() (*sql.Context, error), rawDb *mysql_db.MySQLDb) remotesrv.AccessControl {
return &remotesapiAuth{ctxFactory, rawDb}
}
func (r *remotesapiAuth) Authenticate(creds *remotesrv.RequestCredentials) bool {
err := commands.ValidatePasswordWithAuthResponse(r.rawDb, creds.Username, creds.Password)
// ApiAuthenticate checks the provided credentials against the database and return a SQL context if the credentials are
// valid. If the credentials are invalid, then a nil context is returned. Failures to authenticate are logged.
func (r *remotesapiAuth) ApiAuthenticate(ctx context.Context) (context.Context, error) {
creds, err := remotesrv.ExtractBasicAuthCreds(ctx)
if err != nil {
return false
return nil, err
}
ctx, err := r.ctxFactory()
err = commands.ValidatePasswordWithAuthResponse(r.rawDb, creds.Username, creds.Password)
if err != nil {
return false
return nil, fmt.Errorf("API Authentication Failure: %v", err)
}
address := creds.Address
if strings.Index(address, ":") > 0 {
address, _, err = net.SplitHostPort(creds.Address)
if err != nil {
return nil, fmt.Errorf("Invlaid Host string for authentication: %s", creds.Address)
}
}
sqlCtx, err := r.ctxFactory()
if err != nil {
return nil, fmt.Errorf("API Runtime error: %v", err)
}
sqlCtx.Session.SetClient(sql.Client{User: creds.Username, Address: address, Capabilities: 0})
updatedCtx := context.WithValue(ctx, ApiSqleContextKey, sqlCtx)
return updatedCtx, nil
}
func (r *remotesapiAuth) ApiAuthorize(ctx context.Context) (bool, error) {
sqlCtx, ok := ctx.Value(ApiSqleContextKey).(*sql.Context)
if !ok {
return false, fmt.Errorf("Runtime error: could not get SQL context from context")
}
ctx.Session.SetClient(sql.Client{User: creds.Username, Address: creds.Address, Capabilities: 0})
privOp := sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_CloneAdmin)
return r.rawDb.UserHasPrivileges(ctx, privOp)
authorized := r.rawDb.UserHasPrivileges(sqlCtx, privOp)
if !authorized {
return false, fmt.Errorf("API Authorization Failure: %s has not been granted CLONE_ADMIN access", sqlCtx.Session.Client().User)
}
return true, nil
}
func LoadClusterTLSConfig(cfg cluster.Config) (*tls.Config, error) {

View File

@@ -33,6 +33,7 @@ import (
)
var GRPCDialProviderParam = "__DOLT__grpc_dial_provider"
var GRPCUsernameAuthParam = "__DOLT__grpc_username"
type GRPCRemoteConfig struct {
Endpoint string
@@ -100,10 +101,15 @@ func (fact DoltRemoteFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFo
var NoCachingParameter = "__dolt__NO_CACHING"
func (fact DoltRemoteFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}, dp GRPCDialProvider) (chunks.ChunkStore, error) {
var user string
if userParam := params[GRPCUsernameAuthParam]; userParam != nil {
user = userParam.(string)
}
cfg, err := dp.GetGRPCDialParams(grpcendpoint.Config{
Endpoint: urlObj.Host,
Insecure: fact.insecure,
WithEnvCreds: true,
Endpoint: urlObj.Host,
Insecure: fact.insecure,
UserIdForOsEnvAuth: user,
WithEnvCreds: true,
})
if err != nil {
return nil, err

View File

@@ -16,8 +16,10 @@ package env
import (
"crypto/tls"
"errors"
"net"
"net/http"
"os"
"runtime"
"strings"
"unicode"
@@ -25,7 +27,9 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
)
@@ -88,9 +92,18 @@ func (p GRPCDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
if config.Creds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(config.Creds))
} else if config.WithEnvCreds {
rpcCreds, err := p.getRPCCreds(endpoint)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
var rpcCreds credentials.PerRPCCredentials
var err error
if config.UserIdForOsEnvAuth != "" {
rpcCreds, err = p.getRPCCredsFromOSEnv(config.UserIdForOsEnvAuth)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
} else {
rpcCreds, err = p.getRPCCreds(endpoint)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
}
if rpcCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(rpcCreds))
@@ -103,6 +116,24 @@ func (p GRPCDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
}, nil
}
// getRPCCredsFromOSEnv returns RPC Credentials for the specified username, using the DOLT_REMOTE_PASSWORD
func (p GRPCDialProvider) getRPCCredsFromOSEnv(username string) (credentials.PerRPCCredentials, error) {
if username == "" {
return nil, errors.New("Runtime error: username must be provided to getRPCCredsFromOSEnv")
}
pass, found := os.LookupEnv(dconfig.EnvDoltRemotePassword)
if !found {
return nil, errors.New("error: must set DOLT_REMOTE_PASSWORD environment variable to use --user param")
}
c := creds.DoltCredsForPass{
Username: username,
Password: pass,
}
return c.RPCCreds(), nil
}
// getRPCCreds returns any RPC credentials available to this dial provider. If a DoltEnv has been configured
// in this dial provider, it will be used to load custom user credentials, otherwise nil will be returned.
func (p GRPCDialProvider) getRPCCreds(endpoint string) (credentials.PerRPCCredentials, error) {

View File

@@ -131,6 +131,16 @@ func (r *Remote) GetRemoteDBWithoutCaching(ctx context.Context, nbf *types.NomsB
return doltdb.LoadDoltDBWithParams(ctx, nbf, r.Url, filesys2.LocalFS, params)
}
func (r Remote) WithParams(params map[string]string) Remote {
fetchSpecs := make([]string, len(r.FetchSpecs))
copy(fetchSpecs, r.FetchSpecs)
for k, v := range r.Params {
params[k] = v
}
r.Params = params
return r
}
// PushOptions contains information needed for push for
// one or more branches or a tag for a specific remote database.
type PushOptions struct {

View File

@@ -27,6 +27,12 @@ type Config struct {
Creds credentials.PerRPCCredentials
WithEnvCreds bool
// If this is non-empty, and WithEnvCreds is true, then the caller is
// requesting to use username/password authentication instead of JWT
// authentication against the gRPC endpoint. Currently, the password
// comes from the OS environment variable DOLT_REMOTE_PASSWORD.
UserIdForOsEnvAuth string
// If non-nil, this is used for transport level security in the dial
// options, instead of a default option based on `Insecure`.
TLSConfig *tls.Config

View File

@@ -30,6 +30,8 @@ import (
"strings"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/hash"
@@ -397,3 +399,40 @@ func getFileReaderAt(path string, offset int64, length int64) (io.ReadCloser, in
r := closerReaderWrapper{io.LimitReader(f, length), f}
return r, fSize, nil
}
// ExtractBasicAuthCreds extracts the username and password from the incoming request. It returns RequestCredentials
// populated with necessary information to authenticate the request. nil and an error will be returned if any error
// occurs.
func ExtractBasicAuthCreds(ctx context.Context) (*RequestCredentials, error) {
if md, ok := metadata.FromIncomingContext(ctx); !ok {
return nil, errors.New("no metadata in context")
} else {
var username string
var password string
auths := md.Get("authorization")
if len(auths) != 1 {
username = "root"
password = ""
} else {
auth := auths[0]
if !strings.HasPrefix(auth, "Basic ") {
return nil, fmt.Errorf("bad request: authorization header did not start with 'Basic '")
}
authTrim := strings.TrimPrefix(auth, "Basic ")
uDec, err := base64.URLEncoding.DecodeString(authTrim)
if err != nil {
return nil, fmt.Errorf("incoming request authorization header failed to decode: %v", err)
}
userPass := strings.Split(string(uDec), ":")
username = userPass[0]
password = userPass[1]
}
addr, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("incoming request had no peer")
}
return &RequestCredentials{Username: username, Password: password, Address: addr.Addr.String()}, nil
}
}

View File

@@ -16,14 +16,10 @@ package remotesrv
import (
"context"
"encoding/base64"
"strings"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
@@ -34,12 +30,20 @@ type RequestCredentials struct {
}
type ServerInterceptor struct {
Lgr *logrus.Entry
Authenticator Authenticator
Lgr *logrus.Entry
AccessController AccessControl
}
type Authenticator interface {
Authenticate(creds *RequestCredentials) bool
// AccessControl is an interface that provides authentication and authorization for the gRPC server.
type AccessControl interface {
// ApiAuthenticate checks the incoming request for authentication credentials and validates them. If the user's
// identity checks out, the returned context will have the sqlContext within it, which contains the user's ID.
// If the user is not legitimate, an error is returned.
ApiAuthenticate(ctx context.Context) (context.Context, error)
// ApiAuthorize checks that the authenticated user has sufficient privileges to perform the requested action.
// Currently, CLONE_ADMIN is required. True and a nil error returned if the user is authorized, otherwise false
// with an error.
ApiAuthorize(ctx context.Context) (bool, error)
}
func (si *ServerInterceptor) Stream() grpc.StreamServerInterceptor {
@@ -69,40 +73,23 @@ func (si *ServerInterceptor) Options() []grpc.ServerOption {
}
}
// authenticate checks the incoming request for authentication credentials and validates them. If the user is
// legitimate, an authorization check is performed. If no error is returned, the user should be allowed to proceed.
func (si *ServerInterceptor) authenticate(ctx context.Context) error {
if md, ok := metadata.FromIncomingContext(ctx); ok {
var username string
var password string
auths := md.Get("authorization")
if len(auths) != 1 {
username = "root"
} else {
auth := auths[0]
if !strings.HasPrefix(auth, "Basic ") {
si.Lgr.Info("incoming request had malformed authentication header")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
authTrim := strings.TrimPrefix(auth, "Basic ")
uDec, err := base64.URLEncoding.DecodeString(authTrim)
if err != nil {
si.Lgr.Infof("incoming request authorization header failed to decode: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
userPass := strings.Split(string(uDec), ":")
username = userPass[0]
password = userPass[1]
}
addr, ok := peer.FromContext(ctx)
if !ok {
si.Lgr.Info("incoming request had no peer")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
if authed := si.Authenticator.Authenticate(&RequestCredentials{Username: username, Password: password, Address: addr.Addr.String()}); !authed {
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
ctx, err := si.AccessController.ApiAuthenticate(ctx)
if err != nil {
si.Lgr.Warnf("authentication failed: %s", err.Error())
status.Error(codes.Unauthenticated, "unauthenticated")
return err
}
return status.Error(codes.Unauthenticated, "unauthenticated 1")
// Have a valid user in the context. Check authorization.
if authorized, err := si.AccessController.ApiAuthorize(ctx); !authorized {
si.Lgr.Warnf("authorization failed: %s", err.Error())
status.Error(codes.PermissionDenied, "unauthorized")
return err
}
// Access Granted.
return nil
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
@@ -73,6 +74,12 @@ func doDoltFetch(ctx *sql.Context, args []string) (int, error) {
return cmdFailure, err
}
if user, hasUser := apr.GetValue(cli.UserFlag); hasUser {
remote = remote.WithParams(map[string]string{
dbfactory.GRPCUsernameAuthParam: user,
})
}
srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), remote, false)
if err != nil {
return 1, err

View File

@@ -24,6 +24,7 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
@@ -93,6 +94,12 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) {
return noConflictsOrViolations, threeWayMerge, err
}
if user, hasUser := apr.GetValue(cli.UserFlag); hasUser {
pullSpec.Remote = pullSpec.Remote.WithParams(map[string]string{
dbfactory.GRPCUsernameAuthParam: user,
})
}
srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), pullSpec.Remote, false)
if err != nil {
return noConflictsOrViolations, threeWayMerge, fmt.Errorf("failed to get remote db; %w", err)

View File

@@ -78,10 +78,10 @@ func RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.Context, error),
return args, nil
}
func WithUserPasswordAuth(args remotesrv.ServerArgs, auth remotesrv.Authenticator) remotesrv.ServerArgs {
func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessControl) remotesrv.ServerArgs {
si := remotesrv.ServerInterceptor{
Lgr: args.Logger,
Authenticator: auth,
Lgr: args.Logger,
AccessController: authnz,
}
args.Options = append(args.Options, si.Options()...)
return args

View File

@@ -162,12 +162,10 @@ select count(*) from vals;
}
@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with authentication" {
skip "only support authenticating fetch with dolthub for now."
mkdir remote
cd remote
dolt init
dolt --privilege-file=privs.json sql -q "CREATE USER user IDENTIFIED BY 'pass0'"
dolt --privilege-file=privs.json sql -q "CREATE USER user0 IDENTIFIED BY 'pass0'"
dolt sql -q 'create table vals (i int);'
dolt sql -q 'insert into vals (i) values (1), (2), (3), (4), (5);'
dolt add vals
@@ -187,7 +185,7 @@ select count(*) from vals;
run dolt sql -q 'select count(*) from vals'
[[ "$output" =~ "5" ]] || false
dolt --port 3307 --host localhost -u $DOLT_REMOTE_USER -p $DOLT_REMOTE_PASSWORD sql -q "
dolt --port 3307 --host localhost --no-tls -u $DOLT_REMOTE_USER -p $DOLT_REMOTE_PASSWORD sql -q "
use remote;
call dolt_checkout('-b', 'new_branch');
insert into vals (i) values (6), (7), (8), (9), (10);
@@ -202,7 +200,7 @@ call dolt_commit('-am', 'add some vals');
# No auth fetch
run dolt fetch
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'root'" ]] || false
# # With auth fetch
run dolt fetch -u $DOLT_REMOTE_USER
@@ -216,7 +214,7 @@ call dolt_commit('-am', 'add some vals');
run dolt checkout new_branch
[[ "$status" -eq 0 ]] || false
dolt --port 3307 --host localhost -u $DOLT_REMOTE_USER -p $DOLT_REMOTE_PASSWORD sql -q "
dolt --port 3307 --host localhost --no-tls -u $DOLT_REMOTE_USER -p $DOLT_REMOTE_PASSWORD sql -q "
use remote;
call dolt_checkout('new_branch');
insert into vals (i) values (11);
@@ -226,7 +224,7 @@ call dolt_commit('-am', 'add one val');
# No auth pull
run dolt pull
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'root'" ]] || false
# With auth pull
run dolt pull -u $DOLT_REMOTE_USER
@@ -236,8 +234,6 @@ call dolt_commit('-am', 'add one val');
}
@test "sql-server-remotesrv: clone/fetch/pull from remotesapi port with clone_admin authentication" {
skip "only support authenticating fetch with dolthub for now."
mkdir remote
cd remote
dolt init
@@ -250,11 +246,12 @@ call dolt_commit('-am', 'add one val');
srv_pid=$!
sleep 2 # wait for server to start so we don't lock it out
run dolt --port 3307 --host localhost -u user0 -p pass0 sql -q "
CREATE USER clone_admin_user@'%' IDENTIFIED BY 'pass1';
GRANT CLONE_ADMIN ON *.* TO clone_admin_user@'%';
run dolt sql -q "
CREATE USER clone_admin_user@'localhost' IDENTIFIED BY 'pass1';
GRANT CLONE_ADMIN ON *.* TO clone_admin_user@'localhost';
select user from mysql.user;
"
[ $status -eq 0 ]
[[ $output =~ user0 ]] || false
[[ $output =~ clone_admin_user ]] || false
@@ -268,12 +265,10 @@ select user from mysql.user;
run dolt sql -q 'select count(*) from vals'
[[ "$output" =~ "5" ]] || false
dolt --port 3307 --host localhost -u user0 -p pass0 sql -q "
use remote;
dolt --port 3307 --host localhost -u user0 -p pass0 --no-tls --use-db remote sql -q "
call dolt_checkout('-b', 'new_branch');
insert into vals (i) values (6), (7), (8), (9), (10);
call dolt_commit('-am', 'add some vals');
"
call dolt_commit('-am', 'add some vals');"
run dolt branch -v -a
[ "$status" -eq 0 ]
@@ -283,7 +278,7 @@ call dolt_commit('-am', 'add some vals');
# No auth fetch
run dolt fetch
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'root'" ]] || false
# # With auth fetch
run dolt fetch -u clone_admin_user
@@ -297,17 +292,15 @@ call dolt_commit('-am', 'add some vals');
run dolt checkout new_branch
[[ "$status" -eq 0 ]] || false
dolt --port 3307 --host localhost -u user0 -p pass0 sql -q "
use remote;
dolt sql -q "
call dolt_checkout('new_branch');
insert into vals (i) values (11);
call dolt_commit('-am', 'add one val');
"
call dolt_commit('-am', 'add one val');"
# No auth pull
run dolt pull
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'root'" ]] || false
# With auth pull
run dolt pull -u clone_admin_user
@@ -334,7 +327,7 @@ call dolt_commit('-am', 'add one val');
cd ../
run dolt clone http://localhost:50051/remote repo1
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'root'" ]] || false
}
@test "sql-server-remotesrv: dolt clone with incorrect authentication returns error" {
@@ -361,10 +354,10 @@ call dolt_commit('-am', 'add one val');
export DOLT_REMOTE_PASSWORD="wrong-password"
run dolt clone http://localhost:50051/remote repo1 -u $DOLT_REMOTE_USER
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'user0'" ]] || false
export DOLT_REMOTE_PASSWORD="pass0"
run dolt clone http://localhost:50051/remote repo1 -u doesnt_exist
[[ "$status" != 0 ]] || false
[[ "$output" =~ "Unauthenticated" ]] || false
[[ "$output" =~ "Access denied for user 'doesnt_exist'" ]] || false
}