Implemented and, or clauses

This commit is contained in:
Zach Musgrave
2019-04-29 14:26:50 -07:00
parent dcc2e4cc85
commit 4ce7f2fbfe
2 changed files with 267 additions and 211 deletions
+196 -171
View File
@@ -386,197 +386,222 @@ func getColumnNameString(e *sqlparser.ColName) string {
return b.String()
}
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
func createFilterForWhere(whereClause *sqlparser.Where, inputSchemas map[string]schema.Schema, aliases *Aliases, rss *resultset.ResultSetSchema) (rowFilterFn, error) {
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {
return nil, errFmt("Having clause not supported")
}
var filter rowFilterFn
if whereClause == nil {
filter = func(r row.Row) bool {
return func(r row.Row) bool {
return true
}
}, nil
} else {
switch expr := whereClause.Expr.(type) {
case *sqlparser.ComparisonExpr:
return createFilterForWhereExpr(whereClause.Expr, inputSchemas, aliases, rss)
}
}
leftGetter, err := getComparisonValueGetter(expr.Left, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
rightGetter, err := getComparisonValueGetter(expr.Right, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
func createFilterForWhereExpr(whereExpr sqlparser.Expr, inputSchemas map[string]schema.Schema, aliases *Aliases, rss *resultset.ResultSetSchema) (rowFilterFn, error) {
// Fill in noms kinds for SQL_VAL fields if possible
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
leftGetter.NomsKind = rightGetter.NomsKind
}
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
rightGetter.NomsKind = leftGetter.NomsKind
}
var filter rowFilterFn
switch expr := whereExpr.(type) {
case *sqlparser.ComparisonExpr:
// Fill in comparison kinds before doing error checking
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
leftGetter, err := getComparisonValueGetter(expr.Left, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
rightGetter, err := getComparisonValueGetter(expr.Right, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
// Initialize the getters, mostly so that literal vals can do type error checking and cache results
if err := leftGetter.Init(); err != nil {
return nil, err
}
if err := rightGetter.Init(); err != nil {
return nil, err
}
// Fill in noms kinds for SQL_VAL fields if possible
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
leftGetter.NomsKind = rightGetter.NomsKind
}
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
rightGetter.NomsKind = leftGetter.NomsKind
}
// All the operations differ only in their filter logic
switch expr.Operator {
case sqlparser.EqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return leftVal.Equals(rightVal)
}
case sqlparser.LessThanStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return leftVal.Less(rightVal)
}
case sqlparser.GreaterThanStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return rightVal.Less(leftVal)
}
case sqlparser.LessEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return leftVal.Less(rightVal) || leftVal.Equals(rightVal)
}
case sqlparser.GreaterEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return rightVal.Less(leftVal) || rightVal.Equals(leftVal)
}
case sqlparser.NotEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return !leftVal.Equals(rightVal)
}
case sqlparser.NullSafeEqualStr:
return nil, errFmt("null safe equal operation not supported")
case sqlparser.InStr:
return nil, errFmt("in keyword not supported")
case sqlparser.NotInStr:
return nil, errFmt("in keyword not supported")
case sqlparser.LikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.NotLikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.RegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.NotRegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.JSONExtractOp:
return nil, errFmt("json not supported")
case sqlparser.JSONUnquoteExtractOp:
return nil, errFmt("json not supported")
}
case *sqlparser.ColName:
getter, err := getComparisonValueGetter(expr, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
// Fill in comparison kinds before doing error checking
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
if getter.NomsKind != types.BoolKind {
return nil, errFmt("Type mismatch: cannot use column %v as boolean expression", nodeToString(expr))
}
// Initialize the getters, mostly so that literal vals can do type error checking and cache results
if err := leftGetter.Init(); err != nil {
return nil, err
}
if err := rightGetter.Init(); err != nil {
return nil, err
}
// All the operations differ only in their filter logic
switch expr.Operator {
case sqlparser.EqualStr:
filter = func(r row.Row) bool {
colVal := getter.Get(r)
if types.IsNull(colVal) {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return colVal.Equals(types.Bool(true))
return leftVal.Equals(rightVal)
}
case *sqlparser.AndExpr:
return nil, errFmt("And expressions not supported: %v", nodeToString(expr))
case *sqlparser.OrExpr:
return nil, errFmt("Or expressions not supported: %v", nodeToString(expr))
case *sqlparser.NotExpr:
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
case *sqlparser.ParenExpr:
return nil, errFmt("Parenthetical expressions not supported: %v", nodeToString(expr))
case *sqlparser.RangeCond:
return nil, errFmt("Range expressions not supported: %v", nodeToString(expr))
case *sqlparser.IsExpr:
return nil, errFmt("Is expressions not supported: %v", nodeToString(expr))
case *sqlparser.ExistsExpr:
return nil, errFmt("Exists expressions not supported: %v", nodeToString(expr))
case *sqlparser.SQLVal:
return nil, errFmt("Literal expressions not supported: %v", nodeToString(expr))
case *sqlparser.NullVal:
return nil, errFmt("NULL expressions not supported: %v", nodeToString(expr))
case *sqlparser.BoolVal:
return nil, errFmt("Bool expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValTuple:
return nil, errFmt("Tuple expressions not supported: %v", nodeToString(expr))
case *sqlparser.Subquery:
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
case *sqlparser.ListArg:
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
case *sqlparser.IntervalExpr:
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
case *sqlparser.CollateExpr:
return nil, errFmt("Collate expressions not supported: %v", nodeToString(expr))
case *sqlparser.FuncExpr:
return nil, errFmt("Function expressions not supported: %v", nodeToString(expr))
case *sqlparser.CaseExpr:
return nil, errFmt("Case expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValuesFuncExpr:
return nil, errFmt("Values func expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertExpr:
return nil, errFmt("Conversion expressions not supported: %v", nodeToString(expr))
case *sqlparser.SubstrExpr:
return nil, errFmt("Substr expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertUsingExpr:
return nil, errFmt("Convert expressions not supported: %v", nodeToString(expr))
case *sqlparser.MatchExpr:
return nil, errFmt("Match expressions not supported: %v", nodeToString(expr))
case *sqlparser.GroupConcatExpr:
return nil, errFmt("Group concat expressions not supported: %v", nodeToString(expr))
case *sqlparser.Default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
case sqlparser.LessThanStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return leftVal.Less(rightVal)
}
case sqlparser.GreaterThanStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return rightVal.Less(leftVal)
}
case sqlparser.LessEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return leftVal.Less(rightVal) || leftVal.Equals(rightVal)
}
case sqlparser.GreaterEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return rightVal.Less(leftVal) || rightVal.Equals(leftVal)
}
case sqlparser.NotEqualStr:
filter = func(r row.Row) bool {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return !leftVal.Equals(rightVal)
}
case sqlparser.NullSafeEqualStr:
return nil, errFmt("null safe equal operation not supported")
case sqlparser.InStr:
return nil, errFmt("in keyword not supported")
case sqlparser.NotInStr:
return nil, errFmt("in keyword not supported")
case sqlparser.LikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.NotLikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.RegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.NotRegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.JSONExtractOp:
return nil, errFmt("json not supported")
case sqlparser.JSONUnquoteExtractOp:
return nil, errFmt("json not supported")
}
case *sqlparser.ColName:
getter, err := getComparisonValueGetter(expr, inputSchemas, aliases, rss)
if err != nil {
return nil, err
}
if getter.NomsKind != types.BoolKind {
return nil, errFmt("Type mismatch: cannot use column %v as boolean expression", nodeToString(expr))
}
filter = func(r row.Row) bool {
colVal := getter.Get(r)
if types.IsNull(colVal) {
return false
}
return colVal.Equals(types.Bool(true))
}
case *sqlparser.AndExpr:
var leftFilter, rightFilter rowFilterFn
var err error
if leftFilter, err = createFilterForWhereExpr(expr.Left, inputSchemas, aliases, rss); err != nil {
return nil, err
}
if rightFilter, err = createFilterForWhereExpr(expr.Right, inputSchemas, aliases, rss); err != nil {
return nil, err
}
filter = func(r row.Row) (matchesFilter bool) {
return leftFilter(r) && rightFilter(r)
}
case *sqlparser.OrExpr:
var leftFilter, rightFilter rowFilterFn
var err error
if leftFilter, err = createFilterForWhereExpr(expr.Left, inputSchemas, aliases, rss); err != nil {
return nil, err
}
if rightFilter, err = createFilterForWhereExpr(expr.Right, inputSchemas, aliases, rss); err != nil {
return nil, err
}
filter = func(r row.Row) (matchesFilter bool) {
return leftFilter(r) || rightFilter(r)
}
case *sqlparser.NotExpr:
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
case *sqlparser.ParenExpr:
return nil, errFmt("Parenthetical expressions not supported: %v", nodeToString(expr))
case *sqlparser.RangeCond:
return nil, errFmt("Range expressions not supported: %v", nodeToString(expr))
case *sqlparser.IsExpr:
return nil, errFmt("Is expressions not supported: %v", nodeToString(expr))
case *sqlparser.ExistsExpr:
return nil, errFmt("Exists expressions not supported: %v", nodeToString(expr))
case *sqlparser.SQLVal:
return nil, errFmt("Literal expressions not supported: %v", nodeToString(expr))
case *sqlparser.NullVal:
return nil, errFmt("NULL expressions not supported: %v", nodeToString(expr))
case *sqlparser.BoolVal:
return nil, errFmt("Bool expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValTuple:
return nil, errFmt("Tuple expressions not supported: %v", nodeToString(expr))
case *sqlparser.Subquery:
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
case *sqlparser.ListArg:
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
case *sqlparser.IntervalExpr:
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
case *sqlparser.CollateExpr:
return nil, errFmt("Collate expressions not supported: %v", nodeToString(expr))
case *sqlparser.FuncExpr:
return nil, errFmt("Function expressions not supported: %v", nodeToString(expr))
case *sqlparser.CaseExpr:
return nil, errFmt("Case expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValuesFuncExpr:
return nil, errFmt("Values func expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertExpr:
return nil, errFmt("Conversion expressions not supported: %v", nodeToString(expr))
case *sqlparser.SubstrExpr:
return nil, errFmt("Substr expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertUsingExpr:
return nil, errFmt("Convert expressions not supported: %v", nodeToString(expr))
case *sqlparser.MatchExpr:
return nil, errFmt("Match expressions not supported: %v", nodeToString(expr))
case *sqlparser.GroupConcatExpr:
return nil, errFmt("Group concat expressions not supported: %v", nodeToString(expr))
case *sqlparser.Default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
}
return filter, nil