Merge pull request #9753 from dolthub/nathan/testFixes

Fixes for dolt_tests
This commit is contained in:
Nathan Gabrielson
2025-08-29 15:01:48 -07:00
committed by GitHub
4 changed files with 119 additions and 36 deletions
+66 -29
View File
@@ -39,7 +39,7 @@ const (
// message is a string used to indicate test failures, and will not halt the overall process.
// message will be empty if the test passed.
// err indicates runtime failures and will stop dolt_test_run from proceeding.
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value string, queryResult *sql.RowIter) (testPassed bool, message string, err error) {
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult *sql.RowIter) (testPassed bool, message string, err error) {
switch assertion {
case AssertionExpectedRows:
message, err = expectRows(sqlCtx, comparison, value, queryResult)
@@ -59,7 +59,7 @@ func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value
return true, "", nil
}
func expectSingleValue(sqlCtx *sql.Context, comparison string, value string, queryResult *sql.RowIter) (message string, err error) {
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult *sql.RowIter) (message string, err error) {
row, err := (*queryResult).Next(sqlCtx)
if err == io.EOF {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
@@ -70,7 +70,6 @@ func expectSingleValue(sqlCtx *sql.Context, comparison string, value string, que
if len(row) != 1 {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
}
_, err = (*queryResult).Next(sqlCtx)
(*queryResult).Close(sqlCtx)
if err == nil { //If multiple rows were given, we should error out
@@ -79,41 +78,51 @@ func expectSingleValue(sqlCtx *sql.Context, comparison string, value string, que
return "", err
}
if value == nil { // If we're expecting a null value, we don't need to type switch
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
}
switch actualValue := row[0].(type) {
case int32:
expectedInt, err := strconv.ParseInt(value, 10, 32)
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", value, actualValue), nil
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(value, 10, 32)
case int64:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", value, actualValue), nil
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case float64:
expectedFloat, err := strconv.ParseFloat(value, 64)
expectedFloat, err := strconv.ParseFloat(*value, 64)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", value, actualValue), nil
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
case float32:
expectedFloat, err := strconv.ParseFloat(value, 32)
expectedFloat, err := strconv.ParseFloat(*value, 32)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", value, actualValue), nil
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
case decimal.Decimal:
expectedDecimal, err := decimal.NewFromString(value)
expectedDecimal, err := decimal.NewFromString(*value)
if err != nil {
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", value, actualValue), nil
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
}
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
case time.Time:
expectedTime, format, err := parseTestsDate(value)
expectedTime, format, err := parseTestsDate(*value)
if err != nil {
return fmt.Sprintf("%s does not appear to be a valid date", value), nil
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
}
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
case *val.TextStorage, string:
@@ -121,16 +130,19 @@ func expectSingleValue(sqlCtx *sql.Context, comparison string, value string, que
if err != nil {
return "", err
}
return compareTestAssertion(comparison, value, actualString, AssertionExpectedSingleValue), nil
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
default:
return fmt.Sprintf("The type of %v is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
}
}
func expectRows(sqlCtx *sql.Context, comparison string, value string, queryResult *sql.RowIter) (message string, err error) {
expectedRows, err := strconv.Atoi(value)
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult *sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedRows, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", value), nil
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numRows int
@@ -146,10 +158,13 @@ func expectRows(sqlCtx *sql.Context, comparison string, value string, queryResul
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
}
func expectColumns(sqlCtx *sql.Context, comparison string, value string, queryResult *sql.RowIter) (message string, err error) {
expectedColumns, err := strconv.Atoi(value)
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult *sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedColumns, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", value), nil
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numColumns int
@@ -255,7 +270,7 @@ func compareDates(comparison string, expectedValue, realValue time.Time, format
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
default:
return fmt.Sprintf("%s is not a valid assertion type", comparison)
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
@@ -290,7 +305,26 @@ func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid assertion type", comparison)
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// compareNullValue is a function used for comparing a null value.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
switch comparison {
case "==":
if actualValue != nil {
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
}
case "!=":
if actualValue == nil {
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
}
default:
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
}
return ""
}
@@ -298,12 +332,15 @@ func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal
// GetStringColAsString is a function that returns a text column as a string.
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
// so we use a special parser to get the correct string values
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (string, error) {
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if ts, ok := tableValue.(*val.TextStorage); ok {
return ts.Unwrap(sqlCtx)
str, err := ts.Unwrap(sqlCtx)
return &str, err
} else if str, ok := tableValue.(string); ok {
return str, nil
return &str, nil
} else if tableValue == nil {
return nil, nil
} else {
return "", fmt.Errorf("unexpected type %T, was expecting string", tableValue)
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}
@@ -223,18 +223,18 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if err != nil {
return
}
message, err := validateQuery(trtf.ctx, trtf.catalog, query)
message, err := validateQuery(trtf.ctx, trtf.catalog, *query)
if err != nil && message == "" {
message = fmt.Sprintf("query error: %s", err.Error())
}
var testPassed bool
if message == "" {
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, query)
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query)
if err != nil {
message = fmt.Sprintf("Query error: %s", err.Error())
} else {
testPassed, message, err = actions.AssertData(trtf.ctx, assertion, comparison, value, &queryResult)
testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, &queryResult)
if err != nil {
return testResult{}, err
}
@@ -245,7 +245,12 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if !testPassed {
status = "FAIL"
}
result = testResult{testName, groupName, query, status, message}
var groupString string
if groupName != nil {
groupString = *groupName
}
result = testResult{*testName, groupString, *query, status, message}
return result, nil
}
@@ -301,7 +306,7 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er
return !node.IsReadOnly(), nil
}
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value string, err error) {
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) {
if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil {
return
}
@@ -337,5 +342,9 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string,
} else if isWrite {
return "Cannot execute write queries", nil
}
if strings.Contains(strings.ToLower(query), "dolt_test_run(") {
return "Cannot call dolt_test_run in dolt_tests", nil
}
return "", nil
}
@@ -52,7 +52,7 @@ func doltTestsSchema() sql.Schema {
{Name: "test_query", Type: sqlTypes.Text, Source: doltdb.TestsTableName, PrimaryKey: false, Nullable: false},
{Name: "assertion_type", Type: sqlTypes.Text, Source: doltdb.TestsTableName, PrimaryKey: false, Nullable: false},
{Name: "assertion_comparator", Type: sqlTypes.Text, Source: doltdb.TestsTableName, PrimaryKey: false, Nullable: false},
{Name: "assertion_value", Type: sqlTypes.Text, Source: doltdb.TestsTableName, PrimaryKey: false, Nullable: false},
{Name: "assertion_value", Type: sqlTypes.Text, Source: doltdb.TestsTableName, PrimaryKey: false, Nullable: true},
}
}
@@ -148,7 +148,7 @@ var DoltTestRunFunctionScripts = []queries.ScriptTest{
},
},
{
Name: "Delimiter is option for dolt_test_run",
Name: "Delimiter is optional for dolt_test_run",
SetUpScript: []string{
"INSERT INTO dolt_tests VALUES ('should pass', 'delimiter tests', 'show tables', 'expected_rows', '==', '0'), " +
"('should also pass', 'delimiter tests', 'show tables;', 'expected_rows', '==', '0')",
@@ -163,6 +163,20 @@ var DoltTestRunFunctionScripts = []queries.ScriptTest{
},
},
},
{
Name: "Null test group functions correctly",
SetUpScript: []string{
"INSERT INTO dolt_tests VALUES ('should pass', NULL, 'select * from dolt_log;', 'expected_rows', '>=', '1')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * FROM dolt_test_run('should pass')",
Expected: []sql.Row{
{"should pass", "", "select * from dolt_log;", "PASS", ""},
},
},
},
},
{
Name: "Simple row and column tests",
SetUpScript: []string{
@@ -293,6 +307,29 @@ var DoltTestRunFunctionScripts = []queries.ScriptTest{
},
},
},
{
Name: "Can handle null values correctly",
SetUpScript: []string{
"CREATE TABLE numbers (i int, t text, j int)",
"INSERT INTO numbers VALUES (NULL, NULL, 4)",
"INSERT INTO dolt_tests (test_name, test_query, assertion_type, assertion_comparator) VALUES " +
"('simple null int equality', 'SELECT i FROM numbers', 'expected_single_value', '=='), " +
"('simple null string equality', 'SELECT t FROM numbers', 'expected_single_value', '=='), " +
"('simple null inequality', 'SELECT i FROM numbers', 'expected_single_value', '!='), " +
"('expect null, get not null', 'SELECT j FROM numbers', 'expected_single_value', '==')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * FROM dolt_test_run('*')",
Expected: []sql.Row{
{"expect null, get not null", "", "SELECT j FROM numbers", "FAIL", "Assertion failed: expected_single_value equal to NULL, got 4"},
{"simple null inequality", "", "SELECT i FROM numbers", "FAIL", "Assertion failed: expected_single_value not equal to NULL, got NULL"},
{"simple null int equality", "", "SELECT i FROM numbers", "PASS", ""},
{"simple null string equality", "", "SELECT t FROM numbers", "PASS", ""},
},
},
},
},
{
Name: "Single value will not accept multiple values",
SetUpScript: []string{