From 97e2aba3d2ddb3baca1511b5927bbe92e34c912a Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 22 May 2023 11:27:23 -0700 Subject: [PATCH] Tidying up --- .../doltcore/merge/merge_prolly_rows.go | 77 ++++++++++--------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/go/libraries/doltcore/merge/merge_prolly_rows.go b/go/libraries/doltcore/merge/merge_prolly_rows.go index 1f227645dc..04ddf97764 100644 --- a/go/libraries/doltcore/merge/merge_prolly_rows.go +++ b/go/libraries/doltcore/merge/merge_prolly_rows.go @@ -366,7 +366,7 @@ type checkValidator struct { // about the table being merged, |vm| provides the details on how the value tuples are being merged between the ancestor, // right and left sides of the merge, |sch| provides the final schema of the merge, and |edits| is used to write // constraint validation artifacts. -func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sch schema.Schema, edits *prolly.ArtifactsEditor) (checkValidator, error) { +func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch schema.Schema, edits *prolly.ArtifactsEditor) (checkValidator, error) { checkExpressions := make(map[string]sql.Expression) checks := sch.Checks() @@ -375,35 +375,10 @@ func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sc continue } - // TODO: Hide in a Helper function! - query := fmt.Sprintf("SELECT %s from %s.%s", check.Expression(), "mydb", tm.name) - sqlSch, err := sqlutil.FromDoltSchema(tm.name, sch) + expr, err := resolveExpression(ctx, check.Expression(), sch, tm.name) if err != nil { return checkValidator{}, err } - mockTable := memory.NewTable(tm.name, sqlSch, nil) - mockDatabase := memory.NewDatabase("mydb") - mockDatabase.AddTable(tm.name, mockTable) - mockProvider := memory.NewDBProvider(mockDatabase) - catalog := analyzer.NewCatalog(mockProvider) - - pseudoAnalyzedQuery, err := planbuilder.Parse(sqlCtx, catalog, query) - if err != nil { - return checkValidator{}, err - } - - var expr sql.Expression - transform.Inspect(pseudoAnalyzedQuery, func(n sql.Node) bool { - if projector, ok := n.(sql.Projector); ok { - expr = projector.ProjectedExprs()[0] - return false - } - return true - }) - if expr == nil { - return checkValidator{}, fmt.Errorf("unable to find expression in analyzed query") - } - checkExpressions[check.Name()] = expr } @@ -425,11 +400,7 @@ func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sc // validateDiff inspects the three-way diff event |diff| and evaluates any check constraint expressions that need to // be rechecked after the merge. If any check constraint violations are detected, the violation count is returned as // the first return parameter and the violations are also written to the artifact editor passed in on creation. -func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDiff) (int, error) { - // TODO: This sql Context creation is expensive; do this higher up so we don't have to recreate this over and over - // Has this already been done even? Can we just change the signature to sql.Context? - sqlCtx := sql.NewContext(ctx) - +func (cv checkValidator) validateDiff(ctx *sql.Context, diff tree.ThreeWayDiff) (int, error) { conflictCount := 0 for checkName, checkExpression := range cv.checkExpressions { @@ -442,7 +413,6 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif valueTuple = diff.Left valueDesc = cv.tableMerger.leftSch.GetValueDescriptor() case tree.DiffOpRightAdd, tree.DiffOpRightDelete, tree.DiffOpRightModify: - // TODO: Can this actually happen when we're always merging from right into left? valueTuple = diff.Right valueDesc = cv.tableMerger.rightSch.GetValueDescriptor() case tree.DiffOpConvergentAdd, tree.DiffOpConvergentDelete, tree.DiffOpConvergentModify: @@ -474,7 +444,7 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif row = append(row, nil) keyDesc := cv.sch.GetKeyDescriptor() for i := range keyDesc.Types { - value, err := index.GetField(sqlCtx, keyDesc, i, diff.Key, cv.tableMerger.ns) + value, err := index.GetField(ctx, keyDesc, i, diff.Key, cv.tableMerger.ns) if err != nil { return 0, err } @@ -482,14 +452,14 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif } for i := range cv.sch.GetNonPKCols().GetColumns() { - value, err := index.GetField(sqlCtx, cv.sch.GetValueDescriptor(), i, newTuple, cv.tableMerger.ns) + value, err := index.GetField(ctx, cv.sch.GetValueDescriptor(), i, newTuple, cv.tableMerger.ns) if err != nil { return 0, err } row = append(row, value) } - result, err := checkExpression.Eval(sqlCtx, row) + result, err := checkExpression.Eval(ctx, row) if err != nil { return 0, err } @@ -1048,6 +1018,41 @@ func (m *secondaryMerger) finalize(ctx context.Context) (durable.IndexSet, durab return m.leftSet, m.rightSet, nil } +// resolveExpression takes in a string |expression| and does basic resolution on it (e.g. column names and function +// names) so that the returned sql.Expression can be evaluated. The schema of the table is specified in |sch| and the +// name of the table in |tableName|. +func resolveExpression(ctx *sql.Context, expression string, sch schema.Schema, tableName string) (sql.Expression, error) { + query := fmt.Sprintf("SELECT %s from %s.%s", expression, "mydb", tableName) + sqlSch, err := sqlutil.FromDoltSchema(tableName, sch) + if err != nil { + return nil, err + } + mockTable := memory.NewTable(tableName, sqlSch, nil) + mockDatabase := memory.NewDatabase("mydb") + mockDatabase.AddTable(tableName, mockTable) + mockProvider := memory.NewDBProvider(mockDatabase) + catalog := analyzer.NewCatalog(mockProvider) + + pseudoAnalyzedQuery, err := planbuilder.Parse(ctx, catalog, query) + if err != nil { + return nil, err + } + + var expr sql.Expression + transform.Inspect(pseudoAnalyzedQuery, func(n sql.Node) bool { + if projector, ok := n.(sql.Projector); ok { + expr = projector.ProjectedExprs()[0] + return false + } + return true + }) + if expr == nil { + return nil, fmt.Errorf("unable to find expression in analyzed query") + } + + return expr, nil +} + // remapTuple takes the given |tuple| and the |desc| that describes its data, and uses |mapping| to map the tuple's // data into a new [][]byte, as indicated by the specified ordinal mapping. func remapTuple(tuple val.Tuple, desc val.TupleDesc, mapping val.OrdinalMapping) [][]byte {