diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 37590c2867..7457767b27 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -415,6 +415,49 @@ var DoltRevisionDbScripts = []queries.ScriptTest{ // DoltScripts are script tests specific to Dolt (not the engine in general), e.g. by involving Dolt functions. Break // this slice into others with good names as it grows. var DoltScripts = []queries.ScriptTest{ + { + Name: "test null filtering in secondary indexes (https://github.com/dolthub/dolt/issues/4199)", + SetUpScript: []string{ + "create table t (pk int primary key auto_increment, d datetime, index index1 (d));", + "insert into t (d) values (NOW()), (NOW());", + "insert into t (d) values (NULL), (NULL);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select count(*) from t where d is not null", + Expected: []sql.Row{{2}}, + }, + { + Query: "select count(*) from t where d is null", + Expected: []sql.Row{{2}}, + }, + { + // Test the null-safe equals operator + Query: "select count(*) from t where d <=> NULL", + Expected: []sql.Row{{2}}, + }, + { + // Test the null-safe equals operator + Query: "select count(*) from t where not(d <=> null)", + Expected: []sql.Row{{2}}, + }, + { + // Test an IndexedJoin + Query: "select count(ifnull(t.d, 1)) from t, t as t2 where t.d is not null and t.pk = t2.pk and t2.d is not null;", + Expected: []sql.Row{{2}}, + }, + { + // Test an IndexedJoin + Query: "select count(ifnull(t.d, 1)) from t, t as t2 where t.d is null and t.pk = t2.pk and t2.d is null;", + Expected: []sql.Row{{2}}, + }, + { + // Test an IndexedJoin + Query: "select count(ifnull(t.d, 1)) from t, t as t2 where t.d is null and t.pk = t2.pk and t2.d is not null;", + Expected: []sql.Row{{0}}, + }, + }, + }, { Name: "test backticks in index name (https://github.com/dolthub/dolt/issues/3776)", SetUpScript: []string{ diff --git a/go/store/types/tuple.go b/go/store/types/tuple.go index 78b73c14ab..2c24dba846 100644 --- a/go/store/types/tuple.go +++ b/go/store/types/tuple.go @@ -690,6 +690,15 @@ func (t Tuple) TupleCompare(nbf *NomsBinFormat, otherTuple Tuple) (int, error) { otherKind := otherDec.ReadKind() if kind != otherKind { + // If we are comparing any type to a null type, always evaluate + // the null value as greater than the non-null value. This is needed + // to keep null value ordering consistent in indexes. + if kind == NullKind { + return 1, nil + } else if otherKind == NullKind { + return -1, nil + } + return int(kind) - int(otherKind), nil } diff --git a/go/store/types/tuple_test.go b/go/store/types/tuple_test.go index c9bf154161..bd36161a2c 100644 --- a/go/store/types/tuple_test.go +++ b/go/store/types/tuple_test.go @@ -186,6 +186,46 @@ func TestTupleLess(t *testing.T) { []Value{UUID(uuid.MustParse(OneUUID)), String("abc")}, true, }, + { + []Value{UUID(uuid.MustParse(OneUUID)), String("abc")}, + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + true, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + []Value{UUID(uuid.MustParse(OneUUID)), String("abc")}, + false, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), Timestamp(time.Now())}, + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + true, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + []Value{UUID(uuid.MustParse(OneUUID)), Timestamp(time.Now())}, + false, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), Int(100)}, + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + true, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + []Value{UUID(uuid.MustParse(OneUUID)), Int(100)}, + false, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), Point{1, 1.0, 1.0}}, + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + true, + }, + { + []Value{UUID(uuid.MustParse(OneUUID)), NullValue}, + []Value{UUID(uuid.MustParse(OneUUID)), Point{1, 1.0, 1.0}}, + false, + }, } isLTZero := func(n int) bool { @@ -194,23 +234,25 @@ func TestTupleLess(t *testing.T) { nbf := Format_Default for _, test := range tests { - tpl1, err := NewTuple(nbf, test.vals1...) + t.Run("", func(t *testing.T) { + tpl1, err := NewTuple(nbf, test.vals1...) - require.NoError(t, err) + require.NoError(t, err) - tpl2, err := NewTuple(nbf, test.vals2...) - require.NoError(t, err) + tpl2, err := NewTuple(nbf, test.vals2...) + require.NoError(t, err) - actual, err := tpl1.Less(nbf, tpl2) - require.NoError(t, err) + actual, err := tpl1.Less(nbf, tpl2) + require.NoError(t, err) - if actual != test.expected { - t.Error("tpl1:", mustString(EncodedValue(context.Background(), tpl1)), "tpl2:", mustString(EncodedValue(context.Background(), tpl2)), "expected", test.expected, "actual:", actual) - } + if actual != test.expected { + t.Error("tpl1:", mustString(EncodedValue(context.Background(), tpl1)), "tpl2:", mustString(EncodedValue(context.Background(), tpl2)), "expected", test.expected, "actual:", actual) + } - res, err := tpl1.Compare(nbf, tpl2) - require.NoError(t, err) - require.Equal(t, actual, isLTZero(res)) + res, err := tpl1.Compare(nbf, tpl2) + require.NoError(t, err) + require.Equal(t, actual, isLTZero(res)) + }) } }