Rewrote where clause handling in preparation for making it work with joins

This commit is contained in:
Zach Musgrave
2019-04-25 17:41:34 -07:00
parent 2a054d5165
commit 9b29512150
7 changed files with 274 additions and 183 deletions

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/table/typed/noms"
"github.com/xwb1989/sqlparser"
"io"
@@ -47,7 +48,7 @@ func ExecuteDelete(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Delet
tableSch := table.GetSchema()
// TODO: support aliases
filter, err := createFilterForWhere(s.Where, tableSch, NewAliases())
filter, err := createFilterForWhere(s.Where, map[string]schema.Schema{tableName: tableSch}, NewAliases())
if err != nil {
return errDelete(err.Error())
}

View File

@@ -157,7 +157,7 @@ func makeRow(columns []schema.Column, tableSch schema.Schema, tuple sqlparser.Va
column := columns[i]
switch val := expr.(type) {
case *sqlparser.SQLVal:
nomsVal, err := extractNomsValueFromSQLVal(val, column)
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
if err != nil {
return nil, err
}

View File

@@ -216,7 +216,7 @@ func processSelectedColumns(root *doltdb.RootValue, selectStmt *SelectStatement,
}
} else {
var err error
tableName, tableSch, err = findSchemaForColumn(colName, selectStmt)
tableName, tableSch, err = findSchemaForColumn(colName, selectStmt.inputSchemas)
if err != nil {
return err
}
@@ -244,31 +244,6 @@ func processSelectedColumns(root *doltdb.RootValue, selectStmt *SelectStatement,
return nil
}
// Finds the schema that contains the column name given among the tables given. Returns an error if no schema contains
// such a column name, or if multiple do. This method is only used for naked column names, not qualified ones. Assumes
// that table names have already been verified to exist.
func findSchemaForColumn(colName string, statement *SelectStatement) (string, schema.Schema, error) {
schemas := statement.inputSchemas
var colSchema schema.Schema
var tableName string
for tbl, sch := range schemas {
if _, ok := sch.GetAllCols().GetByName(colName); ok {
if colSchema != nil {
return "", nil, errFmt("Ambiguous column: %v", colName)
}
colSchema = sch
tableName = tbl
}
}
if colSchema == nil {
return "", nil, errFmt("Unknown column '%v'", colName)
}
return tableName, colSchema, nil
}
// Gets the schema for the table name given. Will cause a panic if the table doesn't exist.
func mustGetSchema(root *doltdb.RootValue, tableName string) schema.Schema {
tbl, _:= root.GetTable(tableName)
@@ -278,14 +253,7 @@ func mustGetSchema(root *doltdb.RootValue, tableName string) schema.Schema {
// Processes the where clause by applying an appropriate filter fn to the SelectStatement given. Returns an error if the
// where clause can't be processed.
func processWhereClause(selectStmt *SelectStatement, s *sqlparser.Select) error {
// TODO: make work for more than 1 table
var tableSch schema.Schema
for _, sch := range selectStmt.inputSchemas {
tableSch = sch
break
}
filter, err := createFilterForWhere(s.Where, tableSch, selectStmt.aliases)
filter, err := createFilterForWhere(s.Where, selectStmt.inputSchemas, selectStmt.aliases)
if err != nil {
return err
}
@@ -340,7 +308,8 @@ func createPipeline(root *doltdb.RootValue, statement *SelectStatement) (*pipeli
}
results := make([]resultset.TableResult, 0)
for tableName, result := range pipelines {
for _, tableName := range statement.inputTables {
result := pipelines[tableName]
if err := result.p.Wait(); err != nil || result.err != nil {
return nil, err
}
@@ -402,7 +371,6 @@ func createOutputSchemaMappingTransform(tableSch schema.Schema, rss *resultset.R
// Returns a ResultSetSchema for the given select statement, which contains a target schema and mappings to get there
// from the individual table schemas.
func createResultSetSchema(statement *SelectStatement) error {
// Iterate over the columns twice: first to get an ordered list to use to create an output schema with
cols := make([]schema.Column, 0, len(statement.selectedCols))
for _, selectedCol := range statement.selectedCols {

View File

@@ -309,37 +309,6 @@ func TestExecuteSelect(t *testing.T) {
"l", types.StringKind, "m", types.BoolKind, "a", types.IntKind, "r", types.FloatKind,
"u", types.UUIDKind, "n", types.UintKind),
},
{
name: "Test selecting from multiple tables",
query: `select * from people, episodes`,
expectedRows: rs(
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
),
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
},
{
name: "Test select *, not equals",
query: "select * from people where age <> 40",
@@ -401,6 +370,103 @@ func TestExecuteSelect(t *testing.T) {
}
}
func TestJoins(t *testing.T) {
tests := []struct {
name string
query string
expectedRows []row.Row
expectedSchema schema.Schema
expectedErr bool
}{
{
name: "Test full cross product",
query: `select * from people, episodes`,
expectedRows: rs(
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
),
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
},
{
name: "Test natural join with where clause",
query: `select * from people p, appearances a where a.character_id = p.id`,
expectedRows: rs(
concatRows(peopleTestSchema, homer, episodesTestSchema, ep1),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep2),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep3),
concatRows(peopleTestSchema, homer, episodesTestSchema, ep4),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep1),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep2),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep3),
concatRows(peopleTestSchema, marge, episodesTestSchema, ep4),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep1),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep2),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep3),
concatRows(peopleTestSchema, bart, episodesTestSchema, ep4),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep1),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep2),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep3),
concatRows(peopleTestSchema, lisa, episodesTestSchema, ep4),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep1),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep2),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep3),
concatRows(peopleTestSchema, moe, episodesTestSchema, ep4),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep1),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep2),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep3),
concatRows(peopleTestSchema, barney, episodesTestSchema, ep4),
),
expectedSchema: compressSchemas(peopleTestSchema, episodesTestSchema),
},
}
for _, tt := range tests {
dEnv := dtestutils.CreateTestEnv()
createTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot()
sqlStatement, _ := sqlparser.Parse(tt.query)
s := sqlStatement.(*sqlparser.Select)
t.Run(tt.name, func(t *testing.T) {
if tt.expectedRows != nil && tt.expectedSchema == nil {
assert.Fail(t, "Incorrect test setup: schema must both be provided when rows are")
t.FailNow()
}
rows, sch, err := ExecuteSelect(root, s)
if err != nil {
assert.True(t, tt.expectedErr, err.Error())
} else {
assert.False(t, tt.expectedErr, "unexpected error")
}
assert.Equal(t, tt.expectedRows, rows)
assert.Equal(t, tt.expectedSchema, sch)
})
}
}
// Returns the logical concatenation of the schemas and rows given, rewriting all tag numbers to begin at zero. The row
// returned will have a new schema identical to the result of compressSchema.
func concatRows(schemasAndRows ...interface{}) row.Row {

View File

@@ -6,6 +6,7 @@ import (
"github.com/attic-labs/noms/go/types"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/table/typed/noms"
"github.com/xwb1989/sqlparser"
"io"
@@ -66,7 +67,7 @@ func ExecuteUpdate(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Updat
switch val := update.Expr.(type) {
case *sqlparser.SQLVal:
nomsVal, err := extractNomsValueFromSQLVal(val, column)
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
if err != nil {
return nil, err
}
@@ -133,7 +134,7 @@ func ExecuteUpdate(db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Updat
}
// TODO: support aliases in update
filter, err := createFilterForWhere(s.Where, tableSch, NewAliases())
filter, err := createFilterForWhere(s.Where, map[string]schema.Schema{tableName: tableSch}, NewAliases())
if err != nil {
return errUpdate(err.Error())
}

View File

@@ -232,8 +232,118 @@ func nodeToString(node sqlparser.SQLNode) string {
return buffer.String()
}
// Finds the schema that contains the column name given among the tables given. Returns an error if no schema contains
// such a column name, or if multiple do. This method is only used for naked column names, not qualified ones. Assumes
// that table names have already been verified to exist.
func findSchemaForColumn(colName string, schemas map[string]schema.Schema) (string, schema.Schema, error) {
var colSchema schema.Schema
var tableName string
for tbl, sch := range schemas {
if _, ok := sch.GetAllCols().GetByName(colName); ok {
if colSchema != nil {
return "", nil, errFmt("Ambiguous column: %v", colName)
}
colSchema = sch
tableName = tbl
}
}
if colSchema == nil {
return "", nil, errFmt("Unknown column: '%v'", colName)
}
return tableName, colSchema, nil
}
type valGetterKind uint8
const (
COLNAME valGetterKind = iota
SQL_VAL
BOOL_VAL
)
// valGetter is a convenience object used for comparing the right and left side of an expression
type valGetter struct {
// The kind of this val getter
Kind valGetterKind
// The value type returned by this getter
NomsKind types.NomsKind
// The kind of the value that this getter's result will be compared against, filled in elsewhere
CmpKind types.NomsKind
// Init() performs error checking and does any labor-saving pre-calculation that doens't need to be done for every
// row in the result set
Init func() error
// Get() returns the value for this getter for the row given
Get func(r row.Row) types.Value
// CachedVal is a handy place to put a pre-computed value for getters that deal with constants or literals
CachedVal types.Value
}
// Returns a comparison value getter for the expression given, which could be a column value or a literal
func getComparisonValueGetter(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, aliases *Aliases) (*valGetter, error) {
switch e := expr.(type) {
case *sqlparser.ColName:
colNameStr := e.Name.String()
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
colNameStr = col.ColumnName
}
_, tableSch, err := findSchemaForColumn(colNameStr, inputSchemas)
if err != nil {
return nil, err
}
column, _ := tableSch.GetAllCols().GetByName(colNameStr)
getter := valGetter{Kind: COLNAME, NomsKind: column.Kind}
getter.Init = func() error {
return nil
}
getter.Get = func(r row.Row) types.Value {
value, _ := r.GetColVal(column.Tag)
return value
}
return &getter, nil
case *sqlparser.SQLVal:
getter := valGetter{Kind: SQL_VAL}
getter.Init = func() error {
val, err := extractNomsValueFromSQLVal(e, getter.CmpKind)
if err != nil {
return err
}
getter.CachedVal = val
return nil
}
getter.Get = func(r row.Row) types.Value {
return getter.CachedVal
}
return &getter, nil
case sqlparser.BoolVal:
val := types.Bool(bool(e))
getter := valGetter{Kind: BOOL_VAL, NomsKind: types.BoolKind}
getter.Init = func() error {
switch getter.CmpKind {
case types.BoolKind:
return nil
default:
return errFmt("Type mismatch: boolean value but non-numeric column: %v", nodeToString(e))
}
}
getter.Get = func(r row.Row) types.Value {
return val
}
return &getter, nil
default:
return nil, errFmt("Unsupported comparison %v", nodeToString(e))
}
}
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema, aliases *Aliases) (rowFilterFn, error) {
func createFilterForWhere(whereClause *sqlparser.Where, inputSchemas map[string]schema.Schema, aliases *Aliases) (rowFilterFn, error) {
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {
return nil, errFmt("Having clause not supported")
}
@@ -246,133 +356,90 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
} else {
switch expr := whereClause.Expr.(type) {
case *sqlparser.ComparisonExpr:
left := expr.Left
right := expr.Right
op := expr.Operator
colValOnLeft := true
colExpr := left
valExpr := right
// Swap the column and value expr as necessary
colName, ok := colExpr.(*sqlparser.ColName)
if !ok {
colValOnLeft = false
colExpr = right
valExpr = left
leftGetter, err := getComparisonValueGetter(expr.Left, inputSchemas, aliases)
if err != nil {
return nil, err
}
rightGetter, err := getComparisonValueGetter(expr.Right, inputSchemas, aliases)
if err != nil {
return nil, err
}
colName, ok = colExpr.(*sqlparser.ColName)
if !ok {
return nil, errFmt("Only column names and value literals are supported")
// Fill in noms kinds for SQL_VAL fields if possible
if leftGetter.Kind == SQL_VAL && rightGetter.Kind != SQL_VAL {
leftGetter.NomsKind = rightGetter.NomsKind
}
if rightGetter.Kind == SQL_VAL && leftGetter.Kind != SQL_VAL {
rightGetter.NomsKind = leftGetter.NomsKind
}
colNameStr := colName.Name.String()
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
colNameStr = col.ColumnName
}
column, ok := tableSch.GetAllCols().GetByName(colNameStr)
if !ok {
return nil, errFmt("Unknown column: '%v'", colNameStr)
}
// Fill in comparison kinds before doing error checking
rightGetter.CmpKind, leftGetter.CmpKind = leftGetter.NomsKind, rightGetter.NomsKind
var comparisonVal types.Value
switch val := valExpr.(type) {
case *sqlparser.SQLVal:
var err error
comparisonVal, err = extractNomsValueFromSQLVal(val, column)
if err != nil {
return nil, err
}
case sqlparser.BoolVal:
switch column.Kind {
case types.BoolKind:
comparisonVal = types.Bool(bool(val))
default:
return nil, errFmt("Type mismatch: boolean value but non-numeric column: %v", nodeToString(val))
}
default:
return nil, errFmt("Only SQL literal values are supported in comparisons: %v", nodeToString(val))
// Initialize the getters, mostly so that literal vals can do type error checking and cache results
if err := leftGetter.Init(); err != nil {
return nil, err
}
if err := rightGetter.Init(); err != nil {
return nil, err
}
// All the operations differ only in their filter logic
switch op {
switch expr.Operator {
case sqlparser.EqualStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return comparisonVal.Equals(colVal)
return leftVal.Equals(rightVal)
}
case sqlparser.LessThanStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
leftVal := colVal
rightVal := comparisonVal
if !colValOnLeft {
swap(&leftVal, &rightVal)
}
return leftVal.Less(rightVal)
}
case sqlparser.GreaterThanStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
leftVal := colVal
rightVal := comparisonVal
if !colValOnLeft {
swap(&leftVal, &rightVal)
}
return rightVal.Less(leftVal)
}
case sqlparser.LessEqualStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
leftVal := colVal
rightVal := comparisonVal
if !colValOnLeft {
swap(&leftVal, &rightVal)
}
return leftVal.Less(rightVal) || leftVal.Equals(rightVal)
}
case sqlparser.GreaterEqualStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
leftVal := colVal
rightVal := comparisonVal
if !colValOnLeft {
swap(&leftVal, &rightVal)
}
return rightVal.Less(leftVal) || rightVal.Equals(leftVal)
}
case sqlparser.NotEqualStr:
filter = func(r row.Row) bool {
colVal, ok := r.GetColVal(column.Tag)
if !ok {
leftVal := leftGetter.Get(r)
rightVal := rightGetter.Get(r)
if types.IsNull(leftVal) || types.IsNull(rightVal) {
return false
}
return !comparisonVal.Equals(colVal)
return !leftVal.Equals(rightVal)
}
case sqlparser.NullSafeEqualStr:
return nil, errFmt("null safe equal operation not supported")
@@ -398,10 +465,12 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
colNameStr = col.ColumnName
}
column, ok := tableSch.GetAllCols().GetByName(colNameStr)
if !ok {
return nil, errFmt("Unknown column: '%v'", colNameStr)
_, tableSch, err := findSchemaForColumn(colNameStr, inputSchemas)
if err != nil {
return nil, err
}
column, _ := tableSch.GetAllCols().GetByName(colNameStr)
if column.Kind != types.BoolKind {
return nil, errFmt("Type mismatch: cannot use column %v as boolean expression", colNameStr)
}
@@ -475,14 +544,12 @@ func createFilterForWhere(whereClause *sqlparser.Where, tableSch schema.Schema,
}
func swap(left, right *types.Value) {
temp := *right
*right = *left
*left = temp
*right, *left = *left, *right
}
// extractNomsValueFromSQLVal extracts a noms value from the given SQLVal, using type info in the dolt column given as
// a hint and for type-checking
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (types.Value, error) {
func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, kind types.NomsKind) (types.Value, error) {
switch val.Type {
// Integer-like values
case sqlparser.HexVal, sqlparser.HexNum, sqlparser.IntVal, sqlparser.BitVal:
@@ -490,7 +557,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
if err != nil {
return nil, err
}
switch column.Kind {
switch kind {
case types.IntKind:
return types.Int(intVal), nil
case types.FloatKind:
@@ -506,7 +573,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
if err != nil {
return nil, err
}
switch column.Kind {
switch kind {
case types.FloatKind:
return types.Float(floatVal), nil
default:
@@ -515,7 +582,7 @@ func extractNomsValueFromSQLVal(val *sqlparser.SQLVal, column schema.Column) (ty
// Strings, which can be coerced into lots of other types
case sqlparser.StrVal:
strVal := string(val.Val)
switch column.Kind {
switch kind {
case types.StringKind:
return types.String(strVal), nil
case types.UUIDKind:

View File

@@ -2,8 +2,6 @@ package sql
import (
"github.com/liquidata-inc/ld/dolt/go/cmd/dolt/dtestutils"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
"github.com/stretchr/testify/assert"
"github.com/xwb1989/sqlparser"
"testing"
@@ -13,8 +11,6 @@ func TestWhereClauseErrorHandling(t *testing.T) {
tests := []struct {
name string
query string
expectedRows []row.Row
expectedSchema schema.Schema
expectedErr string
}{
{
@@ -93,6 +89,7 @@ func TestWhereClauseErrorHandling(t *testing.T) {
expectedErr: "Type mismatch:",
},
}
for _, tt := range tests {
dEnv := dtestutils.CreateTestEnv()
createTestDatabase(dEnv, t)
@@ -102,21 +99,12 @@ func TestWhereClauseErrorHandling(t *testing.T) {
s := sqlStatement.(*sqlparser.Select)
t.Run(tt.name, func(t *testing.T) {
if tt.expectedRows != nil && tt.expectedSchema == nil {
assert.Fail(t, "Incorrect test setup: schema must both be provided when rows are")
t.FailNow()
}
rows, sch, err := ExecuteSelect(root, s)
if tt.expectedErr != "" {
assert.NotNil(t, err)
_, _, err := ExecuteSelect(root, s)
if err != nil {
assert.Contains(t, err.Error(), tt.expectedErr)
} else {
assert.Nil(t, err)
assert.Equal(t,"", tt.expectedErr)
}
untypedRows := convertRows(t, tt.expectedRows, peopleTestSchema, tt.expectedSchema)
assert.Equal(t, untypedRows, rows)
assert.Equal(t, tt.expectedSchema, sch)
})
}
}