Fixed query differ to work with subquery expressions.

Signed-off-by: Zach Musgrave <zach@liquidata.co>
This commit is contained in:
Zach Musgrave
2020-08-06 15:59:34 -07:00
parent 8a314ecf07
commit e05a5f5bdc
2 changed files with 85 additions and 60 deletions

View File

@@ -36,6 +36,7 @@ func lazyQueryPlan(node sql.Node) (lazyNode sql.Node, projections []sql.Expressi
}
offset := 0
originalChildrenSchema := schema(node.Children())
lazyChildren := make([]sql.Node, len(children))
for i, c := range children {
c, pjs, ord, err := lazyQueryPlan(c)
@@ -57,8 +58,8 @@ func lazyQueryPlan(node sql.Node) (lazyNode sql.Node, projections []sql.Expressi
return nil, nil, nil, err
}
lazyNode, err = plan.TransformExpressions(node, func(e sql.Expression) (sql.Expression, error) {
return makeExpressionLazy(e, projections)
lazyNode, err = plan.TransformExpressionsWithNode(node, func(n sql.Node, e sql.Expression) (sql.Expression, error) {
return makeExpressionLazy(node, originalChildrenSchema, e, projections)
})
if err != nil {
return nil, nil, nil, err
@@ -83,6 +84,14 @@ func lazyQueryPlan(node sql.Node) (lazyNode sql.Node, projections []sql.Expressi
return lazyNode, projections, order, nil
}
func schema(nodes []sql.Node) sql.Schema {
var schema sql.Schema
for _, node := range nodes {
schema = append(schema, node.Schema()...)
}
return schema
}
// wrapGroupBy wraps a GroupBy node in a Sort node so its output can be ordered in query diffs.
func wrapGroupBy(g *plan.GroupBy) (node sql.Node, projections []sql.Expression, order []plan.SortField) {
projections = make([]sql.Expression, len(g.Schema()))
@@ -102,14 +111,36 @@ func wrapGroupBy(g *plan.GroupBy) (node sql.Node, projections []sql.Expression,
return plan.NewSort(order, g), projections, order
}
func makeExpressionLazy(e sql.Expression, composite []sql.Expression) (sql.Expression, error) {
gf, ok := e.(*expression.GetField)
if ok {
if gf.Index() >= len(composite) {
func makeExpressionLazy(node sql.Node, originalChildSchema sql.Schema, e sql.Expression, exprs []sql.Expression) (sql.Expression, error) {
if gf, ok := e.(*expression.GetField); ok {
if gf.Index() >= len(exprs) {
return nil, fmt.Errorf("index out of bounds in lazy expression substitution")
}
e = composite[gf.Index()]
return exprs[gf.Index()], nil
}
// For subqueries, we need to apply the same lazy substitution to any expressions in the outer scope, and then shift
// the indexes of the inner scope to handle any erased projections.
if s, ok := e.(*plan.Subquery); ok {
childSchema := schema(node.Children())
newSubquery, err := plan.TransformExpressionsUp(s.Query, func(e sql.Expression) (sql.Expression, error) {
if gf, ok := e.(*expression.GetField); ok {
if gf.Index() < len(exprs) {
return exprs[gf.Index()], nil
} else {
// Part of the inner scope, shift indexes
offset := len(childSchema) - len(originalChildSchema)
return shiftFieldIndices(offset, e)[0], nil
}
}
return e, nil
})
if err != nil {
return nil, err
}
return s.WithQuery(newSubquery), nil
}
return e, nil
}
@@ -136,9 +167,9 @@ func getOrderForTable(tbl sql.Table) (order []plan.SortField) {
return order
}
func shiftFieldIndices(offset int, composite ...sql.Expression) []sql.Expression {
shifted := make([]sql.Expression, len(composite))
for i, e := range composite {
func shiftFieldIndices(offset int, exprs ...sql.Expression) []sql.Expression {
shifted := make([]sql.Expression, len(exprs))
for i, e := range exprs {
shifted[i], _ = expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
if gf, ok := e.(*expression.GetField); ok {
return gf.WithIndex(gf.Index() + offset), nil

View File

@@ -36,12 +36,6 @@ import (
det "github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle/enginetest"
)
func TestQueryDiffer(t *testing.T) {
testQueryDiffer(t)
TestEngineTestQueryDifferBefore(t)
EngineTestQueryDifferAfter(t)
}
type queryDifferTest struct {
name string
query string
@@ -318,49 +312,6 @@ var groupByTests = []queryDifferTest{
},
}
func testQueryDiffer(t *testing.T) {
inner := func(t *testing.T, test queryDifferTest) {
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
for _, c := range queryDifferTestSetup {
exitCode := c.cmd.Exec(ctx, c.cmd.Name(), c.args, dEnv)
assert.Equal(t, 0, exitCode)
}
for _, c := range test.setup {
exitCode := c.cmd.Exec(ctx, c.cmd.Name(), c.args, dEnv)
assert.Equal(t, 0, exitCode)
}
fromRoot, err := dEnv.HeadRoot(ctx)
require.NoError(t, err)
toRoot, err := dEnv.WorkingRoot(ctx)
require.NoError(t, err)
qd, err := querydiff.MakeQueryDiffer(ctx, dEnv, fromRoot, toRoot, test.query)
require.NoError(t, err)
qd.Start()
for _, expected := range test.diffRows {
from, to, err := qd.NextDiff()
assert.NoError(t, err)
assert.Equal(t, expected.from, from)
assert.Equal(t, expected.to, to)
}
from, to, err := qd.NextDiff()
assert.Nil(t, from)
assert.Nil(t, to)
assert.Equal(t, io.EOF, err)
}
for _, testSet := range queryDiffTests {
for _, test := range testSet {
t.Run(test.name, func(t *testing.T) {
inner(t, test)
})
}
}
}
var engineTestSetup = []testCommand{
{commands.SqlCmd{}, []string{"-q", "create table mytable (" +
"i bigint primary key," +
@@ -542,6 +493,49 @@ func skipEngineTest(test enginetest.QueryTest) bool {
return false
}
func TestQueryDiffer(t *testing.T) {
inner := func(t *testing.T, test queryDifferTest) {
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
for _, c := range queryDifferTestSetup {
exitCode := c.cmd.Exec(ctx, c.cmd.Name(), c.args, dEnv)
assert.Equal(t, 0, exitCode)
}
for _, c := range test.setup {
exitCode := c.cmd.Exec(ctx, c.cmd.Name(), c.args, dEnv)
assert.Equal(t, 0, exitCode)
}
fromRoot, err := dEnv.HeadRoot(ctx)
require.NoError(t, err)
toRoot, err := dEnv.WorkingRoot(ctx)
require.NoError(t, err)
qd, err := querydiff.MakeQueryDiffer(ctx, dEnv, fromRoot, toRoot, test.query)
require.NoError(t, err)
qd.Start()
for _, expected := range test.diffRows {
from, to, err := qd.NextDiff()
assert.NoError(t, err)
assert.Equal(t, expected.from, from)
assert.Equal(t, expected.to, to)
}
from, to, err := qd.NextDiff()
assert.Nil(t, from)
assert.Nil(t, to)
assert.Equal(t, io.EOF, err)
}
for _, testSet := range queryDiffTests {
for _, test := range testSet {
t.Run(test.name, func(t *testing.T) {
inner(t, test)
})
}
}
}
func TestEngineTestQueryDifferBefore(t *testing.T) {
inner := func(t *testing.T, test enginetest.QueryTest, dEnv *env.DoltEnv) {
if skipEngineTest(test) {
@@ -593,7 +587,7 @@ func TestEngineTestQueryDifferBefore(t *testing.T) {
}
}
func EngineTestQueryDifferAfter(t *testing.T) {
func TestEngineTestQueryDifferAfter(t *testing.T) {
inner := func(t *testing.T, test enginetest.QueryTest, dEnv *env.DoltEnv) {
if skipEngineTest(test) {
t.Skip()