Files
dolt/integration-tests/go-sql-server-driver/cmd.go

282 lines
5.9 KiB
Go

// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bytes"
"database/sql"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
var DoltPath string
const TestUserName = "Bats Tests"
const TestEmailAddress = "bats@email.fake"
const ConnectAttempts = 50
const RetrySleepDuration = 50 * time.Millisecond
func init() {
path := os.Getenv("DOLT_BIN_PATH")
if path == "" {
path = "dolt"
}
path = filepath.Clean(path)
var err error
DoltPath, err = exec.LookPath(path)
if err != nil {
panic(fmt.Sprintf("did not find dolt binary: %v", err.Error()))
}
}
// DoltUser is an abstraction for a user account that calls `dolt` CLI
// commands. All of our dolt binary invocations are done through DoltUser.
//
// For our purposes, it does the following:
// * owns a tmpdir, to which it sets DOLT_ROOT_PATH when invoking dolt.
// * some initial dolt global config,
// - user.name
// - user.email
// - metrics.disabled = true
//
// * can create repo stores, which will be a tmpdir to store a repo and/or subrepos.
type DoltUser struct {
tmpdir string
}
func NewDoltUser() (DoltUser, error) {
tmpdir, err := os.MkdirTemp("", "go-sql-server-dirver-")
if err != nil {
return DoltUser{}, err
}
res := DoltUser{tmpdir}
err = res.DoltExec("config", "--global", "--add", "metrics.disabled", "true")
if err != nil {
return DoltUser{}, err
}
err = res.DoltExec("config", "--global", "--add", "user.name", TestUserName)
if err != nil {
return DoltUser{}, err
}
err = res.DoltExec("config", "--global", "--add", "user.email", TestEmailAddress)
if err != nil {
return DoltUser{}, err
}
return res, nil
}
func (u DoltUser) DoltCmd(args ...string) *exec.Cmd {
cmd := exec.Command(DoltPath, args...)
cmd.Dir = u.tmpdir
cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir)
return cmd
}
func (u DoltUser) DoltExec(args ...string) error {
cmd := u.DoltCmd(args...)
return cmd.Run()
}
func (u DoltUser) MakeRepoStore() (RepoStore, error) {
tmpdir, err := os.MkdirTemp(u.tmpdir, "repo-store-")
if err != nil {
return RepoStore{}, err
}
return RepoStore{u, tmpdir}, nil
}
type RepoStore struct {
user DoltUser
dir string
}
func (rs RepoStore) MakeRepo(name string) (Repo, error) {
path := filepath.Join(rs.dir, name)
err := os.Mkdir(path, 0750)
if err != nil {
return Repo{}, err
}
ret := Repo{rs.user, path}
err = ret.DoltExec("init")
if err != nil {
return Repo{}, err
}
return ret, nil
}
func (rs RepoStore) DoltCmd(args ...string) *exec.Cmd {
cmd := rs.user.DoltCmd(args...)
cmd.Dir = rs.dir
return cmd
}
type Repo struct {
user DoltUser
dir string
}
func (r Repo) DoltCmd(args ...string) *exec.Cmd {
cmd := r.user.DoltCmd(args...)
cmd.Dir = r.dir
return cmd
}
func (r Repo) DoltExec(args ...string) error {
cmd := r.DoltCmd(args...)
err := cmd.Start()
if err != nil {
return err
}
return cmd.Wait()
}
func (r Repo) CreateRemote(name, url string) error {
cmd := r.DoltCmd("remote", "add", name, url)
return cmd.Run()
}
type SqlServer struct {
Done chan struct{}
Cmd *exec.Cmd
Port int
Output *bytes.Buffer
DBName string
RecreateCmd func(args ...string) *exec.Cmd
}
type SqlServerOpt func(s *SqlServer)
func WithArgs(args ...string) SqlServerOpt {
return func(s *SqlServer) {
s.Cmd.Args = append(s.Cmd.Args, args...)
}
}
func WithPort(port int) SqlServerOpt {
return func(s *SqlServer) {
s.Port = port
}
}
type DoltCmdable interface {
DoltCmd(...string) *exec.Cmd
}
func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) {
cmd := dc.DoltCmd("sql-server")
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
cmd.Stderr = cmd.Stdout
output := new(bytes.Buffer)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
io.Copy(io.MultiWriter(os.Stdout, output), stdout)
}()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
ret := &SqlServer{
Done: done,
Cmd: cmd,
Port: 3306,
Output: output,
RecreateCmd: func(args ...string) *exec.Cmd {
return dc.DoltCmd(args...)
},
}
for _, o := range opts {
o(ret)
}
err = ret.Cmd.Start()
if err != nil {
return nil, err
}
return ret, nil
}
func (r Repo) StartSqlServer(opts ...SqlServerOpt) (*SqlServer, error) {
return StartSqlServer(r, opts...)
}
func (s *SqlServer) ErrorStop() error {
<-s.Done
return s.Cmd.Wait()
}
func (s *SqlServer) Restart(newargs *[]string) error {
err := s.GracefulStop()
if err != nil {
return err
}
args := s.Cmd.Args[1:]
if newargs != nil {
args = append([]string{"sql-server"}, (*newargs)...)
}
s.Cmd = s.RecreateCmd(args...)
stdout, err := s.Cmd.StdoutPipe()
if err != nil {
return err
}
s.Cmd.Stderr = s.Cmd.Stdout
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
io.Copy(io.MultiWriter(os.Stdout, s.Output), stdout)
}()
s.Done = make(chan struct{})
go func() {
wg.Wait()
close(s.Done)
}()
return s.Cmd.Start()
}
func (s *SqlServer) DB() (*sql.DB, error) {
authority := "root"
location := fmt.Sprintf("tcp(127.0.0.1:%d)", s.Port)
dbname := s.DBName
dsn := fmt.Sprintf("%s@%s/%s?allowAllFiles=true&tls=preferred", authority, location, dbname)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
for i := 0; i < ConnectAttempts; i++ {
err = db.Ping()
if err == nil {
return db, nil
}
time.Sleep(RetrySleepDuration)
}
if err != nil {
return nil, err
}
return db, nil
}