diff --git a/go/cmd/dolt/commands/tblcmds/import.go b/go/cmd/dolt/commands/tblcmds/import.go index a1f0284d16..4445ed95ee 100644 --- a/go/cmd/dolt/commands/tblcmds/import.go +++ b/go/cmd/dolt/commands/tblcmds/import.go @@ -778,7 +778,7 @@ func NameAndTypeTransform(row sql.Row, rowOperationSchema sql.PrimaryKeySchema, // For non string types we want empty strings to be converted to nils. String types should be allowed to take on // an empty string value switch col.Type.(type) { - case sql.StringType: + case sql.StringType, sql.EnumType, sql.SetType: default: row[i] = emptyStringToNil(row[i]) } diff --git a/go/utils/batsee/main.go b/go/utils/batsee/main.go index 9918877747..b1e92c47df 100644 --- a/go/utils/batsee/main.go +++ b/go/utils/batsee/main.go @@ -43,12 +43,13 @@ var batseeDoc = cli.CommandDocumentationContent{ Output for each test is written to a file in the batsee_output directory. Example: batsee -t 42 --max-time 1h15m -r 2 --only types.bats,foreign-keys.bats`, Synopsis: []string{ - `[-t {{.LessThan}}threads{{.GreaterThan}}] [--skip-slow] [--max-time {{.LessThan}}time{{.GreaterThan}}] [--retries {{.LessThan}}retries{{.GreaterThan}}] [--only test1,test2,...]`, + `[-t threads] [-o dir] [--skip-slow] [--max-time time] [--only test1,test2,...]`, }, } const ( threadsFlag = "threads" + outputDir = "output" skipSlowFlag = "skip-slow" maxTimeFlag = "max-time" onlyFLag = "only" @@ -57,7 +58,8 @@ const ( func buildArgParser() *argparser.ArgParser { ap := argparser.NewArgParserWithMaxArgs("batsee", 0) - ap.SupportsUint(threadsFlag, "t", "threads", "Number of tests to execute in parallel. Defaults to 12") + ap.SupportsInt(threadsFlag, "t", "threads", "Number of tests to execute in parallel. Defaults to 12") + ap.SupportsString(outputDir, "o", "directory", "Directory to write output to. Defaults to 'batsee_results'") ap.SupportsFlag(skipSlowFlag, "s", "Skip slow tests. This is a static list of test we know are slow, may grow stale.") ap.SupportsString(maxTimeFlag, "", "duration", "Maximum time to run tests. Defaults to 30m") ap.SupportsString(onlyFLag, "", "", "Only run the specified test, or tests (comma separated)") @@ -87,17 +89,26 @@ var slowCommands = map[string]bool{ "remotes.bats": true, } -func main() { - ap := buildArgParser() - help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString("batsee", batseeDoc, ap)) - args := os.Args[1:] - apr := cli.ParseArgsOrDie(ap, args, help) +type config struct { + threads int + output string + duration time.Duration + skipSlow bool + limitTo map[string]bool + retries int +} - threads, hasThreads := apr.GetUint(threadsFlag) +func buildConfig(apr *argparser.ArgParseResults) config { + threads, hasThreads := apr.GetInt(threadsFlag) if !hasThreads { threads = 12 } + output, hasOutput := apr.GetValue(outputDir) + if !hasOutput { + output = "batsee_results" + } + durationInput, hasDuration := apr.GetValue(maxTimeFlag) if !hasDuration { durationInput = "30m" @@ -124,6 +135,24 @@ func main() { retries = 1 } + return config{ + threads: threads, + output: output, + duration: duration, + skipSlow: skipSlow, + limitTo: limitTo, + retries: retries, + } +} + +func main() { + ap := buildArgParser() + help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString("batsee", batseeDoc, ap)) + args := os.Args[1:] + apr := cli.ParseArgsOrDie(ap, args, help) + + config := buildConfig(apr) + startTime := time.Now() cwd, err := os.Getwd() @@ -147,7 +176,7 @@ func main() { workQueue := []string{} // Insert the slow tests first for key, _ := range slowCommands { - if !skipSlow { + if !config.skipSlow { workQueue = append(workQueue, key) } } @@ -166,15 +195,15 @@ func main() { ctx := context.Background() ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - ctx, cancel := context.WithTimeout(ctx, duration) + ctx, cancel := context.WithTimeout(ctx, config.duration) defer cancel() var wg sync.WaitGroup - for i := uint64(0); i < threads; i++ { + for i := 0; i < config.threads; i++ { go func() { wg.Add(1) defer wg.Done() - worker(jobs, retries, results, ctx, limitTo) + worker(jobs, results, ctx, config) }() } @@ -183,7 +212,7 @@ func main() { } close(jobs) - cli.Println(fmt.Sprintf("Waiting for workers (%d) to finish", threads)) + cli.Println(fmt.Sprintf("Waiting for workers (%d) to finish", config.threads)) comprehensiveWait(ctx, &wg) close(results) @@ -261,24 +290,24 @@ func durationStr(duration time.Duration) string { return fmt.Sprintf("%02d:%02d", int(duration.Minutes()), int(duration.Seconds())%60) } -func worker(jobs <-chan string, retries int, results chan<- batsResult, ctx context.Context, limitTo map[string]bool) { +func worker(jobs <-chan string, results chan<- batsResult, ctx context.Context, config config) { for job := range jobs { - runBats(job, retries, results, ctx, limitTo) + runBats(job, results, ctx, config) } } // runBats runs a single bats test and sends the result to the results channel. Stdout and stderr are written to files // in the batsee_results directory in the CWD, and the error is written to the result.err field. -func runBats(path string, retries int, resultChan chan<- batsResult, ctx context.Context, limitTo map[string]bool) { +func runBats(path string, resultChan chan<- batsResult, ctx context.Context, cfg config) { cmd := exec.CommandContext(ctx, "bats", path) // Set the process group ID so that we can kill the entire process tree if it runs too long. We need to differenciate // process group of the sub process from this one, because kill the primary process if we don't. cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - cmd.Env = append(os.Environ(), fmt.Sprintf("DOLT_TEST_RETRIES=%d", retries)) + cmd.Env = append(os.Environ(), fmt.Sprintf("DOLT_TEST_RETRIES=%d", cfg.retries)) result := batsResult{path: path} - if limitTo != nil && len(limitTo) != 0 && !limitTo[path] { + if cfg.limitTo != nil && len(cfg.limitTo) != 0 && !cfg.limitTo[path] { result.skipped = true resultChan <- result return @@ -286,7 +315,7 @@ func runBats(path string, retries int, resultChan chan<- batsResult, ctx context startTime := time.Now() - outPath := fmt.Sprintf("batsee_results/%s.stdout.log", path) + outPath := fmt.Sprintf("%s/%s.stdout.log", cfg.output, path) output, err := os.Create(outPath) if err != nil { cli.Println("Error creating stdout log:", err.Error()) @@ -299,7 +328,7 @@ func runBats(path string, retries int, resultChan chan<- batsResult, ctx context result.err = err } - errPath := fmt.Sprintf("batsee_results/%s.stderr.log", path) + errPath := fmt.Sprintf("%s/%s.stderr.log", cfg.output, path) errput, err := os.Create(errPath) if err != nil { cli.Println("Error creating stderr log:", err.Error()) diff --git a/integration-tests/bats/import-update-tables.bats b/integration-tests/bats/import-update-tables.bats index 87c7525d8f..ab059f56d5 100644 --- a/integration-tests/bats/import-update-tables.bats +++ b/integration-tests/bats/import-update-tables.bats @@ -1315,4 +1315,52 @@ v1,v2 DELIM dolt table import -u tbl auto-increment.csv -} \ No newline at end of file +} +@test "import-update-tables: distinguish between empty string and null for ENUMs" { + dolt sql < data.csv + + run dolt table import -u alphabet data.csv + [ $status -eq 0 ] + [[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "select * from alphabet;" + [ $status -eq 0 ] + [[ "$output" = "$expected" ]] || false +} + +@test "import-update-tables: distinguish between empty string and null for SETs" { + dolt sql < word_data.csv + + run dolt table import -u word word_data.csv + [ $status -eq 0 ] + [[ "$output" =~ "Rows Processed: 4, Additions: 4, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "select * from word;" + [ $status -eq 0 ] + [[ "$output" = "$expected" ]] || false +}