Merge branch 'main' into james/index

This commit is contained in:
James Cor
2023-05-09 09:39:13 -07:00
3 changed files with 99 additions and 22 deletions

View File

@@ -778,7 +778,7 @@ func NameAndTypeTransform(row sql.Row, rowOperationSchema sql.PrimaryKeySchema,
// For non string types we want empty strings to be converted to nils. String types should be allowed to take on
// an empty string value
switch col.Type.(type) {
case sql.StringType:
case sql.StringType, sql.EnumType, sql.SetType:
default:
row[i] = emptyStringToNil(row[i])
}

View File

@@ -43,12 +43,13 @@ var batseeDoc = cli.CommandDocumentationContent{
Output for each test is written to a file in the batsee_output directory.
Example: batsee -t 42 --max-time 1h15m -r 2 --only types.bats,foreign-keys.bats`,
Synopsis: []string{
`[-t {{.LessThan}}threads{{.GreaterThan}}] [--skip-slow] [--max-time {{.LessThan}}time{{.GreaterThan}}] [--retries {{.LessThan}}retries{{.GreaterThan}}] [--only test1,test2,...]`,
`[-t threads] [-o dir] [--skip-slow] [--max-time time] [--only test1,test2,...]`,
},
}
const (
threadsFlag = "threads"
outputDir = "output"
skipSlowFlag = "skip-slow"
maxTimeFlag = "max-time"
onlyFLag = "only"
@@ -57,7 +58,8 @@ const (
func buildArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithMaxArgs("batsee", 0)
ap.SupportsUint(threadsFlag, "t", "threads", "Number of tests to execute in parallel. Defaults to 12")
ap.SupportsInt(threadsFlag, "t", "threads", "Number of tests to execute in parallel. Defaults to 12")
ap.SupportsString(outputDir, "o", "directory", "Directory to write output to. Defaults to 'batsee_results'")
ap.SupportsFlag(skipSlowFlag, "s", "Skip slow tests. This is a static list of test we know are slow, may grow stale.")
ap.SupportsString(maxTimeFlag, "", "duration", "Maximum time to run tests. Defaults to 30m")
ap.SupportsString(onlyFLag, "", "", "Only run the specified test, or tests (comma separated)")
@@ -87,17 +89,26 @@ var slowCommands = map[string]bool{
"remotes.bats": true,
}
func main() {
ap := buildArgParser()
help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString("batsee", batseeDoc, ap))
args := os.Args[1:]
apr := cli.ParseArgsOrDie(ap, args, help)
type config struct {
threads int
output string
duration time.Duration
skipSlow bool
limitTo map[string]bool
retries int
}
threads, hasThreads := apr.GetUint(threadsFlag)
func buildConfig(apr *argparser.ArgParseResults) config {
threads, hasThreads := apr.GetInt(threadsFlag)
if !hasThreads {
threads = 12
}
output, hasOutput := apr.GetValue(outputDir)
if !hasOutput {
output = "batsee_results"
}
durationInput, hasDuration := apr.GetValue(maxTimeFlag)
if !hasDuration {
durationInput = "30m"
@@ -124,6 +135,24 @@ func main() {
retries = 1
}
return config{
threads: threads,
output: output,
duration: duration,
skipSlow: skipSlow,
limitTo: limitTo,
retries: retries,
}
}
func main() {
ap := buildArgParser()
help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString("batsee", batseeDoc, ap))
args := os.Args[1:]
apr := cli.ParseArgsOrDie(ap, args, help)
config := buildConfig(apr)
startTime := time.Now()
cwd, err := os.Getwd()
@@ -147,7 +176,7 @@ func main() {
workQueue := []string{}
// Insert the slow tests first
for key, _ := range slowCommands {
if !skipSlow {
if !config.skipSlow {
workQueue = append(workQueue, key)
}
}
@@ -166,15 +195,15 @@ func main() {
ctx := context.Background()
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()
ctx, cancel := context.WithTimeout(ctx, duration)
ctx, cancel := context.WithTimeout(ctx, config.duration)
defer cancel()
var wg sync.WaitGroup
for i := uint64(0); i < threads; i++ {
for i := 0; i < config.threads; i++ {
go func() {
wg.Add(1)
defer wg.Done()
worker(jobs, retries, results, ctx, limitTo)
worker(jobs, results, ctx, config)
}()
}
@@ -183,7 +212,7 @@ func main() {
}
close(jobs)
cli.Println(fmt.Sprintf("Waiting for workers (%d) to finish", threads))
cli.Println(fmt.Sprintf("Waiting for workers (%d) to finish", config.threads))
comprehensiveWait(ctx, &wg)
close(results)
@@ -261,24 +290,24 @@ func durationStr(duration time.Duration) string {
return fmt.Sprintf("%02d:%02d", int(duration.Minutes()), int(duration.Seconds())%60)
}
func worker(jobs <-chan string, retries int, results chan<- batsResult, ctx context.Context, limitTo map[string]bool) {
func worker(jobs <-chan string, results chan<- batsResult, ctx context.Context, config config) {
for job := range jobs {
runBats(job, retries, results, ctx, limitTo)
runBats(job, results, ctx, config)
}
}
// runBats runs a single bats test and sends the result to the results channel. Stdout and stderr are written to files
// in the batsee_results directory in the CWD, and the error is written to the result.err field.
func runBats(path string, retries int, resultChan chan<- batsResult, ctx context.Context, limitTo map[string]bool) {
func runBats(path string, resultChan chan<- batsResult, ctx context.Context, cfg config) {
cmd := exec.CommandContext(ctx, "bats", path)
// Set the process group ID so that we can kill the entire process tree if it runs too long. We need to differenciate
// process group of the sub process from this one, because kill the primary process if we don't.
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
cmd.Env = append(os.Environ(), fmt.Sprintf("DOLT_TEST_RETRIES=%d", retries))
cmd.Env = append(os.Environ(), fmt.Sprintf("DOLT_TEST_RETRIES=%d", cfg.retries))
result := batsResult{path: path}
if limitTo != nil && len(limitTo) != 0 && !limitTo[path] {
if cfg.limitTo != nil && len(cfg.limitTo) != 0 && !cfg.limitTo[path] {
result.skipped = true
resultChan <- result
return
@@ -286,7 +315,7 @@ func runBats(path string, retries int, resultChan chan<- batsResult, ctx context
startTime := time.Now()
outPath := fmt.Sprintf("batsee_results/%s.stdout.log", path)
outPath := fmt.Sprintf("%s/%s.stdout.log", cfg.output, path)
output, err := os.Create(outPath)
if err != nil {
cli.Println("Error creating stdout log:", err.Error())
@@ -299,7 +328,7 @@ func runBats(path string, retries int, resultChan chan<- batsResult, ctx context
result.err = err
}
errPath := fmt.Sprintf("batsee_results/%s.stderr.log", path)
errPath := fmt.Sprintf("%s/%s.stderr.log", cfg.output, path)
errput, err := os.Create(errPath)
if err != nil {
cli.Println("Error creating stderr log:", err.Error())

View File

@@ -1315,4 +1315,52 @@ v1,v2
DELIM
dolt table import -u tbl auto-increment.csv
}
}
@test "import-update-tables: distinguish between empty string and null for ENUMs" {
dolt sql <<SQL
create table alphabet(pk int primary key, letter enum('', 'a', 'b'));
SQL
dolt commit -Am "add a table"
expected=$(cat <<DELIM
pk,letter
1,a
2,""
3,
DELIM
)
echo "$expected" > data.csv
run dolt table import -u alphabet data.csv
[ $status -eq 0 ]
[[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false
run dolt sql -r csv -q "select * from alphabet;"
[ $status -eq 0 ]
[[ "$output" = "$expected" ]] || false
}
@test "import-update-tables: distinguish between empty string and null for SETs" {
dolt sql <<SQL
create table word(pk int primary key, letters set('', 'a', 'b'));
SQL
dolt commit -Am "add a table"
expected=$(cat <<DELIM
pk,letters
1,"a,b"
2,a
3,""
4,
DELIM
)
echo "$expected" > word_data.csv
run dolt table import -u word word_data.csv
[ $status -eq 0 ]
[[ "$output" =~ "Rows Processed: 4, Additions: 4, Modifications: 0, Had No Effect: 0" ]] || false
run dolt sql -r csv -q "select * from word;"
[ $status -eq 0 ]
[[ "$output" = "$expected" ]] || false
}