integration-tests/go-sql-server-driver: Implement max-connections testing to go-sql-server-driver.

This commit is contained in:
Aaron Son
2025-03-25 09:19:31 -07:00
parent 24c4e7ca20
commit 2ff00c2afc
6 changed files with 299 additions and 9 deletions

View File

@@ -18,7 +18,7 @@ require (
github.com/dustin/go-humanize v1.0.1
github.com/fatih/color v1.13.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/go-sql-driver/mysql v1.9.1
github.com/gocraft/dbr/v2 v2.7.2
github.com/golang/snappy v0.0.4
github.com/google/uuid v1.6.0

View File

@@ -297,8 +297,8 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d h1:QQP1nE4qh5aHTGvI1LgOFxZYVxYoGeMfbNHikogPyoA=
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-sql-driver/mysql v1.9.1 h1:FrjNGn/BsJQjVRuSa8CBrM5BWA9BWoXXat3KrtSb/iI=
github.com/go-sql-driver/mysql v1.9.1/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=

View File

@@ -17,7 +17,9 @@ package sql_server_driver
import (
"bufio"
"bytes"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"log"
@@ -28,6 +30,7 @@ import (
"sync"
"time"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
)
@@ -379,15 +382,38 @@ func (s *SqlServer) Restart(newargs *[]string, newenvs *[]string) error {
}
func (s *SqlServer) DB(c Connection) (*sql.DB, error) {
var pass string
connector, err := s.Connector(c)
if err != nil {
return nil, err
}
return OpenDB(connector)
}
// If a test needs to circumvent the database/sql connection pool
// it can use the raw MySQL connector.
func (s *SqlServer) Connector(c Connection) (driver.Connector, error) {
pass, err := c.Password()
if err != nil {
return nil, err
}
return ConnectDB(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams)
dsn := GetDSN(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams)
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
// See the comment on WithConnectRetriesDisabled for why we do this.
cfg.Apply(mysql.BeforeConnect(func(ctx context.Context, cfg *mysql.Config) error {
// TODO: This could be more robust if we sniffed it on first connect.
const numAttemptsGoLibraryMakes = 3
if attempt, ok := incrementConnectRetryAttempts(ctx); ok && attempt < numAttemptsGoLibraryMakes {
return driver.ErrBadConn
}
return nil
}))
return mysql.NewConnector(cfg)
}
func ConnectDB(user, password, name, host string, port int, driverParams map[string]string) (*sql.DB, error) {
func GetDSN(user, password, name, host string, port int, driverParams map[string]string) string {
params := make(url.Values)
params.Set("allowAllFiles", "true")
params.Set("tls", "preferred")
@@ -395,7 +421,27 @@ func ConnectDB(user, password, name, host string, port int, driverParams map[str
params.Set(k, v)
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", user, password, host, port, name, params.Encode())
return dsn
}
func OpenDB(connector driver.Connector) (*sql.DB, error) {
db := sql.OpenDB(connector)
var err error
for i := 0; i < ConnectAttempts; i++ {
err = db.Ping()
if err == nil {
return db, nil
}
time.Sleep(RetrySleepDuration)
}
if err != nil {
return nil, err
}
return db, nil
}
func ConnectDB(user, password, name, host string, port int, params map[string]string) (*sql.DB, error) {
dsn := GetDSN(user, password, name, host, port, params)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
@@ -412,3 +458,40 @@ func ConnectDB(user, password, name, host string, port int, driverParams map[str
}
return db, nil
}
type connectRetryAttemptKeyType int
var connectRetryAttemptKey connectRetryAttemptKeyType
// The database/sql package in Go takes connections out of a
// connection pool or opens a new connection to the database. It has
// logic in it that looks for driver.ErrBadConn responses when it is
// opening a connection and automatically retries those connections a
// fixed number of times before actually surfacing the error to the
// caller. This behavior interferes with some testing we want to do
// against the behavior of the server.
//
// WithConnectionRetriesDisabled is a hack which circumvents these
// retries. It works by embedding a counter into the returned context,
// which should then be passed to *sql.DB.Conn(). An interceptor has
// been installed on the |mysql.Connector| which looks for this
// counter and fast-fails the first few calls into the driver. That
// way the first call which goes through to the driver is the last and
// final retry from *sql.DB.
func WithConnectRetriesDisabled(ctx context.Context) context.Context {
return context.WithValue(ctx, connectRetryAttemptKey, &retryAttempt{})
}
type retryAttempt struct {
attempt int
}
func incrementConnectRetryAttempts(ctx context.Context) (int, bool) {
v := ctx.Value(connectRetryAttemptKey)
if v != nil {
if v, ok := v.(*retryAttempt); ok {
v.attempt += 1
return v.attempt, true
}
}
return 0, false
}

View File

@@ -15,7 +15,7 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/creasty/defaults v1.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d // indirect
github.com/go-sql-driver/mysql v1.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.35.0 // indirect
golang.org/x/sys v0.30.0 // indirect

View File

@@ -4,8 +4,8 @@ github.com/creasty/defaults v1.6.0 h1:ltuE9cfphUtlrBeomuu8PEyISTXnxqkBIoQfXgv7BS
github.com/creasty/defaults v1.6.0/go.mod h1:iGzKe6pbEHnpMPtfDXZEr0NVxWnPTjb1bbDy08fPzYM=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d h1:QQP1nE4qh5aHTGvI1LgOFxZYVxYoGeMfbNHikogPyoA=
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-sql-driver/mysql v1.9.1 h1:FrjNGn/BsJQjVRuSa8CBrM5BWA9BWoXXat3KrtSb/iI=
github.com/go-sql-driver/mysql v1.9.1/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=

View File

@@ -0,0 +1,207 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"context"
"database/sql"
"errors"
"testing"
"time"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
)
func TestSqlServerMaxConns(t *testing.T) {
t.Run("MaxConns 3", testMaxConns3)
t.Run("MaxConns 3 BackLog 0", testMaxConns3BackLog0)
t.Run("MaxConns 3 BackLog 1", testMaxConns3BackLog1)
t.Run("MaxConns 3 MaxConnectionsTimeout 10s", testMaxConns3Timeout10s)
}
func setupMaxConnsTest(t *testing.T, ctx context.Context, args ...string) (*sql.DB, []*sql.Conn) {
// TODO: These tests should run parallel once |main| is merged.
// Add "--port", `{{get_port "server"}}` to the args and add t.Parallel() to this function.
u, err := driver.NewDoltUser()
require.NoError(t, err)
t.Cleanup(func() {
u.Cleanup()
})
rs, err := u.MakeRepoStore()
require.NoError(t, err)
repo, err := rs.MakeRepo("max_conns_test")
require.NoError(t, err)
args = append(args, "--max-connections", "3")
srvSettings := &driver.Server{
Args: args,
}
server := MakeServer(t, repo, srvSettings)
server.DBName = "max_conns_test"
db, err := server.DB(driver.Connection{User: "root"})
require.NoError(t, err)
t.Cleanup(func() {
db.Close()
})
db.SetMaxIdleConns(0)
var conns []*sql.Conn
t.Cleanup(func() {
closeAll(conns)
})
for i := 0; i < 3; i++ {
conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
require.NoError(t, err)
conns = append(conns, conn)
}
_, err = conns[0].ExecContext(ctx, `
CREATE TABLE test_table (
id INT AUTO_INCREMENT PRIMARY KEY,
str VARCHAR(20)
);`)
require.NoError(t, err)
return db, conns
}
func closeAll(conns []*sql.Conn) {
for i, c := range conns {
if c != nil {
c.Close()
}
conns[i] = nil
}
}
func testMaxConns3BackLog0(t *testing.T) {
ctx := context.Background()
db, _ := setupMaxConnsTest(t, ctx, "--back-log", "0")
if t.Failed() {
return
}
_, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
require.ErrorIs(t, err, mysql.ErrInvalidConn)
}
func testMaxConns3Timeout10s(t *testing.T) {
ctx := context.Background()
db, _ := setupMaxConnsTest(t, ctx, "--max-connections-timeout", "10s")
if t.Failed() {
return
}
start := time.Now()
_, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
elapsed := time.Since(start)
require.ErrorIs(t, err, mysql.ErrInvalidConn)
require.True(t, elapsed > 9 * time.Second, "it took more than 9 seconds to fail")
require.True(t, elapsed < 12 * time.Second, "it took less than 12 seconds to fail")
}
func testMaxConns3(t *testing.T) {
ctx := context.Background()
db, conns := setupMaxConnsTest(t, ctx)
if t.Failed() {
return
}
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
if err != nil {
return err
}
defer conn.Close()
_, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test4223')")
return err
})
eg.Go(func() error {
conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
if err != nil {
return err
}
defer conn.Close()
_, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test9119')")
return err
})
conns[0].Close()
conns[0] = nil
require.NoError(t, eg.Wait())
ctx = context.Background()
rows, err := conns[1].QueryContext(ctx, `SELECT * FROM test_table ORDER BY str ASC`)
require.NoError(t, err)
defer rows.Close()
require.True(t, rows.Next())
var id int
var str string
require.NoError(t, rows.Scan(&id, &str))
require.Equal(t, "test4223", str)
require.True(t, rows.Next())
require.NoError(t, rows.Scan(&id, &str))
require.Equal(t, "test9119", str)
require.False(t, rows.Next())
require.NoError(t, rows.Err())
}
func testMaxConns3BackLog1(t *testing.T) {
ctx := context.Background()
db, conns := setupMaxConnsTest(t, ctx, "--back-log", "1")
if t.Failed() {
return
}
eg, ctx := errgroup.WithContext(ctx)
done := make(chan struct{})
eg.Go(func() error {
conn, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
if err != nil {
return err
}
defer func() {
// Keep this connection alive until the other function
// has a chance to try to connect and fail.
<-done
conn.Close()
}()
_, err = conn.ExecContext(ctx, "insert into test_table (str) values ('test4223')")
return err
})
eg.Go(func() error {
defer close(done)
time.Sleep(1 * time.Second)
_, err := db.Conn(driver.WithConnectRetriesDisabled(ctx))
if !assert.ErrorIs(t, err, mysql.ErrInvalidConn) {
return errors.New("unexpected test failure")
}
return nil
})
<-done
conns[0].Close()
conns[0] = nil
require.NoError(t, eg.Wait())
ctx = context.Background()
rows, err := conns[1].QueryContext(ctx, `SELECT * FROM test_table`)
require.NoError(t, err)
defer rows.Close()
require.True(t, rows.Next())
var id int
var str string
require.NoError(t, rows.Scan(&id, &str))
require.Equal(t, "test4223", str)
require.False(t, rows.Next())
require.NoError(t, rows.Err())
}