integration-tests/go-sql-server-driver: Some cleanup to finalization and sql.Rows handling.

This commit is contained in:
Aaron Son
2022-10-03 15:14:37 -07:00
parent bf7480eb1b
commit 04947a44a0
@@ -22,6 +22,7 @@ import (
"database/sql"
"database/sql/driver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)
@@ -185,9 +186,9 @@ func MakeRepo(t *testing.T, rs RepoStore, r TestRepo) Repo {
return repo
}
func MakeServer(t *testing.T, dc DoltCmdable, s *Server) (*SqlServer, func()) {
func MakeServer(t *testing.T, dc DoltCmdable, s *Server) *SqlServer {
if s == nil {
return nil, nil
return nil
}
opts := []SqlServerOpt{WithArgs(s.Args...)}
if s.Port != 0 {
@@ -202,16 +203,21 @@ func MakeServer(t *testing.T, dc DoltCmdable, s *Server) (*SqlServer, func()) {
for _, a := range s.ErrorMatches {
require.Regexp(t, a, output)
}
return nil, nil
return nil
} else {
return server, func() {
t.Cleanup(func() {
// We use assert, not require here, since FailNow() in
// a Cleanup does not make sense.
err := server.GracefulStop()
require.NoError(t, err)
output := string(server.Output.Bytes())
for _, a := range s.LogMatches {
require.Regexp(t, a, output)
if assert.NoError(t, err) {
output := string(server.Output.Bytes())
for _, a := range s.LogMatches {
assert.Regexp(t, a, output)
}
}
}
})
return server
}
}
@@ -227,20 +233,14 @@ func (test Test) Run(t *testing.T) {
servers := make(map[string]*SqlServer)
dbs := make(map[string]*sql.DB)
defer func() {
for _, db := range dbs {
db.Close()
}
}()
for _, r := range test.Repos {
repo := MakeRepo(t, rs, r)
server, close := MakeServer(t, repo, r.Server)
server := MakeServer(t, repo, r.Server)
if server != nil {
server.DBName = r.Name
servers[r.Name] = server
defer close()
db, err := server.DB()
require.NoError(t, err)
@@ -260,10 +260,9 @@ func (test Test) Run(t *testing.T) {
require.NoError(t, rs.WriteFile(f.Name, f.Contents))
}
server, close := MakeServer(t, rs, mr.Server)
server := MakeServer(t, rs, mr.Server)
if server != nil {
servers[mr.Name] = server
defer close()
db, err := server.DB()
require.NoError(t, err)
@@ -271,6 +270,12 @@ func (test Test) Run(t *testing.T) {
}
}
t.Cleanup(func() {
for _, db := range dbs {
db.Close()
}
})
for i, c := range test.Conns {
db := dbs[c.On]
require.NotNilf(t, db, "error in test spec: could not find database %s for connection %d", c.On, i)
@@ -307,22 +312,25 @@ func RunTestsFile(t *testing.T, path string) {
}
type retryTestingT struct {
*testing.T
errorfStrings []string
errorfArgs [][]interface{}
failNow bool
}
func (r *retryTestingT) Errorf(format string, args ...interface{}) {
r.T.Helper()
r.errorfStrings = append(r.errorfStrings, format)
r.errorfArgs = append(r.errorfArgs, args)
}
func (r *retryTestingT) FailNow() {
r.T.Helper()
r.failNow = true
panic(r)
}
func RetryTestRun(t require.TestingT, attempts int, test func(require.TestingT)) {
func RetryTestRun(t *testing.T, attempts int, test func(require.TestingT)) {
if attempts == 0 {
attempts = 1
}
@@ -332,6 +340,7 @@ func RetryTestRun(t require.TestingT, attempts int, test func(require.TestingT))
time.Sleep(50 * time.Millisecond)
}
rtt = new(retryTestingT)
rtt.T = t
func() {
defer func() {
if r := recover(); r != nil {
@@ -355,7 +364,7 @@ func RetryTestRun(t require.TestingT, attempts int, test func(require.TestingT))
}
}
func RunQuery(t require.TestingT, conn *sql.Conn, q Query) {
func RunQuery(t *testing.T, conn *sql.Conn, q Query) {
RetryTestRun(t, q.RetryAttempts, func(t require.TestingT) {
RunQueryAttempt(t, conn, q)
})
@@ -368,36 +377,23 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q Query) {
}
if q.Query != "" {
rows, err := conn.QueryContext(context.Background(), q.Query, args...)
if err == nil {
defer rows.Close()
}
if q.ErrorMatch != "" {
require.Error(t, err)
require.Regexp(t, q.ErrorMatch, err.Error())
return
}
require.NoError(t, err)
defer rows.Close()
cols, err := rows.Columns()
require.NoError(t, err)
require.Equal(t, q.Result.Columns, cols)
for _, r := range q.Result.Rows {
require.True(t, rows.Next())
scanned := make([]any, len(r))
for j := range scanned {
scanned[j] = new(sql.NullString)
}
require.NoError(t, rows.Scan(scanned...))
printed := make([]string, len(r))
for j := range scanned {
s := scanned[j].(*sql.NullString)
if !s.Valid {
printed[j] = "NULL"
} else {
printed[j] = s.String
}
}
require.Equal(t, r, printed)
}
require.False(t, rows.Next())
require.NoError(t, rows.Err())
rowstrings, err := RowsToStrings(len(cols), rows)
require.NoError(t, err)
require.Equal(t, q.Result.Rows, rowstrings)
} else if q.Exec != "" {
_, err := conn.ExecContext(context.Background(), q.Exec, args...)
if q.ErrorMatch == "" {
@@ -408,3 +404,28 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q Query) {
}
}
}
func RowsToStrings(cols int, rows *sql.Rows) ([][]string, error) {
ret := make([][]string, 0)
for rows.Next() {
scanned := make([]any, cols)
for j := range scanned {
scanned[j] = new(sql.NullString)
}
err := rows.Scan(scanned...)
if err != nil {
return nil, err
}
printed := make([]string, cols)
for j := range scanned {
s := scanned[j].(*sql.NullString)
if !s.Valid {
printed[j] = "NULL"
} else {
printed[j] = s.String
}
}
ret = append(ret, printed)
}
return ret, rows.Err()
}