diff --git a/go/cmd/dolt/commands/sqlserver/queryist_utils.go b/go/cmd/dolt/commands/sqlserver/queryist_utils.go index bc5b06e911..42211b8fe5 100644 --- a/go/cmd/dolt/commands/sqlserver/queryist_utils.go +++ b/go/cmd/dolt/commands/sqlserver/queryist_utils.go @@ -16,6 +16,7 @@ package sqlserver import ( "context" + "crypto/tls" sql2 "database/sql" "fmt" "io" @@ -36,9 +37,33 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/filesys" ) +type QueryistTLSMode int + +const ( + QueryistTLSMode_Disabled QueryistTLSMode = iota + // Require TLS, verify the server certificate using the system + // trust store, do not allow fallback to plaintext. + // + // Used for `dolt --host ... sql ...` when `--no-tls-` is not + // specified. Often used for connecting to Hosted DoltDB + // instances using the CLI commands posted on + // hosted.doltdb.com. + QueryistTLSMode_Enabled + // Used for local Dolt CLI queryist connecting to the running + // local server. In this mode, TLS is allowed but not required + // and the client does not verify the remote TLS + // certificate. It is assumed connecting to the port locally + // is secure and lands the client in the correct place, given + // the contents of sql-server.info, for example. + // + // This mode still does not allow the Dolt CLI to connect to a + // server which requires a client certificate. + QueryistTLSMode_NoVerify_FallbackToPlaintext +) + // BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this // module isn't ideal, but it's the only way to get the server config into the queryist. -func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) { +func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, tlsMode QueryistTLSMode, dbRev string) (cli.LateBindQueryist, error) { clientConfig, err := GetClientConfig(cwdFS, creds, apr) if err != nil { return nil, err @@ -54,8 +79,13 @@ func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, c parsedMySQLConfig.DBName = dbRev parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port) - if useTLS { - parsedMySQLConfig.TLSConfig = "true" + switch tlsMode { + case QueryistTLSMode_Disabled: + case QueryistTLSMode_Enabled: + parsedMySQLConfig.TLS = &tls.Config{} + case QueryistTLSMode_NoVerify_FallbackToPlaintext: + parsedMySQLConfig.TLS = &tls.Config{InsecureSkipVerify: true} + parsedMySQLConfig.AllowFallbackToPlaintext = true } mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig) diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index b01c201f1b..7356732bc4 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -623,8 +623,11 @@ If you're interested in running this command against a remote host, hit us up on if !hasPort { port = 3306 } - useTLS := !apr.Contains(cli.NoTLSFlag) - return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb) + tlsMode := sqlserver.QueryistTLSMode_Enabled + if apr.Contains(cli.NoTLSFlag) { + tlsMode = sqlserver.QueryistTLSMode_Disabled + } + return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, tlsMode, useDb) } else { _, hasPort := apr.GetInt(cli.PortFlag) if hasPort { @@ -712,7 +715,7 @@ If you're interested in running this command against a remote host, hit us up on if !creds.Specified { creds = &cli.UserPassword{Username: sqlserver.LocalConnectionUser, Password: localCreds.Secret, Specified: false} } - return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", localCreds.Port, false, useDb) + return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", localCreds.Port, sqlserver.QueryistTLSMode_NoVerify_FallbackToPlaintext, useDb) } } diff --git a/integration-tests/bats/mutual-tls-auth.bats b/integration-tests/bats/mutual-tls-auth.bats index fc23db2093..250c9e1261 100644 --- a/integration-tests/bats/mutual-tls-auth.bats +++ b/integration-tests/bats/mutual-tls-auth.bats @@ -3,6 +3,7 @@ load $BATS_TEST_DIRNAME/helper/common.bash load $BATS_TEST_DIRNAME/helper/query-server-common.bash REQUIRE_CLIENT_CERT=false +REQUIRE_SECURE_TRANSPORT=false setup() { skiponwindows "tests are flaky on Windows" @@ -27,6 +28,7 @@ listener: host: "0.0.0.0" port: $PORT require_client_cert: $REQUIRE_CLIENT_CERT + require_secure_transport: $REQUIRE_SECURE_TRANSPORT tls_cert: $CERTS_DIR/server-cert.pem tls_key: $CERTS_DIR/server-key.pem EOF @@ -48,6 +50,7 @@ listener: host: "0.0.0.0" port: $PORT require_client_cert: $REQUIRE_CLIENT_CERT + require_secure_transport: $REQUIRE_SECURE_TRANSPORT ca_cert: $CERTS_DIR/ca.pem tls_cert: $CERTS_DIR/server-cert.pem tls_key: $CERTS_DIR/server-key.pem @@ -381,6 +384,16 @@ EOF [[ "$output" =~ "123" ]] || false } +# bats test_tags=no_lambda +@test "mutual-tls-auth: dolt cli works with require_secure_transport" { + REQUIRE_SECURE_TRANSPORT=true + start_sql_server_with_TLS + + run dolt sql -q 'show databases' + [ "$status" -eq 0 ] || false + [[ "$output" =~ "Database" ]] || false +} + # bats test_tags=no_lambda @test "mutual-tls-auth: auth works with require_client_cert (without cert verification)" { dolt sql -q "create user user1@'%';" @@ -447,7 +460,7 @@ EOF run dolt sql -q "SELECT 1;" [ "$status" -ne 0 ] - [[ "$output" =~ "UNAVAILABLE" ]] || false + [[ "$output" =~ "remote error: tls: certificate required" ]] || false } # bats test_tags=no_lambda