mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-14 10:09:09 -06:00
Rewrote where clause handling in preparation for making it work with joins
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/table/typed/noms"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"io"
|
||||
@@ -47,7 +48,7 @@ func ExecuteDelete(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Delet
|
||||
tableSch := table.GetSchema()
|
||||
|
||||
// TODO: support aliases
|
||||
filter, err := createFilterForWhere(s.Where, tableSch, NewAliases())
|
||||
filter, err := createFilterForWhere(s.Where, map[string]schema.Schema{tableName: tableSch}, NewAliases())
|
||||
if err != nil {
|
||||
return errDelete(err.Error())
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func makeRow(columns []schema.Column, tableSch schema.Schema, tuple sqlparser.Va
|
||||
column := columns[i]
|
||||
switch val := expr.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
nomsVal, err := extractNomsValueFromSQLVal(val, column)
|
||||
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ func processSelectedColumns(root *doltdb.RootValue, selectStmt *SelectStatement,
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
tableName, tableSch, err = findSchemaForColumn(colName, selectStmt)
|
||||
tableName, tableSch, err = findSchemaForColumn(colName, selectStmt.inputSchemas)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -244,31 +244,6 @@ func processSelectedColumns(root *doltdb.RootValue, selectStmt *SelectStatement,
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finds the schema that contains the column name given among the tables given. Returns an error if no schema contains
|
||||
// such a column name, or if multiple do. This method is only used for naked column names, not qualified ones. Assumes
|
||||
// that table names have already been verified to exist.
|
||||
func findSchemaForColumn(colName string, statement *SelectStatement) (string, schema.Schema, error) {
|
||||
schemas := statement.inputSchemas
|
||||
|
||||
var colSchema schema.Schema
|
||||
var tableName string
|
||||
for tbl, sch := range schemas {
|
||||
if _, ok := sch.GetAllCols().GetByName(colName); ok {
|
||||
if colSchema != nil {
|
||||
return "", nil, errFmt("Ambiguous column: %v", colName)
|
||||
}
|
||||
colSchema = sch
|
||||
tableName = tbl
|
||||
}
|
||||
}
|
||||
|
||||
if colSchema == nil {
|
||||
return "", nil, errFmt("Unknown column '%v'", colName)
|
||||
}
|
||||
|
||||
return tableName, colSchema, nil
|
||||
}
|
||||
|
||||
// Gets the schema for the table name given. Will cause a panic if the table doesn't exist.
|
||||
func mustGetSchema(root *doltdb.RootValue, tableName string) schema.Schema {
|
||||
tbl, _:= root.GetTable(tableName)
|
||||
@@ -278,14 +253,7 @@ func mustGetSchema(root *doltdb.RootValue, tableName string) schema.Schema {
|
||||
// Processes the where clause by applying an appropriate filter fn to the SelectStatement given. Returns an error if the
|
||||
// where clause can't be processed.
|
||||
func processWhereClause(selectStmt *SelectStatement, s *sqlparser.Select) error {
|
||||
// TODO: make work for more than 1 table
|
||||
var tableSch schema.Schema
|
||||
for _, sch := range selectStmt.inputSchemas {
|
||||
tableSch = sch
|
||||
break
|
||||
}
|
||||
|
||||
filter, err := createFilterForWhere(s.Where, tableSch, selectStmt.aliases)
|
||||
filter, err := createFilterForWhere(s.Where, selectStmt.inputSchemas, selectStmt.aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -340,7 +308,8 @@ func createPipeline(root *doltdb.RootValue, statement *SelectStatement) (*pipeli
|
||||
}
|
||||
|
||||
results := make([]resultset.TableResult, 0)
|
||||
for tableName, result := range pipelines {
|
||||
for _, tableName := range statement.inputTables {
|
||||
result := pipelines[tableName]
|
||||
if err := result.p.Wait(); err != nil || result.err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -402,7 +371,6 @@ func createOutputSchemaMappingTransform(tableSch schema.Schema, rss *resultset.R
|
||||
// Returns a ResultSetSchema for the given select statement, which contains a target schema and mappings to get there
|
||||
// from the individual table schemas.
|
||||
func createResultSetSchema(statement *SelectStatement) error {
|
||||
|
||||
// Iterate over the columns twice: first to get an ordered list to use to create an output schema with
|
||||
cols := make([]schema.Column, 0, len(statement.selectedCols))
|
||||
for _, selectedCol := range statement.selectedCols {
|
||||
|
||||
@@ -309,37 +309,6 @@ func TestExecuteSelect(t *testing.T) {
|
||||
"l", types.StringKind, "m", types.BoolKind, "a", types.IntKind, "r", types.FloatKind,
|
||||
"u", types.UUIDKind, "n", types.UintKind),
|
||||
},
|
||||
{
|
||||
name: "Test selecting from multiple tables",
|
||||
query: `select * from people, episodes`,
|
||||
expectedRows: rs(
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
|
||||
),
|
||||
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
|
||||
},
|
||||
{
|
||||
name: "Test select *, not equals",
|
||||
query: "select * from people where age <> 40",
|
||||
@@ -401,6 +370,103 @@ func TestExecuteSelect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedRows []row.Row
|
||||
expectedSchema schema.Schema
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Test full cross product",
|
||||
query: `select * from people, episodes`,
|
||||
expectedRows: rs(
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
|
||||
),
|
||||
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
|
||||
},
|
||||
{
|
||||
name: "Test natural join with where clause",
|
||||
query: `select * from people p, appearances a where a.character_id = p.id`,
|
||||
expectedRows: rs(
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
|
||||
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
|
||||
),
|
||||
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
createTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot()
|
||||
|
||||
sqlStatement, _ := sqlparser.Parse(tt.query)
|
||||
s := sqlStatement.(*sqlparser.Select)
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectedRows != nil && tt.expectedSchema == nil {
|
||||
assert.Fail(t, "Incorrect test setup: schema must both be provided when rows are")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
rows, sch, err := ExecuteSelect(root, s)
|
||||
if err != nil {
|
||||
assert.True(t, tt.expectedErr, err.Error())
|
||||
} else {
|
||||
assert.False(t, tt.expectedErr, "unexpected error")
|
||||
}
|
||||
assert.Equal(t, tt.expectedRows, rows)
|
||||
assert.Equal(t, tt.expectedSchema, sch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the logical concatenation of the schemas and rows given, rewriting all tag numbers to begin at zero. The row
|
||||
// returned will have a new schema identical to the result of compressSchema.
|
||||
func concatRows(schemasAndRows ...interface{}) row.Row {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/attic-labs/noms/go/types"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/table/typed/noms"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"io"
|
||||
@@ -66,7 +67,7 @@ func ExecuteUpdate(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Updat
|
||||
|
||||
switch val := update.Expr.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
nomsVal, err := extractNomsValueFromSQLVal(val, column)
|
||||
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -133,7 +134,7 @@ func ExecuteUpdate(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Updat
|
||||
}
|
||||
|
||||
// TODO: support aliases in update
|
||||
filter, err := createFilterForWhere(s.Where, tableSch, NewAliases())
|
||||
filter, err := createFilterForWhere(s.Where, map[string]schema.Schema{tableName: tableSch}, NewAliases())
|
||||
if err != nil {
|
||||
return errUpdate(err.Error())
|
||||
}
|
||||
|
||||
@@ -232,8 +232,118 @@ func nodeToString(node sqlparser.SQLNode) string {
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
// Finds the schema that contains the column name given among the tables given. Returns an error if no schema contains
|
||||
// such a column name, or if multiple do. This method is only used for naked column names, not qualified ones. Assumes
|
||||
// that table names have already been verified to exist.
|
||||
func findSchemaForColumn(colName string, schemas map[string]schema.Schema) (string, schema.Schema, error) {
|
||||
var colSchema schema.Schema
|
||||
var tableName string
|
||||
for tbl, sch := range schemas {
|
||||
if _, ok := sch.GetAllCols().GetByName(colName); ok {
|
||||
if colSchema != nil {
|
||||
return "", nil, errFmt("Ambiguous column: %v", colName)
|
||||
}
|
||||
colSchema = sch
|
||||
tableName = tbl
|
||||
}
|
||||
}
|
||||
|
||||
if colSchema == nil {
|
||||
return "", nil, errFmt("Unknown column: '%v'", colName)
|
||||
}
|
||||
|
||||
return tableName, colSchema, nil
|
||||
}
|
||||
|
||||
type valGetterKind uint8
|
||||
const (
|
||||
COLNAME valGetterKind = iota
|
||||
SQL_VAL
|
||||
BOOL_VAL
|
||||
)
|
||||
|
||||
// valGetter is a convenience object used for comparing the right and left side of an expression
|
||||
type valGetter struct {
|
||||
// The kind of this val getter
|
||||
Kind valGetterKind
|
||||
// The value type returned by this getter
|
||||
NomsKind types.NomsKind
|
||||
// The kind of the value that this getter's result will be compared against, filled in elsewhere
|
||||
CmpKind types.NomsKind
|
||||
// Init() performs error checking and does any labor-saving pre-calculation that doens't need to be done for every
|
||||
// row in the result set
|
||||
Init func() error
|
||||
// Get() returns the value for this getter for the row given
|
||||
Get func(r row.Row) types.Value
|
||||
// CachedVal is a handy place to put a pre-computed value for getters that deal with constants or literals
|
||||
CachedVal types.Value
|
||||
}
|
||||
|
||||
// Returns a comparison value getter for the expression given, which could be a column value or a literal
|
||||
func getComparisonValueGetter(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, aliases *Aliases) (*valGetter, error) {
|
||||
switch e := expr.(type) {
|
||||
case *sqlparser.ColName:
|
||||
colNameStr := e.Name.String()
|
||||
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
|
||||
colNameStr = col.ColumnName
|
||||
}
|
||||
|
||||
_, tableSch, err := findSchemaForColumn(colNameStr, inputSchemas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
column, _ := tableSch.GetAllCols().GetByName(colNameStr)
|
||||
getter := valGetter{Kind: COLNAME, NomsKind: column.Kind}
|
||||
|
||||
getter.Init = func() error {
|
||||
return nil
|
||||
}
|
||||
getter.Get = func(r row.Row) types.Value {
|
||||
value, _ := r.GetColVal(column.Tag)
|
||||
return value
|
||||
}
|
||||
|
||||
return &getter, nil
|
||||
case *sqlparser.SQLVal:
|
||||
getter := valGetter{Kind: SQL_VAL}
|
||||
|
||||
getter.Init = func() error {
|
||||
val, err := extractNomsValueFromSQLVal(e, getter.CmpKind)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
getter.CachedVal = val
|
||||
return nil
|
||||
}
|
||||
getter.Get = func(r row.Row) types.Value {
|
||||
return getter.CachedVal
|
||||
}
|
||||
|
||||
return &getter, nil
|
||||
case sqlparser.BoolVal:
|
||||
val := types.Bool(bool(e))
|
||||
getter := valGetter{Kind: BOOL_VAL, NomsKind: types.BoolKind}
|
||||
|
||||
getter.Init = func() error {
|
||||
switch getter.CmpKind {
|
||||
case types.BoolKind:
|
||||
return nil
|
||||
default:
|
||||
return errFmt("Type mismatch: boolean value but non-numeric column: %v", nodeToString(e))
|
||||
}
|
||||
}
|
||||
getter.Get = func(r row.Row) types.Value {
|
||||
return val
|
||||
}
|
||||
|
||||
return &getter, nil
|
||||
default:
|
||||
return nil, errFmt("Unsupported comparison %v", nodeToString(e))
|
||||
}
|
||||
}
|
||||
|
||||
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
|
||||
func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema, aliases *Aliases) (rowFilterFn, error) {
|
||||
func createFilterForWhere(whereClause *sqlparser.Where, inputSchemas map[string]schema.Schema, aliases *Aliases) (rowFilterFn, error) {
|
||||
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {
|
||||
return nil, errFmt("Having clause not supported")
|
||||
}
|
||||
@@ -246,133 +356,90 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
|
||||
} else {
|
||||
switch expr := whereClause.Expr.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
left := expr.Left
|
||||
right := expr.Right
|
||||
op := expr.Operator
|
||||
|
||||
colValOnLeft := true
|
||||
colExpr := left
|
||||
valExpr := right
|
||||
|
||||
// Swap the column and value expr as necessary
|
||||
colName, ok := colExpr.(*sqlparser.ColName)
|
||||
if !ok {
|
||||
colValOnLeft = false
|
||||
colExpr = right
|
||||
valExpr = left
|
||||
leftGetter, err := getComparisonValueGetter(expr.Left, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rightGetter, err := getComparisonValueGetter(expr.Right, inputSchemas, aliases)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colName, ok = colExpr.(*sqlparser.ColName)
|
||||
if !ok {
|
||||
return nil, errFmt("Only column names and value literals are supported")
|
||||
// Fill in noms kinds for SQL_VAL fields if possible
|
||||
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
|
||||
leftGetter.NomsKind = rightGetter.NomsKind
|
||||
}
|
||||
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
|
||||
rightGetter.NomsKind = leftGetter.NomsKind
|
||||
}
|
||||
|
||||
colNameStr := colName.Name.String()
|
||||
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
|
||||
colNameStr = col.ColumnName
|
||||
}
|
||||
column, ok := tableSch.GetAllCols().GetByName(colNameStr)
|
||||
if !ok {
|
||||
return nil, errFmt("Unknown column: '%v'", colNameStr)
|
||||
}
|
||||
// Fill in comparison kinds before doing error checking
|
||||
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
|
||||
|
||||
var comparisonVal types.Value
|
||||
switch val := valExpr.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
var err error
|
||||
comparisonVal, err = extractNomsValueFromSQLVal(val, column)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case sqlparser.BoolVal:
|
||||
switch column.Kind {
|
||||
case types.BoolKind:
|
||||
comparisonVal = types.Bool(bool(val))
|
||||
default:
|
||||
return nil, errFmt("Type mismatch: boolean value but non-numeric column: %v", nodeToString(val))
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, errFmt("Only SQL literal values are supported in comparisons: %v", nodeToString(val))
|
||||
// Initialize the getters, mostly so that literal vals can do type error checking and cache results
|
||||
if err := leftGetter.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rightGetter.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// All the operations differ only in their filter logic
|
||||
switch op {
|
||||
switch expr.Operator {
|
||||
case sqlparser.EqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
return comparisonVal.Equals(colVal)
|
||||
return leftVal.Equals(rightVal)
|
||||
}
|
||||
case sqlparser.LessThanStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
|
||||
leftVal := colVal
|
||||
rightVal := comparisonVal
|
||||
if !colValOnLeft {
|
||||
swap(&leftVal, &rightVal)
|
||||
}
|
||||
|
||||
return leftVal.Less(rightVal)
|
||||
}
|
||||
case sqlparser.GreaterThanStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
|
||||
leftVal := colVal
|
||||
rightVal := comparisonVal
|
||||
if !colValOnLeft {
|
||||
swap(&leftVal, &rightVal)
|
||||
}
|
||||
|
||||
return rightVal.Less(leftVal)
|
||||
}
|
||||
case sqlparser.LessEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
|
||||
leftVal := colVal
|
||||
rightVal := comparisonVal
|
||||
if !colValOnLeft {
|
||||
swap(&leftVal, &rightVal)
|
||||
}
|
||||
|
||||
return leftVal.Less(rightVal) || leftVal.Equals(rightVal)
|
||||
}
|
||||
case sqlparser.GreaterEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
|
||||
leftVal := colVal
|
||||
rightVal := comparisonVal
|
||||
if !colValOnLeft {
|
||||
swap(&leftVal, &rightVal)
|
||||
}
|
||||
|
||||
return rightVal.Less(leftVal) || rightVal.Equals(leftVal)
|
||||
}
|
||||
case sqlparser.NotEqualStr:
|
||||
filter = func(r row.Row) bool {
|
||||
colVal, ok := r.GetColVal(column.Tag)
|
||||
if !ok {
|
||||
leftVal := leftGetter.Get(r)
|
||||
rightVal := rightGetter.Get(r)
|
||||
if types.IsNull(leftVal) || types.IsNull(rightVal) {
|
||||
return false
|
||||
}
|
||||
return !comparisonVal.Equals(colVal)
|
||||
return !leftVal.Equals(rightVal)
|
||||
}
|
||||
case sqlparser.NullSafeEqualStr:
|
||||
return nil, errFmt("null safe equal operation not supported")
|
||||
@@ -398,10 +465,12 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
|
||||
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
|
||||
colNameStr = col.ColumnName
|
||||
}
|
||||
column, ok := tableSch.GetAllCols().GetByName(colNameStr)
|
||||
if !ok {
|
||||
return nil, errFmt("Unknown column: '%v'", colNameStr)
|
||||
_, tableSch, err := findSchemaForColumn(colNameStr, inputSchemas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
column, _ := tableSch.GetAllCols().GetByName(colNameStr)
|
||||
|
||||
if column.Kind != types.BoolKind {
|
||||
return nil, errFmt("Type mismatch: cannot use column %v as boolean expression", colNameStr)
|
||||
}
|
||||
@@ -475,14 +544,12 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
|
||||
}
|
||||
|
||||
func swap(left, right *types.Value) {
|
||||
temp := *right
|
||||
*right = *left
|
||||
*left = temp
|
||||
*right, *left = *left, *right
|
||||
}
|
||||
|
||||
// extractNomsValueFromSQLVal extracts a noms value from the given SQLVal, using type info in the dolt column given as
|
||||
// a hint and for type-checking
|
||||
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (types.Value, error) {
|
||||
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, kind types.NomsKind) (types.Value, error) {
|
||||
switch val.Type {
|
||||
// Integer-like values
|
||||
case sqlparser.HexVal, sqlparser.HexNum, sqlparser.IntVal, sqlparser.BitVal:
|
||||
@@ -490,7 +557,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch column.Kind {
|
||||
switch kind {
|
||||
case types.IntKind:
|
||||
return types.Int(intVal), nil
|
||||
case types.FloatKind:
|
||||
@@ -506,7 +573,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch column.Kind {
|
||||
switch kind {
|
||||
case types.FloatKind:
|
||||
return types.Float(floatVal), nil
|
||||
default:
|
||||
@@ -515,7 +582,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
|
||||
// Strings, which can be coerced into lots of other types
|
||||
case sqlparser.StrVal:
|
||||
strVal := string(val.Val)
|
||||
switch column.Kind {
|
||||
switch kind {
|
||||
case types.StringKind:
|
||||
return types.String(strVal), nil
|
||||
case types.UUIDKind:
|
||||
|
||||
@@ -2,8 +2,6 @@ package sql
|
||||
|
||||
import (
|
||||
"github.com/liquidata-inc/ld/dolt/go/cmd/dolt/dtestutils"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"testing"
|
||||
@@ -13,8 +11,6 @@ func TestWhereClauseErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedRows []row.Row
|
||||
expectedSchema schema.Schema
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
@@ -93,6 +89,7 @@ func TestWhereClauseErrorHandling(t *testing.T) {
|
||||
expectedErr: "Type mismatch:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
createTestDatabase(dEnv, t)
|
||||
@@ -102,21 +99,12 @@ func TestWhereClauseErrorHandling(t *testing.T) {
|
||||
s := sqlStatement.(*sqlparser.Select)
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.expectedRows != nil && tt.expectedSchema == nil {
|
||||
assert.Fail(t, "Incorrect test setup: schema must both be provided when rows are")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
rows, sch, err := ExecuteSelect(root, s)
|
||||
if tt.expectedErr != "" {
|
||||
assert.NotNil(t, err)
|
||||
_, _, err := ExecuteSelect(root, s)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), tt.expectedErr)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t,"", tt.expectedErr)
|
||||
}
|
||||
untypedRows := convertRows(t, tt.expectedRows, peopleTestSchema, tt.expectedSchema)
|
||||
assert.Equal(t, untypedRows, rows)
|
||||
assert.Equal(t, tt.expectedSchema, sch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user