diff --git a/go/libraries/doltcore/merge/merge_prolly_rows.go b/go/libraries/doltcore/merge/merge_prolly_rows.go index d5a2dee89e..5186eb1bb6 100644 --- a/go/libraries/doltcore/merge/merge_prolly_rows.go +++ b/go/libraries/doltcore/merge/merge_prolly_rows.go @@ -22,11 +22,7 @@ import ( "fmt" "io" - "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/analyzer" - "github.com/dolthub/go-mysql-server/sql/planbuilder" - "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" errorkinds "gopkg.in/src-d/go-errors.v1" @@ -35,7 +31,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/pool" "github.com/dolthub/dolt/go/store/prolly" @@ -370,7 +365,7 @@ func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch s continue } - expr, err := resolveExpression(ctx, check.Expression(), sch, tm.name) + expr, err := index.ResolveCheckExpression(ctx, tm.name, sch, check.Expression()) if err != nil { return checkValidator{}, err } @@ -433,7 +428,7 @@ func (cv checkValidator) validateDiff(ctx *sql.Context, diff tree.ThreeWayDiff) newTuple = val.NewTuple(cv.valueMerger.syncPool, newTupleBytes...) } - row, err := buildRow(ctx, diff.Key, newTuple, cv.sch, cv.tableMerger) + row, err := index.BuildRow(ctx, diff.Key, newTuple, cv.sch, cv.valueMerger.ns) if err != nil { return 0, err } @@ -487,54 +482,6 @@ func (cv checkValidator) insertArtifact(ctx context.Context, key, value val.Tupl return cv.edits.ReplaceConstraintViolation(ctx, key, cv.srcHash, prolly.ArtifactTypeChkConsViol, cvm) } -// buildRow takes the |key| and |value| tuple and returns a new sql.Row, along with any errors encountered. -func buildRow(ctx *sql.Context, key, value val.Tuple, sch schema.Schema, tableMerger *TableMerger) (sql.Row, error) { - pkCols := sch.GetPKCols() - valueCols := sch.GetNonPKCols() - allCols := sch.GetAllCols() - - // When we parse and resolve the check constraint expression with planbuilder, it leaves row position 0 - // for the expression itself, so we add an empty spot in index 0 of our row to account for that to make sure - // the GetField expressions' indexes match up to the right columns. - row := make(sql.Row, allCols.Size()+1) - - // Skip adding the key tuple if we're working with a keyless table, since the table row data is - // always all contained in the value tuple for keyless tables. - if !schema.IsKeyless(sch) { - keyDesc := sch.GetKeyDescriptor() - for i := range keyDesc.Types { - value, err := index.GetField(ctx, keyDesc, i, key, tableMerger.ns) - if err != nil { - return nil, err - } - - pkCol := pkCols.GetColumns()[i] - row[allCols.TagToIdx[pkCol.Tag]+1] = value - } - } - - valueColIndex := 0 - valueDescriptor := sch.GetValueDescriptor() - for valueTupleIndex := range valueDescriptor.Types { - // Skip processing the first value in the value tuple for keyless tables, since that field - // always holds the cardinality of the row and shouldn't be passed in to an expression. - if schema.IsKeyless(sch) && valueTupleIndex == 0 { - continue - } - - value, err := index.GetField(ctx, valueDescriptor, valueTupleIndex, value, tableMerger.ns) - if err != nil { - return nil, err - } - - col := valueCols.GetColumns()[valueColIndex] - row[allCols.TagToIdx[col.Tag]+1] = value - valueColIndex += 1 - } - - return row, nil -} - // uniqValidator checks whether new additions from the merge-right // duplicate secondary index entries. type uniqValidator struct { @@ -1238,41 +1185,6 @@ func (m *secondaryMerger) finalize(ctx context.Context) (durable.IndexSet, durab return m.leftSet, m.rightSet, nil } -// resolveExpression takes in a string |expression| and does basic resolution on it (e.g. column names and function -// names) so that the returned sql.Expression can be evaluated. The schema of the table is specified in |sch| and the -// name of the table in |tableName|. -func resolveExpression(ctx *sql.Context, expression string, sch schema.Schema, tableName string) (sql.Expression, error) { - query := fmt.Sprintf("SELECT %s from %s.%s", expression, "mydb", tableName) - sqlSch, err := sqlutil.FromDoltSchema("", tableName, sch) - if err != nil { - return nil, err - } - mockDatabase := memory.NewDatabase("mydb") - mockTable := memory.NewLocalTable(mockDatabase.BaseDatabase, tableName, sqlSch, nil) - mockDatabase.AddTable(tableName, mockTable) - mockProvider := memory.NewDBProvider(mockDatabase) - catalog := analyzer.NewCatalog(mockProvider) - - pseudoAnalyzedQuery, err := planbuilder.Parse(ctx, catalog, query) - if err != nil { - return nil, err - } - - var expr sql.Expression - transform.Inspect(pseudoAnalyzedQuery, func(n sql.Node) bool { - if projector, ok := n.(sql.Projector); ok { - expr = projector.ProjectedExprs()[0] - return false - } - return true - }) - if expr == nil { - return nil, fmt.Errorf("unable to find expression in analyzed query") - } - - return expr, nil -} - // remapTuple takes the given |tuple| and the |desc| that describes its data, and uses |mapping| to map the tuple's // data into a new [][]byte, as indicated by the specified ordinal mapping. func remapTuple(tuple val.Tuple, desc val.TupleDesc, mapping val.OrdinalMapping) [][]byte { diff --git a/go/libraries/doltcore/sqle/index/key_builder.go b/go/libraries/doltcore/sqle/index/key_builder.go index 52f346250e..96a56223b3 100644 --- a/go/libraries/doltcore/sqle/index/key_builder.go +++ b/go/libraries/doltcore/sqle/index/key_builder.go @@ -67,7 +67,7 @@ func NewSecondaryKeyBuilder(ctx context.Context, tableName string, sch schema.Sc sqlCtx = sql.NewContext(ctx) } - expr, err := resolveDefaultExpression(sqlCtx, col, sch, tableName) + expr, err := ResolveDefaultExpression(sqlCtx, tableName, sch, col) if err != nil { return SecondaryKeyBuilder{}, err } @@ -93,20 +93,60 @@ func NewSecondaryKeyBuilder(ctx context.Context, tableName string, sch schema.Sc return b, nil } -// resolveDefaultExpression returns an sql.Expression for the column default or generated expression for the +// ResolveDefaultExpression returns a sql.Expression for the column default or generated expression for the // column provided -func resolveDefaultExpression(ctx *sql.Context, col schema.Column, sch schema.Schema, tableName string) (sql.Expression, error) { +func ResolveDefaultExpression(ctx *sql.Context, tableName string, sch schema.Schema, col schema.Column) (sql.Expression, error) { + ct, err := parseCreateTable(ctx, tableName, sch) + if err != nil { + return nil, err + } + + colIdx := ct.CreateSchema.Schema.IndexOfColName(col.Name) + if colIdx < 0 { + return nil, fmt.Errorf("unable to find column %s in analyzed query", col.Name) + } + + sqlCol := ct.CreateSchema.Schema[colIdx] + expr := sqlCol.Default + if expr == nil || expr.Expr == nil { + expr = sqlCol.Generated + } + + if expr == nil || expr.Expr == nil { + return nil, fmt.Errorf("unable to find default or generated expression") + } + + return expr.Expr, nil +} + +// ResolveCheckExpression returns a sql.Expression for the check provided +func ResolveCheckExpression(ctx *sql.Context, tableName string, sch schema.Schema, checkExpr string) (sql.Expression, error) { + ct, err := parseCreateTable(ctx, tableName, sch) + if err != nil { + return nil, err + } + + for _, check := range ct.Checks() { + if check.Expr.String() == checkExpr { + return check.Expr, nil + } + } + + return nil, fmt.Errorf("unable to find check expression") +} + +func parseCreateTable(ctx *sql.Context, tableName string, sch schema.Schema) (*plan.CreateTable, error) { createTable, err := sqlfmt.GenerateCreateTableStatement(tableName, sch, nil, nil) if err != nil { return nil, err } - + query := createTable sqlSch, err := sqlutil.FromDoltSchema("", tableName, sch) if err != nil { return nil, err } - + mockDatabase := memory.NewDatabase("mydb") mockTable := memory.NewLocalTable(mockDatabase.BaseDatabase, tableName, sqlSch, nil) mockDatabase.AddTable(tableName, mockTable) @@ -122,23 +162,7 @@ func resolveDefaultExpression(ctx *sql.Context, col schema.Column, sch schema.Sc if !ok { return nil, fmt.Errorf("expected a *plan.CreateTable node, but got %T", pseudoAnalyzedQuery) } - - colIdx := ct.CreateSchema.Schema.IndexOfColName(col.Name) - if colIdx == -1 { - return nil, fmt.Errorf("unable to find column %s in analyzed query", col.Name) - } - - sqlCol := ct.CreateSchema.Schema[colIdx] - expr := sqlCol.Default - if expr == nil || expr.Expr == nil { - expr = sqlCol.Generated - } - - if expr == nil || expr.Expr == nil { - return nil, fmt.Errorf("unable to find default or generated expression") - } - - return expr.Expr, nil + return ct, nil } type SecondaryKeyBuilder struct { @@ -170,7 +194,7 @@ func (b SecondaryKeyBuilder) SecondaryKeyFromRow(ctx context.Context, k, v val.T sqlCtx = sql.NewContext(ctx) } - sqlRow, err := buildRow(sqlCtx, k, v, b.sch, b.nodeStore) + sqlRow, err := BuildRow(sqlCtx, k, v, b.sch, b.nodeStore) if err != nil { return nil, err } @@ -220,18 +244,13 @@ func (b SecondaryKeyBuilder) SecondaryKeyFromRow(ctx context.Context, k, v val.T return b.builder.Build(b.pool), nil } -// buildRow returns a row for the given key/value tuple pair -func buildRow(ctx *sql.Context, key, value val.Tuple, sch schema.Schema, ns tree.NodeStore) (sql.Row, error) { - prollyRowIter := prolly.NewPointLookup(key, value) - iter, err := NewProllyRowIterForSchema(sch, prollyRowIter, sch.GetKeyDescriptor(), sch.GetValueDescriptor(), sch.GetAllCols().Tags, ns) - if err != nil { - return nil, err - } - - return iter.Next(ctx) +// BuildRow returns a sql.Row for the given key/value tuple pair +func BuildRow(ctx *sql.Context, key, value val.Tuple, sch schema.Schema, ns tree.NodeStore) (sql.Row, error) { + prollyIter := prolly.NewPointLookup(key, value) + rowIter := NewProllyRowIterForSchema(sch, prollyIter, sch.GetKeyDescriptor(), sch.GetValueDescriptor(), sch.GetAllCols().Tags, ns) + return rowIter.Next(ctx) } - // canCopyRawBytes returns true if the bytes for |idxField| can // be copied directly. This is a faster way to populate an index // but requires that no data transformation is needed. For example, diff --git a/go/libraries/doltcore/sqle/index/prolly_row_iter.go b/go/libraries/doltcore/sqle/index/prolly_row_iter.go index c00be1efc9..42dc942833 100644 --- a/go/libraries/doltcore/sqle/index/prolly_row_iter.go +++ b/go/libraries/doltcore/sqle/index/prolly_row_iter.go @@ -39,7 +39,7 @@ type prollyRowIter struct { var _ sql.RowIter = prollyRowIter{} -func NewProllyRowIterForMap(sch schema.Schema, rows prolly.Map, iter prolly.MapIter, projections []uint64) (sql.RowIter, error) { +func NewProllyRowIterForMap(sch schema.Schema, rows prolly.Map, iter prolly.MapIter, projections []uint64) sql.RowIter { if projections == nil { projections = sch.GetAllCols().Tags } @@ -57,7 +57,7 @@ func NewProllyRowIterForSchema( vd val.TupleDesc, projections []uint64, ns tree.NodeStore, -) (sql.RowIter, error) { +) sql.RowIter { if schema.IsKeyless(sch) { return NewKeylessProllyRowIter(sch, iter, vd, projections, ns) } @@ -72,7 +72,7 @@ func NewKeyedProllyRowIter( vd val.TupleDesc, projections []uint64, ns tree.NodeStore, -) (sql.RowIter, error) { +) sql.RowIter { keyProj, valProj, ordProj := projectionMappings(sch, projections) return prollyRowIter{ @@ -84,7 +84,7 @@ func NewKeyedProllyRowIter( ordProj: ordProj, rowLen: len(projections), ns: ns, - }, nil + } } func NewKeylessProllyRowIter( @@ -93,7 +93,7 @@ func NewKeylessProllyRowIter( vd val.TupleDesc, projections []uint64, ns tree.NodeStore, -) (sql.RowIter, error) { +) sql.RowIter { _, valProj, ordProj := projectionMappings(sch, projections) return &prollyKeylessIter{ @@ -103,7 +103,7 @@ func NewKeylessProllyRowIter( ordProj: ordProj, rowLen: len(projections), ns: ns, - }, nil + } } // projectionMappings returns data structures that specify 1) which fields we read diff --git a/go/libraries/doltcore/sqle/rows.go b/go/libraries/doltcore/sqle/rows.go index cbd75848a6..73689d3cdd 100644 --- a/go/libraries/doltcore/sqle/rows.go +++ b/go/libraries/doltcore/sqle/rows.go @@ -255,7 +255,7 @@ func DoltTablePartitionToRowIter(ctx *sql.Context, name string, table *doltdb.Ta if err != nil { return nil, nil, err } - rowIter, err := index.NewProllyRowIterForMap(sch, idx, iter, nil) + rowIter := index.NewProllyRowIterForMap(sch, idx, iter, nil) if err != nil { return nil, nil, err } diff --git a/go/libraries/doltcore/table/table_iterator.go b/go/libraries/doltcore/table/table_iterator.go index b331dcd0a2..e771539345 100644 --- a/go/libraries/doltcore/table/table_iterator.go +++ b/go/libraries/doltcore/table/table_iterator.go @@ -71,7 +71,7 @@ func NewTableIterator(ctx context.Context, sch schema.Schema, idx durable.Index, if err != nil { return nil, err } - rowItr, err = index.NewProllyRowIterForMap(sch, m, itr, nil) + rowItr = index.NewProllyRowIterForMap(sch, m, itr, nil) if err != nil { return nil, err }