mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-27 23:51:59 -05:00
Refactored sql insert, select, and update to pull out common code dealing with parser AST into a util file
This commit is contained in:
@@ -1,5 +1,17 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/attic-labs/noms/go/types"
|
||||
"github.com/google/uuid"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SQL keyword constants for use in switches and comparisons
|
||||
const (
|
||||
ADD = "add"
|
||||
@@ -209,3 +221,256 @@ const (
|
||||
ZEROFILL = "zerofill"
|
||||
)
|
||||
|
||||
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
|
||||
func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema) (filterFn, error) {
|
||||
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {
|
||||
return nil, errFmt("Having clause not supported")
|
||||
}
|
||||
|
||||
var filter filterFn
|
||||
if whereClause == nil {
|
||||
filter = func(r row.Row) bool {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
switch expr := whereClause.Expr.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
left := expr.Left
|
||||
right := expr.Right
|
||||
op := expr.Operator
|
||||
|
||||
colExpr := left
|
||||
valExpr := right
|
||||
|
||||
// Swap the column and value expr as necessary
|
||||
colName, ok := colExpr.(*sqlparser.ColName)
|
||||
if !ok {
|
||||
colExpr = right
|
||||
valExpr = left
|
||||
}
|
||||
|
||||
colName, ok = colExpr.(*sqlparser.ColName)
|
||||
if !ok {
|
||||
return nil, errFmt("Only column names and value literals are supported")
|
||||
}
|
||||
|
||||
colNameStr := colName.Name.String()
|
||||
|
||||
var sqlVal string
|
||||
switch r := valExpr.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
switch r.Type {
|
||||
// String-like values will print with quotes or other markers by default, so use the naked asci
|
||||
// bytes coerced into a string for them
|
||||
case sqlparser.HexVal, sqlparser.BitVal, sqlparser.StrVal:
|
||||
sqlVal = string(r.Val)
|
||||
default:
|
||||
// Default is to use the string value of the SQL node and hope it works
|
||||
sqlVal = nodeToString(valExpr)
|
||||
}
|
||||
default:
|
||||
// Default is to use the string value of the SQL node and hope it works
|
||||
sqlVal = nodeToString(valExpr)
|
||||
}
|
||||
|
||||
col, ok := tableSch.GetAllCols().GetByName(colNameStr)
|
||||
if !ok {
|
||||
return nil, errFmt("%v is not a known column", colNameStr)
|
||||
}
|
||||
|
||||
tag := col.Tag
|
||||
convFunc := doltcore.GetConvFunc(types.StringKind, col.Kind)
|
||||
comparisonVal, err := convFunc(types.String(string(sqlVal)))
|
||||
|
||||
if err != nil {
|
||||
return nil, errFmt("Couldn't convert column to string: %v", err.Error())
|
||||
}
|
||||
|
||||
// All the operations differ only in their filter logic
|
||||
switch op {
|
||||
case sqlparser.EqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return comparisonVal.Equals(rowVal)
|
||||
}
|
||||
case sqlparser.LessThanStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return rowVal.Less(comparisonVal)
|
||||
}
|
||||
case sqlparser.GreaterThanStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return comparisonVal.Less(rowVal)
|
||||
}
|
||||
case sqlparser.LessEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return rowVal.Less(comparisonVal) || rowVal.Equals(comparisonVal)
|
||||
}
|
||||
case sqlparser.GreaterEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return comparisonVal.Less(rowVal) || comparisonVal.Equals(rowVal)
|
||||
}
|
||||
case sqlparser.NotEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
rowVal, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return !comparisonVal.Equals(rowVal)
|
||||
}
|
||||
case sqlparser.NullSafeEqualStr:
|
||||
return nil, errFmt("null safe equal operation not supported")
|
||||
case sqlparser.InStr:
|
||||
return nil, errFmt("in keyword not supported")
|
||||
case sqlparser.NotInStr:
|
||||
return nil, errFmt("in keyword not supported")
|
||||
case sqlparser.LikeStr:
|
||||
return nil, errFmt("like keyword not supported")
|
||||
case sqlparser.NotLikeStr:
|
||||
return nil, errFmt("like keyword not supported")
|
||||
case sqlparser.RegexpStr:
|
||||
return nil, errFmt("regular expressions not supported")
|
||||
case sqlparser.NotRegexpStr:
|
||||
return nil, errFmt("regular expressions not supported")
|
||||
case sqlparser.JSONExtractOp:
|
||||
return nil, errFmt("json not supported")
|
||||
case sqlparser.JSONUnquoteExtractOp:
|
||||
return nil, errFmt("json not supported")
|
||||
}
|
||||
case *sqlparser.AndExpr:
|
||||
return nil, errFmt("And expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.OrExpr:
|
||||
return nil, errFmt("Or expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.NotExpr:
|
||||
return nil, errFmt("Not expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ParenExpr:
|
||||
return nil, errFmt("Parenthetical expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.RangeCond:
|
||||
return nil, errFmt("Range expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.IsExpr:
|
||||
return nil, errFmt("Is expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ExistsExpr:
|
||||
return nil, errFmt("Exists expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.SQLVal:
|
||||
return nil, errFmt("Literal expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.NullVal:
|
||||
return nil, errFmt("NULL expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.BoolVal:
|
||||
return nil, errFmt("Bool expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ColName:
|
||||
return nil, errFmt("Column name expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ValTuple:
|
||||
return nil, errFmt("Tuple expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.Subquery:
|
||||
return nil, errFmt("Subquery expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ListArg:
|
||||
return nil, errFmt("List expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.BinaryExpr:
|
||||
return nil, errFmt("Binary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.UnaryExpr:
|
||||
return nil, errFmt("Unary expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.IntervalExpr:
|
||||
return nil, errFmt("Interval expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.CollateExpr:
|
||||
return nil, errFmt("Collate expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.FuncExpr:
|
||||
return nil, errFmt("Function expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.CaseExpr:
|
||||
return nil, errFmt("Case expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ValuesFuncExpr:
|
||||
return nil, errFmt("Values func expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ConvertExpr:
|
||||
return nil, errFmt("Conversion expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.SubstrExpr:
|
||||
return nil, errFmt("Substr expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.ConvertUsingExpr:
|
||||
return nil, errFmt("Convert expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.MatchExpr:
|
||||
return nil, errFmt("Match expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.GroupConcatExpr:
|
||||
return nil, errFmt("Group concat expressions not supported: %v", nodeToString(expr))
|
||||
case *sqlparser.Default:
|
||||
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
|
||||
default:
|
||||
return nil, errFmt("Unrecognized expression: %v", nodeToString(expr))
|
||||
}
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
// extractNomsValueFromSQLVal extracts a noms value from the given SQLVal, using type info in the dolt column given as
|
||||
// a hint and for type-checking
|
||||
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (types.Value, error) {
|
||||
switch val.Type {
|
||||
// Integer-like values
|
||||
case sqlparser.HexVal, sqlparser.HexNum, sqlparser.IntVal, sqlparser.BitVal:
|
||||
intVal, err := strconv.ParseInt(string(val.Val), 0, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch column.Kind {
|
||||
case types.IntKind:
|
||||
return types.Int(intVal), nil
|
||||
case types.FloatKind:
|
||||
return types.Float(intVal), nil
|
||||
case types.UintKind:
|
||||
return types.Uint(intVal), nil
|
||||
default:
|
||||
return nil, errFmt("Type mismatch: numeric value but non-numeric column: %v", nodeToString(val))
|
||||
}
|
||||
// Float values
|
||||
case sqlparser.FloatVal:
|
||||
floatVal, err := strconv.ParseFloat(string(val.Val), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch column.Kind {
|
||||
case types.FloatKind:
|
||||
return types.Float(floatVal), nil
|
||||
default:
|
||||
return nil, errFmt("Type mismatch: float value but non-float column: %v", nodeToString(val))
|
||||
}
|
||||
// Strings, which can be coerced into lots of other types
|
||||
case sqlparser.StrVal:
|
||||
strVal := string(val.Val)
|
||||
switch column.Kind {
|
||||
case types.StringKind:
|
||||
return types.String(strVal), nil
|
||||
case types.UUIDKind:
|
||||
id, err := uuid.Parse(strVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return types.UUID(id), nil
|
||||
default:
|
||||
return nil, errFmt("Type mismatch: string value but non-string column: %v", nodeToString(val))
|
||||
}
|
||||
case sqlparser.ValArg:
|
||||
return nil, errFmt("Value args not supported")
|
||||
default:
|
||||
return nil, errFmt("Unrecognized SQLVal type %v", val.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func errFmt(fmtMsg string, args... interface{}) error {
|
||||
return errors.New(fmt.Sprintf(fmtMsg, args...))
|
||||
}
|
||||
Reference in New Issue
Block a user