mirror of
https://github.com/dolthub/dolt.git
synced 2026-03-11 01:55:08 -05:00
Tidying up
This commit is contained in:
@@ -366,7 +366,7 @@ type checkValidator struct {
|
||||
// about the table being merged, |vm| provides the details on how the value tuples are being merged between the ancestor,
|
||||
// right and left sides of the merge, |sch| provides the final schema of the merge, and |edits| is used to write
|
||||
// constraint validation artifacts.
|
||||
func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sch schema.Schema, edits *prolly.ArtifactsEditor) (checkValidator, error) {
|
||||
func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch schema.Schema, edits *prolly.ArtifactsEditor) (checkValidator, error) {
|
||||
checkExpressions := make(map[string]sql.Expression)
|
||||
|
||||
checks := sch.Checks()
|
||||
@@ -375,35 +375,10 @@ func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sc
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Hide in a Helper function!
|
||||
query := fmt.Sprintf("SELECT %s from %s.%s", check.Expression(), "mydb", tm.name)
|
||||
sqlSch, err := sqlutil.FromDoltSchema(tm.name, sch)
|
||||
expr, err := resolveExpression(ctx, check.Expression(), sch, tm.name)
|
||||
if err != nil {
|
||||
return checkValidator{}, err
|
||||
}
|
||||
mockTable := memory.NewTable(tm.name, sqlSch, nil)
|
||||
mockDatabase := memory.NewDatabase("mydb")
|
||||
mockDatabase.AddTable(tm.name, mockTable)
|
||||
mockProvider := memory.NewDBProvider(mockDatabase)
|
||||
catalog := analyzer.NewCatalog(mockProvider)
|
||||
|
||||
pseudoAnalyzedQuery, err := planbuilder.Parse(sqlCtx, catalog, query)
|
||||
if err != nil {
|
||||
return checkValidator{}, err
|
||||
}
|
||||
|
||||
var expr sql.Expression
|
||||
transform.Inspect(pseudoAnalyzedQuery, func(n sql.Node) bool {
|
||||
if projector, ok := n.(sql.Projector); ok {
|
||||
expr = projector.ProjectedExprs()[0]
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if expr == nil {
|
||||
return checkValidator{}, fmt.Errorf("unable to find expression in analyzed query")
|
||||
}
|
||||
|
||||
checkExpressions[check.Name()] = expr
|
||||
}
|
||||
|
||||
@@ -425,11 +400,7 @@ func newCheckValidator(sqlCtx *sql.Context, tm *TableMerger, vm *valueMerger, sc
|
||||
// validateDiff inspects the three-way diff event |diff| and evaluates any check constraint expressions that need to
|
||||
// be rechecked after the merge. If any check constraint violations are detected, the violation count is returned as
|
||||
// the first return parameter and the violations are also written to the artifact editor passed in on creation.
|
||||
func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDiff) (int, error) {
|
||||
// TODO: This sql Context creation is expensive; do this higher up so we don't have to recreate this over and over
|
||||
// Has this already been done even? Can we just change the signature to sql.Context?
|
||||
sqlCtx := sql.NewContext(ctx)
|
||||
|
||||
func (cv checkValidator) validateDiff(ctx *sql.Context, diff tree.ThreeWayDiff) (int, error) {
|
||||
conflictCount := 0
|
||||
|
||||
for checkName, checkExpression := range cv.checkExpressions {
|
||||
@@ -442,7 +413,6 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif
|
||||
valueTuple = diff.Left
|
||||
valueDesc = cv.tableMerger.leftSch.GetValueDescriptor()
|
||||
case tree.DiffOpRightAdd, tree.DiffOpRightDelete, tree.DiffOpRightModify:
|
||||
// TODO: Can this actually happen when we're always merging from right into left?
|
||||
valueTuple = diff.Right
|
||||
valueDesc = cv.tableMerger.rightSch.GetValueDescriptor()
|
||||
case tree.DiffOpConvergentAdd, tree.DiffOpConvergentDelete, tree.DiffOpConvergentModify:
|
||||
@@ -474,7 +444,7 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif
|
||||
row = append(row, nil)
|
||||
keyDesc := cv.sch.GetKeyDescriptor()
|
||||
for i := range keyDesc.Types {
|
||||
value, err := index.GetField(sqlCtx, keyDesc, i, diff.Key, cv.tableMerger.ns)
|
||||
value, err := index.GetField(ctx, keyDesc, i, diff.Key, cv.tableMerger.ns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -482,14 +452,14 @@ func (cv checkValidator) validateDiff(ctx context.Context, diff tree.ThreeWayDif
|
||||
}
|
||||
|
||||
for i := range cv.sch.GetNonPKCols().GetColumns() {
|
||||
value, err := index.GetField(sqlCtx, cv.sch.GetValueDescriptor(), i, newTuple, cv.tableMerger.ns)
|
||||
value, err := index.GetField(ctx, cv.sch.GetValueDescriptor(), i, newTuple, cv.tableMerger.ns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
row = append(row, value)
|
||||
}
|
||||
|
||||
result, err := checkExpression.Eval(sqlCtx, row)
|
||||
result, err := checkExpression.Eval(ctx, row)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -1048,6 +1018,41 @@ func (m *secondaryMerger) finalize(ctx context.Context) (durable.IndexSet, durab
|
||||
return m.leftSet, m.rightSet, nil
|
||||
}
|
||||
|
||||
// resolveExpression takes in a string |expression| and does basic resolution on it (e.g. column names and function
|
||||
// names) so that the returned sql.Expression can be evaluated. The schema of the table is specified in |sch| and the
|
||||
// name of the table in |tableName|.
|
||||
func resolveExpression(ctx *sql.Context, expression string, sch schema.Schema, tableName string) (sql.Expression, error) {
|
||||
query := fmt.Sprintf("SELECT %s from %s.%s", expression, "mydb", tableName)
|
||||
sqlSch, err := sqlutil.FromDoltSchema(tableName, sch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockTable := memory.NewTable(tableName, sqlSch, nil)
|
||||
mockDatabase := memory.NewDatabase("mydb")
|
||||
mockDatabase.AddTable(tableName, mockTable)
|
||||
mockProvider := memory.NewDBProvider(mockDatabase)
|
||||
catalog := analyzer.NewCatalog(mockProvider)
|
||||
|
||||
pseudoAnalyzedQuery, err := planbuilder.Parse(ctx, catalog, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var expr sql.Expression
|
||||
transform.Inspect(pseudoAnalyzedQuery, func(n sql.Node) bool {
|
||||
if projector, ok := n.(sql.Projector); ok {
|
||||
expr = projector.ProjectedExprs()[0]
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if expr == nil {
|
||||
return nil, fmt.Errorf("unable to find expression in analyzed query")
|
||||
}
|
||||
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
// remapTuple takes the given |tuple| and the |desc| that describes its data, and uses |mapping| to map the tuple's
|
||||
// data into a new [][]byte, as indicated by the specified ordinal mapping.
|
||||
func remapTuple(tuple val.Tuple, desc val.TupleDesc, mapping val.OrdinalMapping) [][]byte {
|
||||
|
||||
Reference in New Issue
Block a user