mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-14 10:09:09 -06:00
Implemented binary expression evaluation (e.g. addition, subtraction. Only works in the where clause, not the expression list, and has a few bugs.
This commit is contained in:
@@ -288,6 +288,9 @@ func parseColumnAlias(colName string) QualifiedColumn {
|
||||
return QualifiedColumn{"", colName}
|
||||
}
|
||||
|
||||
// nomsOperation knows how to combine two noms values into a single one, e.g. addition
|
||||
type nomsOperation func(left, right types.Value) types.Value
|
||||
|
||||
type valGetterKind uint8
|
||||
const (
|
||||
COLNAME valGetterKind = iota
|
||||
@@ -409,8 +412,149 @@ func getterFor(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, alias
|
||||
}
|
||||
|
||||
return &getter, nil
|
||||
case *sqlparser.BinaryExpr:
|
||||
leftGetter, err := getterFor(e.Left, inputSchemas, aliases, rss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rightGetter, err := getterFor(e.Right, inputSchemas, aliases, rss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fill in target noms kinds for SQL_VAL fields if possible
|
||||
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
|
||||
leftGetter.NomsKind = rightGetter.NomsKind
|
||||
}
|
||||
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
|
||||
rightGetter.NomsKind = leftGetter.NomsKind
|
||||
}
|
||||
|
||||
if rightGetter.NomsKind != leftGetter.NomsKind {
|
||||
return nil, errFmt("Unsupported binary operation types: %v, %v", types.KindToString[leftGetter.NomsKind], types.KindToString[rightGetter.NomsKind])
|
||||
}
|
||||
|
||||
// Fill in comparison kinds before doing error checking
|
||||
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
|
||||
|
||||
// Initialize the getters. This uses the type hints from above to enforce type constraints between columns and
|
||||
// literals.
|
||||
if err := leftGetter.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rightGetter.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getter := valGetter{Kind: SQL_VAL, NomsKind: leftGetter.NomsKind, CmpKind: rightGetter.NomsKind}
|
||||
|
||||
// All the operations differ only in their filter logic
|
||||
var opFn nomsOperation
|
||||
switch e.Operator {
|
||||
case sqlparser.PlusStr:
|
||||
switch getter.NomsKind {
|
||||
case types.UintKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Uint(uint64(left.(types.Int)) + uint64(right.(types.Int)))
|
||||
}
|
||||
case types.IntKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Int(int64(left.(types.Int)) + int64(right.(types.Int)))
|
||||
}
|
||||
case types.FloatKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Float(float64(left.(types.Float)) + float64(right.(types.Float)))
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported type for + operation: %v", types.KindToString[getter.NomsKind])
|
||||
}
|
||||
case sqlparser.MinusStr:
|
||||
switch getter.NomsKind {
|
||||
case types.UintKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Uint(uint64(left.(types.Int)) - uint64(right.(types.Int)))
|
||||
}
|
||||
case types.IntKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Int(int64(left.(types.Int)) - int64(right.(types.Int)))
|
||||
}
|
||||
case types.FloatKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Float(float64(left.(types.Float)) - float64(right.(types.Float)))
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported type for - operation: %v", types.KindToString[getter.NomsKind])
|
||||
}
|
||||
case sqlparser.MultStr:
|
||||
switch getter.NomsKind {
|
||||
case types.UintKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Uint(uint64(left.(types.Int)) * uint64(right.(types.Int)))
|
||||
}
|
||||
case types.IntKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Int(int64(left.(types.Int)) * int64(right.(types.Int)))
|
||||
}
|
||||
case types.FloatKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Float(float64(left.(types.Float)) * float64(right.(types.Float)))
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported type for * operation: %v", types.KindToString[getter.NomsKind])
|
||||
}
|
||||
case sqlparser.DivStr:
|
||||
switch getter.NomsKind {
|
||||
case types.UintKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Uint(uint64(left.(types.Int)) / uint64(right.(types.Int)))
|
||||
}
|
||||
case types.IntKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Int(int64(left.(types.Int)) / int64(right.(types.Int)))
|
||||
}
|
||||
case types.FloatKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Float(float64(left.(types.Float)) / float64(right.(types.Float)))
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported type for / operation: %v", types.KindToString[getter.NomsKind])
|
||||
}
|
||||
case sqlparser.ModStr:
|
||||
switch getter.NomsKind {
|
||||
case types.UintKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Uint(uint64(left.(types.Int)) % uint64(right.(types.Int)))
|
||||
}
|
||||
case types.IntKind:
|
||||
opFn = func(left, right types.Value) types.Value {
|
||||
return types.Int(int64(left.(types.Int)) % int64(right.(types.Int)))
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported type for %% operation: %v", types.KindToString[getter.NomsKind])
|
||||
}
|
||||
default:
|
||||
return nil, errFmt("Unsupported binary operation: %v", e.Operator)
|
||||
}
|
||||
|
||||
getter.Init = func() error {
|
||||
// Already did type checking explicitly
|
||||
return nil
|
||||
}
|
||||
|
||||
getter.Get = func(r row.Row) types.Value {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return nil
|
||||
}
|
||||
return opFn(leftVal, rightVal)
|
||||
}
|
||||
|
||||
return &getter, nil
|
||||
case *sqlparser.UnaryExpr:
|
||||
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
|
||||
default:
|
||||
return nil, errFmt("Unsupported comparison %v", nodeToString(e))
|
||||
return nil, errFmt("Unsupported type %v", nodeToString(e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,11 +635,11 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
|
||||
}
|
||||
cols = append(cols, qc)
|
||||
case *sqlparser.IsExpr:
|
||||
cols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
|
||||
isCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, cols...)
|
||||
cols = append(cols, isCols...)
|
||||
case *sqlparser.AndExpr:
|
||||
leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
@@ -518,7 +662,24 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
|
||||
}
|
||||
cols = append(cols, leftCols...)
|
||||
cols = append(cols, rightCols...)
|
||||
case *sqlparser.BinaryExpr:
|
||||
leftCols, err := resolveColumnsInWhereExpr(expr.Left, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rightCols, err := resolveColumnsInWhereExpr(expr.Right, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cols = append(cols, leftCols...)
|
||||
cols = append(cols, rightCols...)
|
||||
case *sqlparser.UnaryExpr:
|
||||
unaryCols, err := resolveColumnsInWhereExpr(expr.Expr, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cols = append(cols, unaryCols...)
|
||||
case *sqlparser.SQLVal, sqlparser.BoolVal, sqlparser.ValTuple:
|
||||
// No columns, just a SQL literal
|
||||
case *sqlparser.NotExpr:
|
||||
@@ -537,10 +698,6 @@ func resolveColumnsInWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string
|
||||
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ListArg:
|
||||
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.BinaryExpr:
|
||||
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.UnaryExpr:
|
||||
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.IntervalExpr:
|
||||
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.CollateExpr:
|
||||
@@ -834,6 +991,13 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string]
|
||||
return nil, errFmt("Unrecognized is comparison: %v", expr.Operator)
|
||||
}
|
||||
|
||||
// Unary and Binary operators are supported in getGetter(), but not as top-level nodes here.
|
||||
case *sqlparser.BinaryExpr:
|
||||
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.UnaryExpr:
|
||||
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
|
||||
|
||||
// Full listing of the unsupported types for informative error messages
|
||||
case *sqlparser.NotExpr:
|
||||
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ParenExpr:
|
||||
@@ -854,10 +1018,6 @@ func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string]
|
||||
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ListArg:
|
||||
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.BinaryExpr:
|
||||
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.UnaryExpr:
|
||||
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.IntervalExpr:
|
||||
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.CollateExpr:
|
||||
|
||||
Reference in New Issue
Block a user