From f8bf738dc03df72c0f8a368c04710aa35fb89d86 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 6 May 2019 14:58:59 -0700 Subject: [PATCH] Implemented binary expression evaluation (e.g. addition, subtraction. Only works in the where clause, not the expression list, and has a few bugs. --- go/libraries/doltcore/sql/sqlselect_test.go | 91 +++++++--- go/libraries/doltcore/sql/sqltestutil.go | 21 +++ go/libraries/doltcore/sql/sqlutil.go | 182 ++++++++++++++++++-- 3 files changed, 262 insertions(+), 32 deletions(-) diff --git a/go/libraries/doltcore/sql/sqlselect_test.go b/go/libraries/doltcore/sql/sqlselect_test.go index 63f633d5eb..a286ae320c 100644 --- a/go/libraries/doltcore/sql/sqlselect_test.go +++ b/go/libraries/doltcore/sql/sqlselect_test.go @@ -2,7 +2,6 @@ package sql import ( "context" - "fmt" "github.com/attic-labs/noms/go/types" "github.com/liquidata-inc/ld/dolt/go/cmd/dolt/dtestutils" "github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/table/untyped/resultset" @@ -335,6 +334,76 @@ func TestExecuteSelect(t *testing.T) { query: "select * from people where age is true", expectedErr: "Type mismatch:", }, + // TODO: support operations in select clause + // { + // name: "binary expression in select", + // query: "select age + 1 as age from people where is_married", + // expectedRows: rs(newResultSetRow(types.Int(41)), newResultSetRow(types.Int(39))), + // expectedSchema: newResultSetSchema("age", types.FloatKind), + // }, + { + name: "select *, binary + in where", + query: "select * from people where age + 1 = 41", + expectedRows: compressRows(peopleTestSchema, homer, barney), + expectedSchema: compressSchema(peopleTestSchema), + }, + { + name: "select *, binary - in where", + query: "select * from people where age - 1 = 39", + expectedRows: compressRows(peopleTestSchema, homer, barney), + expectedSchema: compressSchema(peopleTestSchema), + }, + { + name: "select *, binary / in where", + query: "select * from people where age / 2 = 20", + expectedRows: compressRows(peopleTestSchema, homer, barney), + expectedSchema: compressSchema(peopleTestSchema), + }, + { + name: "select *, binary / in where", + query: "select * from people where age * 2 = 80", + expectedRows: compressRows(peopleTestSchema, homer, barney), + expectedSchema: compressSchema(peopleTestSchema), + }, + { + name: "select *, binary % in where", + query: "select * from people where age % 4 = 0", + expectedRows: compressRows(peopleTestSchema, homer, lisa, moe, barney), + expectedSchema: compressSchema(peopleTestSchema), + }, + // TODO: this should work but doesn't. Type checking is getting hung up on 2 + 2, since it has no noms type to + // enforce against + // { + // name: "select *, complex binary expr in where", + // query: "select * from people where age / 4 + 2 * 2 = 14", + // expectedRows: compressRows(peopleTestSchema, homer, barney), + // expectedSchema: compressSchema(peopleTestSchema), + // }, + { + name: "select *, binary + in where type mismatch", + query: "select * from people where first + 1 = 41", + expectedErr: "Type mismatch:", + }, + { + name: "select *, binary - in where type mismatch", + query: "select * from people where first - 1 = 39", + expectedErr: "Type mismatch:", + }, + { + name: "select *, binary / in where type mismatch", + query: "select * from people where first / 2 = 20", + expectedErr: "Type mismatch:", + }, + { + name: "select *, binary / in where type mismatch", + query: "select * from people where first * 2 = 80", + expectedErr: "Type mismatch:", + }, + { + name: "select *, binary % in where type mismatch", + query: "select * from people where first % 4 = 0", + expectedErr: "Type mismatch:", + }, { name: "select subset of cols", query: "select first, last from people where age >= 40", @@ -769,26 +838,6 @@ func compressSchemas(schs ...schema.Schema) schema.Schema { return schema.UnkeyedSchemaFromCols(colCol) } -// Creates a new row for a result set specified by the given values -func newResultSetRow(colVals ...types.Value) row.Row { - - taggedVals := make(row.TaggedValues) - cols := make([]schema.Column, len(colVals)) - for i := 0; i < len(colVals); i++ { - taggedVals[uint64(i)] = colVals[i] - nomsKind := colVals[i].Kind() - cols[i] = schema.NewColumn(fmt.Sprintf("%v", i), uint64(i), nomsKind, false) - } - - collection, err := schema.NewColCollection(cols...) - if err != nil { - panic("unexpected error " + err.Error()) - } - sch := schema.UnkeyedSchemaFromCols(collection) - - return row.New(sch, taggedVals) -} - // Creates a new schema for a result set specified by the given pairs of column names and types. Column names are // strings, types are NomsKinds. func newResultSetSchema(colNamesAndTypes ...interface{}) schema.Schema { diff --git a/go/libraries/doltcore/sql/sqltestutil.go b/go/libraries/doltcore/sql/sqltestutil.go index de340d3755..a60a3b88d6 100644 --- a/go/libraries/doltcore/sql/sqltestutil.go +++ b/go/libraries/doltcore/sql/sqltestutil.go @@ -2,6 +2,7 @@ package sql import ( "context" + "fmt" "github.com/attic-labs/noms/go/types" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -260,6 +261,26 @@ func mutateRow(r row.Row, tagsAndVals ...interface{}) row.Row { return mutated } +// Creates a new row for a result set specified by the given values +func newResultSetRow(colVals ...types.Value) row.Row { + + taggedVals := make(row.TaggedValues) + cols := make([]schema.Column, len(colVals)) + for i := 0; i < len(colVals); i++ { + taggedVals[uint64(i)] = colVals[i] + nomsKind := colVals[i].Kind() + cols[i] = schema.NewColumn(fmt.Sprintf("%v", i), uint64(i), nomsKind, false) + } + + collection, err := schema.NewColCollection(cols...) + if err != nil { + panic("unexpected error " + err.Error()) + } + sch := schema.UnkeyedSchemaFromCols(collection) + + return row.New(sch, taggedVals) +} + func createTestTable(dEnv *env.DoltEnv, t *testing.T, tableName string, sch schema.Schema, rs ...row.Row) { imt := table.NewInMemTable(sch) diff --git a/go/libraries/doltcore/sql/sqlutil.go b/go/libraries/doltcore/sql/sqlutil.go index 5b4a37a3da..941e9be327 100644 --- a/go/libraries/doltcore/sql/sqlutil.go +++ b/go/libraries/doltcore/sql/sqlutil.go @@ -288,6 +288,9 @@ func parseColumnAlias(colName string) QualifiedColumn { return QualifiedColumn{"", colName} } +// nomsOperation knows how to combine two noms values into a single one, e.g. addition +type nomsOperation func(left, right types.Value) types.Value + type valGetterKind uint8 const ( COLNAME valGetterKind = iota @@ -409,8 +412,149 @@ func getterFor(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, alias } return &getter, nil + case *sqlparser.BinaryExpr: + leftGetter, err := getterFor(e.Left, inputSchemas, aliases, rss) + if err != nil { + return nil, err + } + rightGetter, err := getterFor(e.Right, inputSchemas, aliases, rss) + if err != nil { + return nil, err + } + + // Fill in target noms kinds for SQL_VAL fields if possible + if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL { + leftGetter.NomsKind = rightGetter.NomsKind + } + if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL { + rightGetter.NomsKind = leftGetter.NomsKind + } + + if rightGetter.NomsKind != leftGetter.NomsKind { + return nil, errFmt("Unsupported binary operation types: %v, %v", types.KindToString[leftGetter.NomsKind], types.KindToString[rightGetter.NomsKind]) + } + + // Fill in comparison kinds before doing error checking + rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind + + // Initialize the getters. This uses the type hints from above to enforce type constraints between columns and + // literals. + if err := leftGetter.Init(); err != nil { + return nil, err + } + if err := rightGetter.Init(); err != nil { + return nil, err + } + + getter := valGetter{Kind: SQL_VAL, NomsKind: leftGetter.NomsKind, CmpKind: rightGetter.NomsKind} + + // All the operations differ only in their filter logic + var opFn nomsOperation + switch e.Operator { + case sqlparser.PlusStr: + switch getter.NomsKind { + case types.UintKind: + opFn = func(left, right types.Value) types.Value { + return types.Uint(uint64(left.(types.Int)) + uint64(right.(types.Int))) + } + case types.IntKind: + opFn = func(left, right types.Value) types.Value { + return types.Int(int64(left.(types.Int)) + int64(right.(types.Int))) + } + case types.FloatKind: + opFn = func(left, right types.Value) types.Value { + return types.Float(float64(left.(types.Float)) + float64(right.(types.Float))) + } + default: + return nil, errFmt("Unsupported type for + operation: %v", types.KindToString[getter.NomsKind]) + } + case sqlparser.MinusStr: + switch getter.NomsKind { + case types.UintKind: + opFn = func(left, right types.Value) types.Value { + return types.Uint(uint64(left.(types.Int)) - uint64(right.(types.Int))) + } + case types.IntKind: + opFn = func(left, right types.Value) types.Value { + return types.Int(int64(left.(types.Int)) - int64(right.(types.Int))) + } + case types.FloatKind: + opFn = func(left, right types.Value) types.Value { + return types.Float(float64(left.(types.Float)) - float64(right.(types.Float))) + } + default: + return nil, errFmt("Unsupported type for - operation: %v", types.KindToString[getter.NomsKind]) + } + case sqlparser.MultStr: + switch getter.NomsKind { + case types.UintKind: + opFn = func(left, right types.Value) types.Value { + return types.Uint(uint64(left.(types.Int)) * uint64(right.(types.Int))) + } + case types.IntKind: + opFn = func(left, right types.Value) types.Value { + return types.Int(int64(left.(types.Int)) * int64(right.(types.Int))) + } + case types.FloatKind: + opFn = func(left, right types.Value) types.Value { + return types.Float(float64(left.(types.Float)) * float64(right.(types.Float))) + } + default: + return nil, errFmt("Unsupported type for * operation: %v", types.KindToString[getter.NomsKind]) + } + case sqlparser.DivStr: + switch getter.NomsKind { + case types.UintKind: + opFn = func(left, right types.Value) types.Value { + return types.Uint(uint64(left.(types.Int)) / uint64(right.(types.Int))) + } + case types.IntKind: + opFn = func(left, right types.Value) types.Value { + return types.Int(int64(left.(types.Int)) / int64(right.(types.Int))) + } + case types.FloatKind: + opFn = func(left, right types.Value) types.Value { + return types.Float(float64(left.(types.Float)) / float64(right.(types.Float))) + } + default: + return nil, errFmt("Unsupported type for / operation: %v", types.KindToString[getter.NomsKind]) + } + case sqlparser.ModStr: + switch getter.NomsKind { + case types.UintKind: + opFn = func(left, right types.Value) types.Value { + return types.Uint(uint64(left.(types.Int)) % uint64(right.(types.Int))) + } + case types.IntKind: + opFn = func(left, right types.Value) types.Value { + return types.Int(int64(left.(types.Int)) % int64(right.(types.Int))) + } + default: + return nil, errFmt("Unsupported type for %% operation: %v", types.KindToString[getter.NomsKind]) + } + default: + return nil, errFmt("Unsupported binary operation: %v", e.Operator) + } + + getter.Init = func() error { + // Already did type checking explicitly + return nil + } + + getter.Get = func(r row.Row) types.Value { + leftVal := leftGetter.Get(r) + rightVal := rightGetter.Get(r) + if types.IsNull(leftVal) || types.IsNull(rightVal) { + return nil + } + return opFn(leftVal, rightVal) + } + + return &getter, nil + case *sqlparser.UnaryExpr: + return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr)) default: - return nil, errFmt("Unsupported comparison %v", nodeToString(e)) + return nil, errFmt("Unsupported type %v", nodeToString(e)) } } @@ -491,11 +635,11 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string } cols = append(cols, qc) case *sqlparser.IsExpr: - cols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases) + isCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases) if err != nil { return nil, err } - cols = append(cols, cols...) + cols = append(cols, isCols...) case *sqlparser.AndExpr: leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases) if err != nil { @@ -518,7 +662,24 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string } cols = append(cols, leftCols...) cols = append(cols, rightCols...) + case *sqlparser.BinaryExpr: + leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases) + if err != nil { + return nil, err + } + rightCols, err := resolveColumnsInWhereExpr(expr.Right, inputSchemas, aliases) + if err != nil { + return nil, err + } + cols = append(cols, leftCols...) + cols = append(cols, rightCols...) + case *sqlparser.UnaryExpr: + unaryCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases) + if err != nil { + return nil, err + } + cols = append(cols, unaryCols...) case *sqlparser.SQLVal, sqlparser.BoolVal, sqlparser.ValTuple: // No columns, just a SQL literal case *sqlparser.NotExpr: @@ -537,10 +698,6 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr)) case *sqlparser.ListArg: return nil, errFmt("List expressions not supported: %v", nodeToString(expr)) - case *sqlparser.BinaryExpr: - return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr)) - case *sqlparser.UnaryExpr: - return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr)) case *sqlparser.IntervalExpr: return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr)) case *sqlparser.CollateExpr: @@ -834,6 +991,13 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string] return nil, errFmt("Unrecognized is comparison: %v", expr.Operator) } + // Unary and Binary operators are supported in getGetter(), but not as top-level nodes here. + case *sqlparser.BinaryExpr: + return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr)) + case *sqlparser.UnaryExpr: + return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr)) + + // Full listing of the unsupported types for informative error messages case *sqlparser.NotExpr: return nil, errFmt("Not expressions not supported: %v", nodeToString(expr)) case *sqlparser.ParenExpr: @@ -854,10 +1018,6 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string] return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr)) case *sqlparser.ListArg: return nil, errFmt("List expressions not supported: %v", nodeToString(expr)) - case *sqlparser.BinaryExpr: - return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr)) - case *sqlparser.UnaryExpr: - return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr)) case *sqlparser.IntervalExpr: return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr)) case *sqlparser.CollateExpr: