Implemented binary expression evaluation (e.g. addition, subtraction. Only works in the where clause, not the expression list, and has a few bugs.

This commit is contained in:
Zach Musgrave
2019-05-06 14:58:59 -07:00
parent 0755df13a3
commit f8bf738dc0
3 changed files with 262 additions and 32 deletions

View File

@@ -288,6 +288,9 @@ func parseColumnAlias(colName string) QualifiedColumn {
return QualifiedColumn{"", colName}
}
// nomsOperation knows how to combine two noms values into a single one, e.g. addition
type nomsOperation func(left, right types.Value) types.Value
type valGetterKind uint8
const (
COLNAME valGetterKind = iota
@@ -409,8 +412,149 @@ func getterFor(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, alias
}
return &getter, nil
case *sqlparser.BinaryExpr:
leftGetter, err := getterFor(e.Left, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
rightGetter, err := getterFor(e.Right, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
// Fill in target noms kinds for SQL_VAL fields if possible
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
leftGetter.NomsKind = rightGetter.NomsKind
}
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
rightGetter.NomsKind = leftGetter.NomsKind
}
if rightGetter.NomsKind != leftGetter.NomsKind {
return nil, errFmt("Unsupported binary operation types: %v, %v", types.KindToString[leftGetter.NomsKind], types.KindToString[rightGetter.NomsKind])
}
// Fill in comparison kinds before doing error checking
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
// Initialize the getters. This uses the type hints from above to enforce type constraints between columns and
// literals.
if err := leftGetter.Init(); err != nil {
return nil, err
}
if err := rightGetter.Init(); err != nil {
return nil, err
}
getter := valGetter{Kind: SQL_VAL, NomsKind: leftGetter.NomsKind, CmpKind: rightGetter.NomsKind}
// All the operations differ only in their filter logic
var opFn nomsOperation
switch e.Operator {
case sqlparser.PlusStr:
switch getter.NomsKind {
case types.UintKind:
opFn = func(left, right types.Value) types.Value {
return types.Uint(uint64(left.(types.Int)) + uint64(right.(types.Int)))
}
case types.IntKind:
opFn = func(left, right types.Value) types.Value {
return types.Int(int64(left.(types.Int)) + int64(right.(types.Int)))
}
case types.FloatKind:
opFn = func(left, right types.Value) types.Value {
return types.Float(float64(left.(types.Float)) + float64(right.(types.Float)))
}
default:
return nil, errFmt("Unsupported type for + operation: %v", types.KindToString[getter.NomsKind])
}
case sqlparser.MinusStr:
switch getter.NomsKind {
case types.UintKind:
opFn = func(left, right types.Value) types.Value {
return types.Uint(uint64(left.(types.Int)) - uint64(right.(types.Int)))
}
case types.IntKind:
opFn = func(left, right types.Value) types.Value {
return types.Int(int64(left.(types.Int)) - int64(right.(types.Int)))
}
case types.FloatKind:
opFn = func(left, right types.Value) types.Value {
return types.Float(float64(left.(types.Float)) - float64(right.(types.Float)))
}
default:
return nil, errFmt("Unsupported type for - operation: %v", types.KindToString[getter.NomsKind])
}
case sqlparser.MultStr:
switch getter.NomsKind {
case types.UintKind:
opFn = func(left, right types.Value) types.Value {
return types.Uint(uint64(left.(types.Int)) * uint64(right.(types.Int)))
}
case types.IntKind:
opFn = func(left, right types.Value) types.Value {
return types.Int(int64(left.(types.Int)) * int64(right.(types.Int)))
}
case types.FloatKind:
opFn = func(left, right types.Value) types.Value {
return types.Float(float64(left.(types.Float)) * float64(right.(types.Float)))
}
default:
return nil, errFmt("Unsupported type for * operation: %v", types.KindToString[getter.NomsKind])
}
case sqlparser.DivStr:
switch getter.NomsKind {
case types.UintKind:
opFn = func(left, right types.Value) types.Value {
return types.Uint(uint64(left.(types.Int)) / uint64(right.(types.Int)))
}
case types.IntKind:
opFn = func(left, right types.Value) types.Value {
return types.Int(int64(left.(types.Int)) / int64(right.(types.Int)))
}
case types.FloatKind:
opFn = func(left, right types.Value) types.Value {
return types.Float(float64(left.(types.Float)) / float64(right.(types.Float)))
}
default:
return nil, errFmt("Unsupported type for / operation: %v", types.KindToString[getter.NomsKind])
}
case sqlparser.ModStr:
switch getter.NomsKind {
case types.UintKind:
opFn = func(left, right types.Value) types.Value {
return types.Uint(uint64(left.(types.Int)) % uint64(right.(types.Int)))
}
case types.IntKind:
opFn = func(left, right types.Value) types.Value {
return types.Int(int64(left.(types.Int)) % int64(right.(types.Int)))
}
default:
return nil, errFmt("Unsupported type for %% operation: %v", types.KindToString[getter.NomsKind])
}
default:
return nil, errFmt("Unsupported binary operation: %v", e.Operator)
}
getter.Init = func() error {
// Already did type checking explicitly
return nil
}
getter.Get = func(r row.Row) types.Value {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return nil
}
return opFn(leftVal, rightVal)
}
return &getter, nil
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
default:
return nil, errFmt("Unsupported comparison %v", nodeToString(e))
return nil, errFmt("Unsupported type %v", nodeToString(e))
}
}
@@ -491,11 +635,11 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
}
cols = append(cols, qc)
case *sqlparser.IsExpr:
cols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
isCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
if err != nil {
return nil, err
}
cols = append(cols, cols...)
cols = append(cols, isCols...)
case *sqlparser.AndExpr:
leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases)
if err != nil {
@@ -518,7 +662,24 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
}
cols = append(cols, leftCols...)
cols = append(cols, rightCols...)
case *sqlparser.BinaryExpr:
leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases)
if err != nil {
return nil, err
}
rightCols, err := resolveColumnsInWhereExpr(expr.Right, inputSchemas, aliases)
if err != nil {
return nil, err
}
cols = append(cols, leftCols...)
cols = append(cols, rightCols...)
case *sqlparser.UnaryExpr:
unaryCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
if err != nil {
return nil, err
}
cols = append(cols, unaryCols...)
case *sqlparser.SQLVal, sqlparser.BoolVal, sqlparser.ValTuple:
// No columns, just a SQL literal
case *sqlparser.NotExpr:
@@ -537,10 +698,6 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
case *sqlparser.ListArg:
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
case *sqlparser.IntervalExpr:
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
case *sqlparser.CollateExpr:
@@ -834,6 +991,13 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string]
return nil, errFmt("Unrecognized is comparison: %v", expr.Operator)
}
// Unary and Binary operators are supported in getGetter(), but not as top-level nodes here.
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
// Full listing of the unsupported types for informative error messages
case *sqlparser.NotExpr:
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
case *sqlparser.ParenExpr:
@@ -854,10 +1018,6 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string]
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
case *sqlparser.ListArg:
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
case *sqlparser.IntervalExpr:
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
case *sqlparser.CollateExpr: