Added a debug port option to launch dlv, and fixed problems with graceful shutdown on windows

This commit is contained in:
Zach Musgrave
2023-05-19 18:38:38 -07:00
parent b4c7e26e04
commit a46ebe1f9f
4 changed files with 80 additions and 53 deletions

View File

@@ -25,12 +25,14 @@ import (
"os/exec"
"path/filepath"
"sync"
"syscall"
"time"
_ "github.com/go-sql-driver/mysql"
)
var DoltPath string
var DelvePath string
const TestUserName = "Bats Tests"
const TestEmailAddress = "bats@email.fake"
@@ -45,10 +47,13 @@ func init() {
}
path = filepath.Clean(path)
var err error
DoltPath, err = exec.LookPath(path)
if err != nil {
log.Printf("did not find dolt binary: %v\n", err.Error())
}
DelvePath, _ = exec.LookPath("dlv")
}
// DoltUser is an abstraction for a user account that calls `dolt` CLI
@@ -66,8 +71,11 @@ type DoltUser struct {
tmpdir string
}
var _ DoltCmdable = DoltUser{}
var _ DoltDebuggable = DoltUser{}
func NewDoltUser() (DoltUser, error) {
tmpdir, err := os.MkdirTemp("", "go-sql-server-dirver-")
tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-")
if err != nil {
return DoltUser{}, err
}
@@ -91,9 +99,33 @@ func (u DoltUser) DoltCmd(args ...string) *exec.Cmd {
cmd := exec.Command(DoltPath, args...)
cmd.Dir = u.tmpdir
cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir)
// TODO: only on windows
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
return cmd
}
func (u DoltUser) DoltDebug(debuggerPort int, args ...string) *exec.Cmd {
if DelvePath != "" {
dlvArgs := []string {
fmt.Sprintf("--listen=:%d", debuggerPort),
"--headless",
"--api-version=2",
"--accept-multiclient",
"exec",
DoltPath,
"--",
}
cmd := exec.Command(DelvePath, append(dlvArgs, args...)...)
cmd.Dir = u.tmpdir
cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir)
return cmd
} else {
panic("dlv not found")
}
}
func (u DoltUser) DoltExec(args ...string) error {
cmd := u.DoltCmd(args...)
return cmd.Run()
@@ -116,6 +148,9 @@ type RepoStore struct {
Dir string
}
var _ DoltCmdable = RepoStore{}
var _ DoltDebuggable = RepoStore{}
func (rs RepoStore) MakeRepo(name string) (Repo, error) {
path := filepath.Join(rs.Dir, name)
err := os.Mkdir(path, 0750)
@@ -136,6 +171,12 @@ func (rs RepoStore) DoltCmd(args ...string) *exec.Cmd {
return cmd
}
func (rs RepoStore) DoltDebug(debuggerPort int, args ...string) *exec.Cmd {
cmd := rs.user.DoltDebug(debuggerPort, args...)
cmd.Dir = rs.Dir
return cmd
}
type Repo struct {
user DoltUser
Dir string
@@ -191,11 +232,29 @@ func WithPort(port int) SqlServerOpt {
}
type DoltCmdable interface {
DoltCmd(...string) *exec.Cmd
DoltCmd(args ...string) *exec.Cmd
}
type DoltDebuggable interface {
DoltDebug(debuggerPort int, args ...string) *exec.Cmd
}
func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) {
cmd := dc.DoltCmd("sql-server")
return runSqlServerCommand(dc, opts, cmd)
}
func DebugSqlServer(dc DoltCmdable, debuggerPort int, opts ...SqlServerOpt) (*SqlServer, error) {
ddb, ok := dc.(DoltDebuggable)
if !ok {
return nil, fmt.Errorf("%T does not implement DoltDebuggable", dc)
}
cmd := ddb.DoltDebug(debuggerPort, "sql-server")
return runSqlServerCommand(dc, opts, cmd)
}
func runSqlServerCommand(dc DoltCmdable, opts []SqlServerOpt, cmd *exec.Cmd) (*SqlServer, error) {
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, err
@@ -232,10 +291,6 @@ func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) {
return ret, nil
}
func (r Repo) StartSqlServer(opts ...SqlServerOpt) (*SqlServer, error) {
return StartSqlServer(r, opts...)
}
func (s *SqlServer) ErrorStop() error {
<-s.Done
return s.Cmd.Wait()

View File

@@ -15,61 +15,21 @@
package sql_server_driver
import (
"syscall"
"golang.org/x/sys/windows"
)
func (s *SqlServer) GracefulStop() error {
dll, err := windows.LoadDLL("kernel32.dll")
err := windows.GenerateConsoleCtrlEvent(windows.CTRL_BREAK_EVENT, uint32(s.Cmd.Process.Pid))
if err != nil {
return err
}
defer dll.Release()
pid := s.Cmd.Process.Pid
f, err := dll.FindProc("AttachConsole")
if err != nil {
return err
}
r1, _, err := f.Call(uintptr(pid))
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
return err
}
set, err := dll.FindProc("SetConsoleCtrlHandler")
if err != nil {
return err
}
r1, _, err = set.Call(0, 1)
if r1 == 0 {
return err
}
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
if err != nil {
return err
}
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
if r1 == 0 {
return err
}
f, err = dll.FindProc("FreeConsole")
if err != nil {
return err
}
_, _, err = f.Call()
if err != nil {
return err
}
<-s.Done
r1, _, err = set.Call(0, 0)
if r1 == 0 {
_, err = s.Cmd.Process.Wait()
if err != nil {
return err
}
return s.Cmd.Wait()
return nil
}

View File

@@ -162,6 +162,10 @@ type Server struct {
// the |Args| to make sure this is true. Defaults to 3308.
Port int `yaml:"port"`
// DebugPort if set to a non-zero value will cause this server to be started with |dlv| listening for a debugger
// connection on the port given.
DebugPort int `yaml:"debug_port"`
// Assertions to be run against the log output of the server process
// after the server process successfully terminates.
LogMatches []string `yaml:"log_matches"`

View File

@@ -22,7 +22,7 @@ import (
"time"
"database/sql"
driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -78,7 +78,15 @@ func MakeServer(t *testing.T, dc driver.DoltCmdable, s *driver.Server) *driver.S
if s.Port != 0 {
opts = append(opts, driver.WithPort(s.Port))
}
server, err := driver.StartSqlServer(dc, opts...)
var server *driver.SqlServer
var err error
if s.DebugPort != 0 {
server, err = driver.DebugSqlServer(dc, s.DebugPort, opts...)
} else {
server, err = driver.StartSqlServer(dc, opts...)
}
require.NoError(t, err)
if len(s.ErrorMatches) > 0 {
err := server.ErrorStop()