From e6d0c8be37e2766a869edbf54b6117f5c73316f0 Mon Sep 17 00:00:00 2001 From: Neil Macneale IV Date: Wed, 1 Oct 2025 19:38:23 +0000 Subject: [PATCH] Support boolean values in dolt_test_run() --- .../env/actions/test_table_helpers.go | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/go/libraries/doltcore/env/actions/test_table_helpers.go b/go/libraries/doltcore/env/actions/test_table_helpers.go index b1f0ab885b..dc72230056 100644 --- a/go/libraries/doltcore/env/actions/test_table_helpers.go +++ b/go/libraries/doltcore/env/actions/test_table_helpers.go @@ -81,6 +81,18 @@ func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, qu return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil } + // Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception + // of "0" and "1", which are valid integers and are covered below. + if *value != "0" && *value != "1" { + if expectedBool, err := strconv.ParseBool(*value); err == nil { + actualBool, boolErr := getInterfaceAsBool(row[0]) + if boolErr != nil { + return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil + } + return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil + } + } + switch actualValue := row[0].(type) { case int8: expectedInt, err := strconv.ParseInt(*value, 10, 64) @@ -346,6 +358,59 @@ func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal return "" } +// getTinyIntColAsBool returns the value interface{} as a bool +// This is necessary because the query engine may return a tinyint column as a bool, int, or other types. +// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles. +func getInterfaceAsBool(col interface{}) (bool, error) { + switch v := col.(type) { + case bool: + return v, nil + case int: + return v == 1, nil + case int8: + return v == 1, nil + case int16: + return v == 1, nil + case int32: + return v == 1, nil + case int64: + return v == 1, nil + case uint: + return v == 1, nil + case uint8: + return v == 1, nil + case uint16: + return v == 1, nil + case uint32: + return v == 1, nil + case uint64: + return v == 1, nil + case string: + return v == "1", nil + default: + return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v) + } +} + +// compareBooleans is a function used for comparing boolean values. +// It takes in a comparison string from one of: "==", "!=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string { + switch comparison { + case "==": + if expectedValue != realValue { + return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue) + } + case "!=": + if expectedValue == realValue { + return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue) + } + default: + return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison) + } + return "" +} + // compareNullValue is a function used for comparing a null value. // It takes in a comparison string from one of: "==", "!=" // It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise