diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 78660f60db..7e7b2ea91c 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -16,6 +16,7 @@ package sqlserver import ( "context" + "fmt" "net" "strconv" "time" @@ -123,7 +124,14 @@ func Serve(ctx context.Context, version string, serverConfig ServerConfig, serve sqlEngine.AddDatabase(information_schema.NewInformationSchemaDatabase(sqlEngine.Catalog)) - hostPort := net.JoinHostPort(serverConfig.Host(), strconv.Itoa(serverConfig.Port())) + portAsString := strconv.Itoa(serverConfig.Port()) + hostPort := net.JoinHostPort(serverConfig.Host(), portAsString) + + if portInUse(hostPort) { + portInUseError := fmt.Errorf("Port %s already in use.", portAsString) + return portInUseError, nil + } + readTimeout := time.Duration(serverConfig.ReadTimeout()) * time.Millisecond writeTimeout := time.Duration(serverConfig.WriteTimeout()) * time.Millisecond mySQLServer, startError = server.NewServer( @@ -155,6 +163,16 @@ func Serve(ctx context.Context, version string, serverConfig ServerConfig, serve return } +func portInUse(hostPort string) bool { + timeout := time.Second + conn, _ := net.DialTimeout("tcp", hostPort, timeout) + if conn != nil { + defer conn.Close() + return true + } + return false +} + func newSessionBuilder(sqlEngine *sqle.Engine, username, email string, autocommit bool) server.SessionBuilder { return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, *sql.IndexRegistry, *sql.ViewRegistry, error) { mysqlSess := sql.NewSession(host, conn.RemoteAddr().String(), conn.User, conn.ConnectionID) diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index 1874f79564..9fd32ae8ae 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -15,6 +15,7 @@ package sqlserver import ( + "net/http" "strings" "testing" @@ -216,3 +217,27 @@ func TestServerSelect(t *testing.T) { }) } } + +// If a port is already in use, throw error "Port XXXX already in use." +func TestServerFailsIfPortInUse(t *testing.T) { + serverController := CreateServerController() + server := &http.Server{ + Addr: ":15200", + Handler: http.DefaultServeMux, + } + go server.ListenAndServe() + go func() { + startServer(context.Background(), "test", "dolt sql-server", []string{ + "-H", "localhost", + "-P", "15200", + "-u", "username", + "-p", "password", + "-t", "5", + "-l", "info", + "-r", + }, dtestutils.CreateEnvWithSeedData(t), serverController) + }() + err := serverController.WaitForStart() + require.Error(t, err) + server.Close() +} diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index 7acd6caef5..1ed0e63a3f 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -20,6 +20,19 @@ teardown() { teardown_common } +@test "sql-server: port in use" { + cd repo1 + + let PORT="$$ % (65536-1024) + 1024" + dolt sql-server --host 0.0.0.0 --port=$PORT --user dolt & + SERVER_PID=$! # will get killed by teardown_common + sleep 5 # not using python wait so this works on windows + + run dolt sql-server --host 0.0.0.0 --port=$PORT --user dolt + [ "$status" -eq 1 ] + [[ "$output" =~ "in use" ]] || false +} + @test "sql-server: multi-client" { skiponwindows "Has dependencies that are missing on the Jenkins Windows installation." @@ -715,4 +728,4 @@ SQL insert_query 1 "INSERT INTO js_test VALUES (1, '{\"a\":1}');" server_query 1 "SELECT * FROM js_test;" "pk,js\n1,{\"a\": 1}" -} \ No newline at end of file +}