Refactored sql insert, select, and update to pull out common code dealing with parser AST into a util file

This commit is contained in:
Zach Musgrave
2019-04-12 16:34:56 -07:00
parent e6cbb97cdf
commit 06dec8bd9f
4 changed files with 292 additions and 495 deletions
+265
View File
@@ -1,5 +1,17 @@
package sql
import (
"errors"
"fmt"
"github.com/attic-labs/noms/go/types"
"github.com/google/uuid"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
"github.com/xwb1989/sqlparser"
"strconv"
)
// SQL keyword constants for use in switches and comparisons
const (
ADD = "add"
@@ -209,3 +221,256 @@ const (
ZEROFILL = "zerofill"
)
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema) (filterFn, error) {
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {
return nil, errFmt("Having clause not supported")
}
var filter filterFn
if whereClause == nil {
filter = func(r row.Row) bool {
return true
}
} else {
switch expr := whereClause.Expr.(type) {
case *sqlparser.ComparisonExpr:
left := expr.Left
right := expr.Right
op := expr.Operator
colExpr := left
valExpr := right
// Swap the column and value expr as necessary
colName, ok := colExpr.(*sqlparser.ColName)
if !ok {
colExpr = right
valExpr = left
}
colName, ok = colExpr.(*sqlparser.ColName)
if !ok {
return nil, errFmt("Only column names and value literals are supported")
}
colNameStr := colName.Name.String()
var sqlVal string
switch r := valExpr.(type) {
case *sqlparser.SQLVal:
switch r.Type {
// String-like values will print with quotes or other markers by default, so use the naked asci
// bytes coerced into a string for them
case sqlparser.HexVal, sqlparser.BitVal, sqlparser.StrVal:
sqlVal = string(r.Val)
default:
// Default is to use the string value of the SQL node and hope it works
sqlVal = nodeToString(valExpr)
}
default:
// Default is to use the string value of the SQL node and hope it works
sqlVal = nodeToString(valExpr)
}
col, ok := tableSch.GetAllCols().GetByName(colNameStr)
if !ok {
return nil, errFmt("%v is not a known column", colNameStr)
}
tag := col.Tag
convFunc := doltcore.GetConvFunc(types.StringKind, col.Kind)
comparisonVal, err := convFunc(types.String(string(sqlVal)))
if err != nil {
return nil, errFmt("Couldn't convert column to string: %v", err.Error())
}
// All the operations differ only in their filter logic
switch op {
case sqlparser.EqualStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return comparisonVal.Equals(rowVal)
}
case sqlparser.LessThanStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return rowVal.Less(comparisonVal)
}
case sqlparser.GreaterThanStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return comparisonVal.Less(rowVal)
}
case sqlparser.LessEqualStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return rowVal.Less(comparisonVal) || rowVal.Equals(comparisonVal)
}
case sqlparser.GreaterEqualStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return comparisonVal.Less(rowVal) || comparisonVal.Equals(rowVal)
}
case sqlparser.NotEqualStr:
filter = func(r row.Row) bool {
rowVal, ok := r.GetColVal(tag)
if !ok {
return false
}
return !comparisonVal.Equals(rowVal)
}
case sqlparser.NullSafeEqualStr:
return nil, errFmt("null safe equal operation not supported")
case sqlparser.InStr:
return nil, errFmt("in keyword not supported")
case sqlparser.NotInStr:
return nil, errFmt("in keyword not supported")
case sqlparser.LikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.NotLikeStr:
return nil, errFmt("like keyword not supported")
case sqlparser.RegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.NotRegexpStr:
return nil, errFmt("regular expressions not supported")
case sqlparser.JSONExtractOp:
return nil, errFmt("json not supported")
case sqlparser.JSONUnquoteExtractOp:
return nil, errFmt("json not supported")
}
case *sqlparser.AndExpr:
return nil, errFmt("And expressions not supported: %v", nodeToString(expr))
case *sqlparser.OrExpr:
return nil, errFmt("Or expressions not supported: %v", nodeToString(expr))
case *sqlparser.NotExpr:
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
case *sqlparser.ParenExpr:
return nil, errFmt("Parenthetical expressions not supported: %v", nodeToString(expr))
case *sqlparser.RangeCond:
return nil, errFmt("Range expressions not supported: %v", nodeToString(expr))
case *sqlparser.IsExpr:
return nil, errFmt("Is expressions not supported: %v", nodeToString(expr))
case *sqlparser.ExistsExpr:
return nil, errFmt("Exists expressions not supported: %v", nodeToString(expr))
case *sqlparser.SQLVal:
return nil, errFmt("Literal expressions not supported: %v", nodeToString(expr))
case *sqlparser.NullVal:
return nil, errFmt("NULL expressions not supported: %v", nodeToString(expr))
case *sqlparser.BoolVal:
return nil, errFmt("Bool expressions not supported: %v", nodeToString(expr))
case *sqlparser.ColName:
return nil, errFmt("Column name expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValTuple:
return nil, errFmt("Tuple expressions not supported: %v", nodeToString(expr))
case *sqlparser.Subquery:
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
case *sqlparser.ListArg:
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
case *sqlparser.BinaryExpr:
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
case *sqlparser.UnaryExpr:
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
case *sqlparser.IntervalExpr:
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
case *sqlparser.CollateExpr:
return nil, errFmt("Collate expressions not supported: %v", nodeToString(expr))
case *sqlparser.FuncExpr:
return nil, errFmt("Function expressions not supported: %v", nodeToString(expr))
case *sqlparser.CaseExpr:
return nil, errFmt("Case expressions not supported: %v", nodeToString(expr))
case *sqlparser.ValuesFuncExpr:
return nil, errFmt("Values func expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertExpr:
return nil, errFmt("Conversion expressions not supported: %v", nodeToString(expr))
case *sqlparser.SubstrExpr:
return nil, errFmt("Substr expressions not supported: %v", nodeToString(expr))
case *sqlparser.ConvertUsingExpr:
return nil, errFmt("Convert expressions not supported: %v", nodeToString(expr))
case *sqlparser.MatchExpr:
return nil, errFmt("Match expressions not supported: %v", nodeToString(expr))
case *sqlparser.GroupConcatExpr:
return nil, errFmt("Group concat expressions not supported: %v", nodeToString(expr))
case *sqlparser.Default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
default:
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
}
}
return filter, nil
}
// extractNomsValueFromSQLVal extracts a noms value from the given SQLVal, using type info in the dolt column given as
// a hint and for type-checking
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (types.Value, error) {
switch val.Type {
// Integer-like values
case sqlparser.HexVal, sqlparser.HexNum, sqlparser.IntVal, sqlparser.BitVal:
intVal, err := strconv.ParseInt(string(val.Val), 0, 64)
if err != nil {
return nil, err
}
switch column.Kind {
case types.IntKind:
return types.Int(intVal), nil
case types.FloatKind:
return types.Float(intVal), nil
case types.UintKind:
return types.Uint(intVal), nil
default:
return nil, errFmt("Type mismatch: numeric value but non-numeric column: %v", nodeToString(val))
}
// Float values
case sqlparser.FloatVal:
floatVal, err := strconv.ParseFloat(string(val.Val), 64)
if err != nil {
return nil, err
}
switch column.Kind {
case types.FloatKind:
return types.Float(floatVal), nil
default:
return nil, errFmt("Type mismatch: float value but non-float column: %v", nodeToString(val))
}
// Strings, which can be coerced into lots of other types
case sqlparser.StrVal:
strVal := string(val.Val)
switch column.Kind {
case types.StringKind:
return types.String(strVal), nil
case types.UUIDKind:
id, err := uuid.Parse(strVal)
if err != nil {
return nil, err
}
return types.UUID(id), nil
default:
return nil, errFmt("Type mismatch: string value but non-string column: %v", nodeToString(val))
}
case sqlparser.ValArg:
return nil, errFmt("Value args not supported")
default:
return nil, errFmt("Unrecognized SQLVal type %v", val.Type)
}
}
func errFmt(fmtMsg string, args... interface{}) error {
return errors.New(fmt.Sprintf(fmtMsg, args...))
}