Closer to working joins without breaking any existing supported queries. Mostly this involves doing a more intelligent job resolving column expressions to their source table and schema.

This commit is contained in:
Zach Musgrave
2019-04-26 13:21:30 -07:00
parent 617d84f89d
commit c55d394b95
3 changed files with 106 additions and 80 deletions
+56 -10
View File
@@ -9,6 +9,7 @@ import (
"github.com/liquidata-inc/ld/dolt/go/libraries/doltcore/schema"
"github.com/xwb1989/sqlparser"
"strconv"
"strings"
)
// SQL keyword constants for use in switches and comparisons
@@ -232,16 +233,36 @@ func nodeToString(node sqlparser.SQLNode) string {
return buffer.String()
}
// Finds the schema that contains the column name given among the tables given. Returns an error if no schema contains
// such a column name, or if multiple do. This method is only used for naked column names, not qualified ones. Assumes
// that table names have already been verified to exist.
func findSchemaForColumn(colName string, schemas map[string]schema.Schema) (string, schema.Schema, error) {
// Finds the schema that contains the column name given among the tables given, and returns the fully qualified column,
// with the full (unaliased) name of the table and column being referenced. Returns an error if no schema contains such
// a column name, or if multiple do.
func resolveColumn(colName string, schemas map[string]schema.Schema, aliases *Aliases) (QualifiedColumn, error) {
// First try getting the table from the column name string itself, eg. "t.col"
qc := parseColumnAlias(colName)
if qc.TableName != "" {
tableName := aliases.TablesByAlias[qc.TableName]
if resolvedName, ok := aliases.ColumnsByAlias[qc.ColumnName]; ok {
return resolvedName, nil
}
if _, ok := schemas[tableName]; ok {
return QualifiedColumn{TableName: tableName, ColumnName: qc.ColumnName}, nil
} else {
return QualifiedColumn{}, errFmt("Unrecognized table name: '%v'", tableName)
}
}
// Then try matching it with known aliases
if qc, ok := aliases.ColumnsByAlias[colName]; ok {
return qc, nil
}
// Finally, look through all input schemas to see if there's an exact match and dying if there's any ambiguity
var colSchema schema.Schema
var tableName string
for tbl, sch := range schemas {
if _, ok := sch.GetAllCols().GetByName(colName); ok {
if colSchema != nil {
return "", nil, errFmt("Ambiguous column: %v", colName)
return QualifiedColumn{}, errFmt("Ambiguous column: %v", colName)
}
colSchema = sch
tableName = tbl
@@ -249,10 +270,19 @@ func findSchemaForColumn(colName string, schemas map[string]schema.Schema) (stri
}
if colSchema == nil {
return "", nil, errFmt("Unknown column: '%v'", colName)
return QualifiedColumn{}, errFmt("Unknown column: '%v'", colName)
}
return tableName, colSchema, nil
return QualifiedColumn{TableName: tableName, ColumnName: colName}, nil
}
// Parses a column alias (e.g.: "a.id") into a qualified column name, where either the table name or the column name may
// be an alias. If there is no table qualifier, the returned QualifiedColumn will have an empty TableName
func parseColumnAlias(colName string) QualifiedColumn {
if idx := strings.Index(colName, "."); idx > 0 {
return QualifiedColumn{colName[:idx], colName[idx+1:]}
}
return QualifiedColumn{"", colName}
}
type valGetterKind uint8
@@ -283,16 +313,22 @@ type valGetter struct {
func getComparisonValueGetter(expr sqlparser.Expr, inputSchemas map[string]schema.Schema, aliases *Aliases) (*valGetter, error) {
switch e := expr.(type) {
case *sqlparser.ColName:
colNameStr := e.Name.String()
colNameStr := getColumnNameString(e)
if col, ok := aliases.ColumnsByAlias[colNameStr]; ok {
colNameStr = col.ColumnName
}
_, tableSch, err := findSchemaForColumn(colNameStr, inputSchemas)
qc, err := resolveColumn(colNameStr, inputSchemas, aliases)
if err != nil {
return nil, err
}
column, _ := tableSch.GetAllCols().GetByName(colNameStr)
tableSch := inputSchemas[qc.TableName]
column, ok := tableSch.GetAllCols().GetByName(qc.ColumnName)
if !ok {
return nil, errFmt("Unknown column %v", colNameStr)
}
getter := valGetter{Kind: COLNAME, NomsKind: column.Kind}
getter.Init = func() error {
@@ -342,6 +378,16 @@ func getComparisonValueGetter(expr sqlparser.Expr, inputSchemas map[string]schem
}
}
func getColumnNameString(e *sqlparser.ColName) string {
var b strings.Builder
if !e.Qualifier.IsEmpty() {
b.WriteString(e.Qualifier.Name.String())
b.WriteString(".")
}
b.WriteString(e.Name.String())
return b.String()
}
// createFilter creates a filter function from the where clause given, or returns an error if it cannot
func createFilterForWhere(whereClause *sqlparser.Where, inputSchemas map[string]schema.Schema, aliases *Aliases) (rowFilterFn, error) {
if whereClause != nil && whereClause.Type != sqlparser.WhereStr {