mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-23 05:13:00 -05:00
fix dolt log pager panics with ctrl+c on Windows (#2789)
This commit is contained in:
@@ -16,13 +16,19 @@ package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/osutil"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
"github.com/dolthub/dolt/go/store/util/outputpager"
|
||||
)
|
||||
|
||||
func TestLog(t *testing.T) {
|
||||
@@ -38,3 +44,46 @@ func TestLog(t *testing.T) {
|
||||
meta, _ := commit.GetCommitMeta()
|
||||
require.Equal(t, "Bill Billerson", meta.Name)
|
||||
}
|
||||
|
||||
func TestLogSigterm(t *testing.T) {
|
||||
if osutil.IsWindows {
|
||||
t.Skip("Skipping test as function used is not supported on Windows")
|
||||
}
|
||||
|
||||
dEnv := createUninitializedEnv()
|
||||
err := dEnv.InitRepo(context.Background(), types.Format_Default, "Bill Billerson", "bigbillieb@fake.horse", env.DefaultInitBranch)
|
||||
|
||||
if err != nil {
|
||||
t.Error("Failed to init repo")
|
||||
}
|
||||
|
||||
cs, _ := doltdb.NewCommitSpec(env.DefaultInitBranch)
|
||||
commit, _ := dEnv.DoltDB.Resolve(context.Background(), cs, nil)
|
||||
cMeta, _ := commit.GetCommitMeta()
|
||||
cHash, _ := commit.HashOf()
|
||||
|
||||
outputpager.SetTestingArg(true)
|
||||
defer outputpager.SetTestingArg(false)
|
||||
|
||||
pager := outputpager.Start()
|
||||
defer pager.Stop()
|
||||
|
||||
chStr := cHash.String()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
pager.Writer.Write([]byte(fmt.Sprintf("\033[1;33mcommit %s \033[0m", chStr)))
|
||||
pager.Writer.Write([]byte(fmt.Sprintf("\nAuthor: %s <%s>", cMeta.Name, cMeta.Email)))
|
||||
|
||||
timeStr := cMeta.FormatTS()
|
||||
pager.Writer.Write([]byte(fmt.Sprintf("\nDate: %s", timeStr)))
|
||||
|
||||
formattedDesc := "\n\n\t" + strings.Replace(cMeta.Description, "\n", "\n\t", -1) + "\n\n"
|
||||
pager.Writer.Write([]byte(fmt.Sprintf(formattedDesc)))
|
||||
}
|
||||
|
||||
process, err := os.FindProcess(syscall.Getpid())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = process.Signal(syscall.SIGTERM)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -1338,6 +1338,7 @@ func (t *AlterableDoltTable) CreateForeignKey(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fkChecks.(int8) == 1 {
|
||||
root, foreignKey, err = creation.ResolveForeignKey(ctx, root, table, foreignKey, t.opts)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,10 +22,13 @@
|
||||
package outputpager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
flag "github.com/juju/gnuflag"
|
||||
goisatty "github.com/mattn/go-isatty"
|
||||
@@ -35,6 +38,7 @@ import (
|
||||
|
||||
var (
|
||||
noPager bool
|
||||
testing = false
|
||||
)
|
||||
|
||||
type Pager struct {
|
||||
@@ -45,8 +49,12 @@ type Pager struct {
|
||||
}
|
||||
|
||||
func Start() *Pager {
|
||||
if noPager || !IsStdoutTty() {
|
||||
return &Pager{os.Stdout, nil, nil, nil, nil}
|
||||
// `testing` is set to true only to test this function because when testing this function, stdout is not Terminal.
|
||||
// otherwise, it must be always false.
|
||||
if !testing {
|
||||
if noPager || !IsStdoutTty() {
|
||||
return &Pager{os.Stdout, nil, nil, nil, nil}
|
||||
}
|
||||
}
|
||||
|
||||
var lessPath string
|
||||
@@ -75,9 +83,28 @@ func Start() *Pager {
|
||||
cmd.Start()
|
||||
|
||||
p := &Pager{stdout, stdin, stdout, &sync.Mutex{}, make(chan struct{})}
|
||||
|
||||
interruptChannel := make(chan os.Signal, 1)
|
||||
signal.Notify(interruptChannel, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-interruptChannel:
|
||||
if ok {
|
||||
p.closePipe()
|
||||
p.doneCh <- struct{}{}
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
d.Chk.NoError(err)
|
||||
if err != nil {
|
||||
fmt.Printf("error occurred during exit: %s ", err)
|
||||
}
|
||||
p.closePipe()
|
||||
p.doneCh <- struct{}{}
|
||||
}()
|
||||
@@ -110,3 +137,7 @@ func RegisterOutputpagerFlags(flags *flag.FlagSet) {
|
||||
func IsStdoutTty() bool {
|
||||
return goisatty.IsTerminal(os.Stdout.Fd())
|
||||
}
|
||||
|
||||
func SetTestingArg(s bool) {
|
||||
testing = s
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user