diff --git a/go/go.mod b/go/go.mod index ea9de6ce73..52a4d6a371 100644 --- a/go/go.mod +++ b/go/go.mod @@ -18,7 +18,7 @@ require ( github.com/dustin/go-humanize v1.0.1 github.com/fatih/color v1.13.0 github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 - github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d + github.com/go-sql-driver/mysql v1.9.1 github.com/gocraft/dbr/v2 v2.7.2 github.com/golang/snappy v0.0.4 github.com/google/uuid v1.6.0 diff --git a/go/go.sum b/go/go.sum index ea48af52a5..f9340f71c5 100644 --- a/go/go.sum +++ b/go/go.sum @@ -297,8 +297,8 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d h1:QQP1nE4qh5aHTGvI1LgOFxZYVxYoGeMfbNHikogPyoA= -github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-sql-driver/mysql v1.9.1 h1:FrjNGn/BsJQjVRuSa8CBrM5BWA9BWoXXat3KrtSb/iI= +github.com/go-sql-driver/mysql v1.9.1/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go index 802abf595c..74325f8d86 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go @@ -17,7 +17,9 @@ package sql_server_driver import ( "bufio" "bytes" + "context" "database/sql" + "database/sql/driver" "fmt" "io" "log" @@ -28,6 +30,7 @@ import ( "sync" "time" + "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql" ) @@ -379,15 +382,38 @@ func (s *SqlServer) Restart(newargs *[]string, newenvs *[]string) error { } func (s *SqlServer) DB(c Connection) (*sql.DB, error) { - var pass string + connector, err := s.Connector(c) + if err != nil { + return nil, err + } + return OpenDB(connector) +} + +// If a test needs to circumvent the database/sql connection pool +// it can use the raw MySQL connector. +func (s *SqlServer) Connector(c Connection) (driver.Connector, error) { pass, err := c.Password() if err != nil { return nil, err } - return ConnectDB(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams) + dsn := GetDSN(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams) + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + return nil, err + } + // See the comment on WithConnectRetriesDisabled for why we do this. + cfg.Apply(mysql.BeforeConnect(func(ctx context.Context, cfg *mysql.Config) error { + // TODO: This could be more robust if we sniffed it on first connect. + const numAttemptsGoLibraryMakes = 3 + if attempt, ok := incrementConnectRetryAttempts(ctx); ok && attempt < numAttemptsGoLibraryMakes { + return driver.ErrBadConn + } + return nil + })) + return mysql.NewConnector(cfg) } -func ConnectDB(user, password, name, host string, port int, driverParams map[string]string) (*sql.DB, error) { +func GetDSN(user, password, name, host string, port int, driverParams map[string]string) string { params := make(url.Values) params.Set("allowAllFiles", "true") params.Set("tls", "preferred") @@ -395,7 +421,27 @@ func ConnectDB(user, password, name, host string, port int, driverParams map[str params.Set(k, v) } dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", user, password, host, port, name, params.Encode()) + return dsn +} +func OpenDB(connector driver.Connector) (*sql.DB, error) { + db := sql.OpenDB(connector) + var err error + for i := 0; i < ConnectAttempts; i++ { + err = db.Ping() + if err == nil { + return db, nil + } + time.Sleep(RetrySleepDuration) + } + if err != nil { + return nil, err + } + return db, nil +} + +func ConnectDB(user, password, name, host string, port int, params map[string]string) (*sql.DB, error) { + dsn := GetDSN(user, password, name, host, port, params) db, err := sql.Open("mysql", dsn) if err != nil { return nil, err @@ -412,3 +458,40 @@ func ConnectDB(user, password, name, host string, port int, driverParams map[str } return db, nil } + +type connectRetryAttemptKeyType int +var connectRetryAttemptKey connectRetryAttemptKeyType + +// The database/sql package in Go takes connections out of a +// connection pool or opens a new connection to the database. It has +// logic in it that looks for driver.ErrBadConn responses when it is +// opening a connection and automatically retries those connections a +// fixed number of times before actually surfacing the error to the +// caller. This behavior interferes with some testing we want to do +// against the behavior of the server. +// +// WithConnectionRetriesDisabled is a hack which circumvents these +// retries. It works by embedding a counter into the returned context, +// which should then be passed to *sql.DB.Conn(). An interceptor has +// been installed on the |mysql.Connector| which looks for this +// counter and fast-fails the first few calls into the driver. That +// way the first call which goes through to the driver is the last and +// final retry from *sql.DB. +func WithConnectRetriesDisabled(ctx context.Context) context.Context { + return context.WithValue(ctx, connectRetryAttemptKey, &retryAttempt{}) +} + +type retryAttempt struct { + attempt int +} + +func incrementConnectRetryAttempts(ctx context.Context) (int, bool) { + v := ctx.Value(connectRetryAttemptKey) + if v != nil { + if v, ok := v.(*retryAttempt); ok { + v.attempt += 1 + return v.attempt, true + } + } + return 0, false +} diff --git a/integration-tests/go-sql-server-driver/go.mod b/integration-tests/go-sql-server-driver/go.mod index befdcd307c..24d7f10462 100644 --- a/integration-tests/go-sql-server-driver/go.mod +++ b/integration-tests/go-sql-server-driver/go.mod @@ -15,7 +15,7 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/creasty/defaults v1.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d // indirect + github.com/go-sql-driver/mysql v1.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.35.0 // indirect golang.org/x/sys v0.30.0 // indirect diff --git a/integration-tests/go-sql-server-driver/go.sum b/integration-tests/go-sql-server-driver/go.sum index 165eb4766d..aa1987b852 100644 --- a/integration-tests/go-sql-server-driver/go.sum +++ b/integration-tests/go-sql-server-driver/go.sum @@ -4,8 +4,8 @@ github.com/creasty/defaults v1.6.0 h1:ltuE9cfphUtlrBeomuu8PEyISTXnxqkBIoQfXgv7BS github.com/creasty/defaults v1.6.0/go.mod h1:iGzKe6pbEHnpMPtfDXZEr0NVxWnPTjb1bbDy08fPzYM= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d h1:QQP1nE4qh5aHTGvI1LgOFxZYVxYoGeMfbNHikogPyoA= -github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-sql-driver/mysql v1.9.1 h1:FrjNGn/BsJQjVRuSa8CBrM5BWA9BWoXXat3KrtSb/iI= +github.com/go-sql-driver/mysql v1.9.1/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/integration-tests/go-sql-server-driver/sql_server_max_conns_test.go b/integration-tests/go-sql-server-driver/sql_server_max_conns_test.go new file mode 100644 index 0000000000..9b6080462d --- /dev/null +++ b/integration-tests/go-sql-server-driver/sql_server_max_conns_test.go @@ -0,0 +1,207 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver" +) + +func TestSqlServerMaxConns(t *testing.T) { + t.Run("MaxConns 3", testMaxConns3) + t.Run("MaxConns 3 BackLog 0", testMaxConns3BackLog0) + t.Run("MaxConns 3 BackLog 1", testMaxConns3BackLog1) + t.Run("MaxConns 3 MaxConnectionsTimeout 10s", testMaxConns3Timeout10s) +} + +func setupMaxConnsTest(t *testing.T, ctx context.Context, args ...string) (*sql.DB, []*sql.Conn) { + // TODO: These tests should run parallel once |main| is merged. + // Add "--port", `{{get_port "server"}}` to the args and add t.Parallel() to this function. + u, err := driver.NewDoltUser() + require.NoError(t, err) + t.Cleanup(func() { + u.Cleanup() + }) + rs, err := u.MakeRepoStore() + require.NoError(t, err) + repo, err := rs.MakeRepo("max_conns_test") + require.NoError(t, err) + args = append(args, "--max-connections", "3") + srvSettings := &driver.Server{ + Args: args, + } + server := MakeServer(t, repo, srvSettings) + server.DBName = "max_conns_test" + db, err := server.DB(driver.Connection{User: "root"}) + require.NoError(t, err) + t.Cleanup(func() { + db.Close() + }) + db.SetMaxIdleConns(0) + + var conns []*sql.Conn + t.Cleanup(func() { + closeAll(conns) + }) + for i := 0; i < 3; i++ { + conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + require.NoError(t, err) + conns = append(conns, conn) + } + + _, err = conns[0].ExecContext(ctx, ` +CREATE TABLE test_table ( + id INT AUTO_INCREMENT PRIMARY KEY, + str VARCHAR(20) +);`) + require.NoError(t, err) + + return db, conns +} + +func closeAll(conns []*sql.Conn) { + for i, c := range conns { + if c != nil { + c.Close() + } + conns[i] = nil + } +} + +func testMaxConns3BackLog0(t *testing.T) { + ctx := context.Background() + db, _ := setupMaxConnsTest(t, ctx, "--back-log", "0") + if t.Failed() { + return + } + _, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + require.ErrorIs(t, err, mysql.ErrInvalidConn) +} + +func testMaxConns3Timeout10s(t *testing.T) { + ctx := context.Background() + db, _ := setupMaxConnsTest(t, ctx, "--max-connections-timeout", "10s") + if t.Failed() { + return + } + start := time.Now() + _, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + elapsed := time.Since(start) + require.ErrorIs(t, err, mysql.ErrInvalidConn) + require.True(t, elapsed > 9 * time.Second, "it took more than 9 seconds to fail") + require.True(t, elapsed < 12 * time.Second, "it took less than 12 seconds to fail") +} + +func testMaxConns3(t *testing.T) { + ctx := context.Background() + db, conns := setupMaxConnsTest(t, ctx) + if t.Failed() { + return + } + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test4223')") + return err + }) + eg.Go(func() error { + conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test9119')") + return err + }) + conns[0].Close() + conns[0] = nil + require.NoError(t, eg.Wait()) + ctx = context.Background() + rows, err := conns[1].QueryContext(ctx, `SELECT * FROM test_table ORDER BY str ASC`) + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + var id int + var str string + require.NoError(t, rows.Scan(&id, &str)) + require.Equal(t, "test4223", str) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &str)) + require.Equal(t, "test9119", str) + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) +} + +func testMaxConns3BackLog1(t *testing.T) { + ctx := context.Background() + db, conns := setupMaxConnsTest(t, ctx, "--back-log", "1") + if t.Failed() { + return + } + eg, ctx := errgroup.WithContext(ctx) + done := make(chan struct{}) + eg.Go(func() error { + conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + if err != nil { + return err + } + defer func() { + // Keep this connection alive until the other function + // has a chance to try to connect and fail. + <-done + conn.Close() + }() + _, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test4223')") + return err + }) + eg.Go(func() error { + defer close(done) + time.Sleep(1 * time.Second) + _, err := db.Conn(driver.WithConnectRetriesDisabled(ctx)) + if !assert.ErrorIs(t, err, mysql.ErrInvalidConn) { + return errors.New("unexpected test failure") + } + return nil + }) + <-done + conns[0].Close() + conns[0] = nil + require.NoError(t, eg.Wait()) + ctx = context.Background() + rows, err := conns[1].QueryContext(ctx, `SELECT * FROM test_table`) + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + var id int + var str string + require.NoError(t, rows.Scan(&id, &str)) + require.Equal(t, "test4223", str) + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) +}