Merge pull request #4646 from dolthub/taylor/diff-summary-dot-sql

Support two and three dot syntax in `dolt_diff_summary` table function
This commit is contained in:
Taylor Bantle
2022-10-28 13:06:50 -07:00
committed by GitHub
3 changed files with 218 additions and 38 deletions

View File

@@ -18,6 +18,7 @@ import (
"fmt"
"io"
"math"
"strings"
"github.com/dolthub/go-mysql-server/sql"
@@ -34,6 +35,7 @@ type DiffSummaryTableFunction struct {
fromCommitExpr sql.Expression
toCommitExpr sql.Expression
dotCommitExpr sql.Expression
tableNameExpr sql.Expression
database sql.Database
}
@@ -84,16 +86,29 @@ func (ds *DiffSummaryTableFunction) FunctionName() string {
return "dolt_diff_summary"
}
// Resolved implements the sql.Resolvable interface
func (ds *DiffSummaryTableFunction) Resolved() bool {
if ds.tableNameExpr != nil {
return ds.fromCommitExpr.Resolved() && ds.toCommitExpr.Resolved() && ds.tableNameExpr.Resolved()
func (ds *DiffSummaryTableFunction) commitsResolved() bool {
if ds.dotCommitExpr != nil {
return ds.dotCommitExpr.Resolved()
}
return ds.fromCommitExpr.Resolved() && ds.toCommitExpr.Resolved()
}
// Resolved implements the sql.Resolvable interface
func (ds *DiffSummaryTableFunction) Resolved() bool {
if ds.tableNameExpr != nil {
return ds.commitsResolved() && ds.tableNameExpr.Resolved()
}
return ds.commitsResolved()
}
// String implements the Stringer interface
func (ds *DiffSummaryTableFunction) String() string {
if ds.dotCommitExpr != nil {
if ds.tableNameExpr != nil {
return fmt.Sprintf("DOLT_DIFF_SUMMARY(%s, %s)", ds.dotCommitExpr.String(), ds.tableNameExpr.String())
}
return fmt.Sprintf("DOLT_DIFF_SUMMARY(%s)", ds.dotCommitExpr.String())
}
if ds.tableNameExpr != nil {
return fmt.Sprintf("DOLT_DIFF_SUMMARY(%s, %s, %s)", ds.fromCommitExpr.String(), ds.toCommitExpr.String(), ds.tableNameExpr.String())
}
@@ -154,7 +169,12 @@ func (ds *DiffSummaryTableFunction) CheckPrivileges(ctx *sql.Context, opChecker
// Expressions implements the sql.Expressioner interface.
func (ds *DiffSummaryTableFunction) Expressions() []sql.Expression {
exprs := []sql.Expression{ds.fromCommitExpr, ds.toCommitExpr}
exprs := []sql.Expression{}
if ds.dotCommitExpr != nil {
exprs = append(exprs, ds.dotCommitExpr)
} else {
exprs = append(exprs, ds.fromCommitExpr, ds.toCommitExpr)
}
if ds.tableNameExpr != nil {
exprs = append(exprs, ds.tableNameExpr)
}
@@ -163,8 +183,8 @@ func (ds *DiffSummaryTableFunction) Expressions() []sql.Expression {
// WithExpressions implements the sql.Expressioner interface.
func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 2 || len(expression) > 3 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "2 or 3", len(expression))
if len(expression) < 1 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 to 3", len(expression))
}
for _, expr := range expression {
@@ -173,19 +193,37 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
}
}
ds.fromCommitExpr = expression[0]
ds.toCommitExpr = expression[1]
if len(expression) == 3 {
ds.tableNameExpr = expression[2]
if strings.Contains(expression[0].String(), "..") {
if len(expression) < 1 || len(expression) > 2 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 or 2", len(expression))
}
ds.dotCommitExpr = expression[0]
if len(expression) == 2 {
ds.tableNameExpr = expression[1]
}
} else {
if len(expression) < 2 || len(expression) > 3 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "2 or 3", len(expression))
}
ds.fromCommitExpr = expression[0]
ds.toCommitExpr = expression[1]
if len(expression) == 3 {
ds.tableNameExpr = expression[2]
}
}
// validate the expressions
if !sql.IsText(ds.fromCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.fromCommitExpr.String())
}
if !sql.IsText(ds.toCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.toCommitExpr.String())
if ds.dotCommitExpr != nil {
if !sql.IsText(ds.dotCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.dotCommitExpr.String())
}
} else {
if !sql.IsText(ds.fromCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.fromCommitExpr.String())
}
if !sql.IsText(ds.toCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.toCommitExpr.String())
}
}
if ds.tableNameExpr != nil {
@@ -199,7 +237,7 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
// RowIter implements the sql.Node interface
func (ds *DiffSummaryTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
fromCommitVal, toCommitVal, tableName, err := ds.evaluateArguments()
fromCommitVal, toCommitVal, dotCommitVal, tableName, err := ds.evaluateArguments()
if err != nil {
return nil, err
}
@@ -209,13 +247,18 @@ func (ds *DiffSummaryTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.
return nil, fmt.Errorf("unexpected database type: %T", ds.database)
}
sess := dsess.DSessFromSess(ctx.Session)
fromRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), fromCommitVal)
fromCommitStr, toCommitStr, err := loadCommitStrings(ctx, fromCommitVal, toCommitVal, dotCommitVal, sqledb)
if err != nil {
return nil, err
}
toRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), toCommitVal)
sess := dsess.DSessFromSess(ctx.Session)
fromRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), fromCommitStr)
if err != nil {
return nil, err
}
toRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), toCommitStr)
if err != nil {
return nil, err
}
@@ -256,42 +299,43 @@ func (ds *DiffSummaryTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.
return NewDiffSummaryTableFunctionRowIter(diffSummaries), nil
}
// evaluateArguments returns fromCommitValStr, toCommitValStr and tableName.
// It evaluates the argument expressions to turn them into values this DiffTableFunction
// evaluateArguments returns fromCommitVal, toCommitVal, dotCommitVal, and tableName.
// It evaluates the argument expressions to turn them into values this DiffSummaryTableFunction
// can use. Note that this method only evals the expressions, and doesn't validate the values.
func (ds *DiffSummaryTableFunction) evaluateArguments() (string, string, string, error) {
func (ds *DiffSummaryTableFunction) evaluateArguments() (interface{}, interface{}, interface{}, string, error) {
var tableName string
if ds.tableNameExpr != nil {
tableNameVal, err := ds.tableNameExpr.Eval(ds.ctx, nil)
if err != nil {
return "", "", "", err
return nil, nil, nil, "", err
}
tn, ok := tableNameVal.(string)
if !ok {
return "", "", "", ErrInvalidTableName.New(ds.tableNameExpr.String())
return nil, nil, nil, "", ErrInvalidTableName.New(ds.tableNameExpr.String())
}
tableName = tn
}
if ds.dotCommitExpr != nil {
dotCommitVal, err := ds.dotCommitExpr.Eval(ds.ctx, nil)
if err != nil {
return nil, nil, nil, "", err
}
return nil, nil, dotCommitVal, tableName, nil
}
fromCommitVal, err := ds.fromCommitExpr.Eval(ds.ctx, nil)
if err != nil {
return "", "", "", err
}
fromCommitValStr, ok := fromCommitVal.(string)
if !ok {
return "", "", "", fmt.Errorf("received '%v' when expecting commit hash string", fromCommitVal)
return nil, nil, nil, "", err
}
toCommitVal, err := ds.toCommitExpr.Eval(ds.ctx, nil)
if err != nil {
return "", "", "", err
}
toCommitValStr, ok := toCommitVal.(string)
if !ok {
return "", "", "", fmt.Errorf("received '%v' when expecting commit hash string", toCommitVal)
return nil, nil, nil, "", err
}
return fromCommitValStr, toCommitValStr, tableName, nil
return fromCommitVal, toCommitVal, nil, tableName, nil
}
// getDiffSummaryNodeFromDelta returns diffSummaryNode object and whether there is data diff or not. It gets tables

View File

@@ -94,6 +94,10 @@ func (dtf *DiffTableFunction) Expressions() []sql.Expression {
// WithExpressions implements the sql.Expressioner interface
func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 2 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf.FunctionName(), "2 to 3", len(expression))
}
// TODO: For now, we will only support literal / fully-resolved arguments to the
// DiffTableFunction to avoid issues where the schema is needed in the analyzer
// before the arguments could be resolved.

View File

@@ -781,6 +781,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{
Query: "SELECT * FROM dolt_diff_summary('main~', 'main', 'test');",
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
},
{
// Without access to the database, dolt_diff_summary with dots should fail with a database access error
User: "tester",
Host: "localhost",
Query: "SELECT * FROM dolt_diff_summary('main~..main', 'test');",
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
},
{
// Without access to the database, dolt_log should fail with a database access error
User: "tester",
@@ -830,6 +837,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{
Query: "SELECT * FROM dolt_diff_summary('main~', 'main', 'test2');",
ExpectedErr: sql.ErrPrivilegeCheckFailed,
},
{
// With access to the db, but not the table, dolt_diff_summary with dots should fail
User: "tester",
Host: "localhost",
Query: "SELECT * FROM dolt_diff_summary('main~...main', 'test2');",
ExpectedErr: sql.ErrPrivilegeCheckFailed,
},
{
// With access to the db, dolt_diff_summary should fail for all tables if no access any of tables
User: "tester",
@@ -837,6 +851,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{
Query: "SELECT * FROM dolt_diff_summary('main~', 'main');",
ExpectedErr: sql.ErrPrivilegeCheckFailed,
},
{
// With access to the db, dolt_diff_summary with dots should fail for all tables if no access any of tables
User: "tester",
Host: "localhost",
Query: "SELECT * FROM dolt_diff_summary('main~...main');",
ExpectedErr: sql.ErrPrivilegeCheckFailed,
},
{
// Revoke select on mydb.test
User: "root",
@@ -886,6 +907,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{
Query: "SELECT COUNT(*) FROM dolt_diff_summary('main~', 'main');",
Expected: []sql.Row{{1}},
},
{
// After granting access to the entire db, dolt_diff_summary with dots should work
User: "tester",
Host: "localhost",
Query: "SELECT COUNT(*) FROM dolt_diff_summary('main~...main');",
Expected: []sql.Row{{1}},
},
{
// After granting access to the entire db, dolt_log should work
User: "tester",
@@ -4474,6 +4502,10 @@ var DiffTableFunctionScriptTests = []queries.ScriptTest{
"set @Commit2 = dolt_commit('-am', 'inserting into t');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * from dolt_diff();",
ExpectedErr: sql.ErrInvalidArgumentNumber,
},
{
Query: "SELECT * from dolt_diff('t');",
ExpectedErr: sql.ErrInvalidArgumentNumber,
@@ -5621,6 +5653,10 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
"set @Commit2 = dolt_commit('-am', 'inserting into t');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * from dolt_diff_summary();",
ExpectedErr: sql.ErrInvalidArgumentNumber,
},
{
Query: "SELECT * from dolt_diff_summary('t');",
ExpectedErr: sql.ErrInvalidArgumentNumber,
@@ -5649,14 +5685,26 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT * from dolt_diff_summary('fake-branch', @Commit2, 't');",
ExpectedErrStr: "branch not found: fake-branch",
},
{
Query: "SELECT * from dolt_diff_summary('fake-branch..main', 't');",
ExpectedErrStr: "branch not found: fake-branch",
},
{
Query: "SELECT * from dolt_diff_summary(@Commit1, 'fake-branch', 't');",
ExpectedErrStr: "branch not found: fake-branch",
},
{
Query: "SELECT * from dolt_diff_summary('main..fake-branch', 't');",
ExpectedErrStr: "branch not found: fake-branch",
},
{
Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 'doesnotexist');",
ExpectedErr: sql.ErrTableNotFound,
},
{
Query: "SELECT * from dolt_diff_summary('main^..main', 'doesnotexist');",
ExpectedErr: sql.ErrTableNotFound,
},
{
Query: "SELECT * from dolt_diff_summary(@Commit1, concat('fake', '-', 'branch'), 't');",
ExpectedErr: sqle.ErrInvalidNonLiteralArgument,
@@ -5669,6 +5717,10 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, LOWER('T'));",
ExpectedErr: sqle.ErrInvalidNonLiteralArgument,
},
{
Query: "SELECT * from dolt_diff_summary('main..main~', LOWER('T'));",
ExpectedErr: sqle.ErrInvalidNonLiteralArgument,
},
},
},
{
@@ -5872,6 +5924,10 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT * from dolt_diff_summary('STAGED', 'WORKING', 't')",
Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('STAGED..WORKING', 't')",
Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('WORKING', 'STAGED', 't')",
Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}},
@@ -5880,6 +5936,10 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT * from dolt_diff_summary('WORKING', 'WORKING', 't')",
Expected: []sql.Row{},
},
{
Query: "SELECT * from dolt_diff_summary('WORKING..WORKING', 't')",
Expected: []sql.Row{},
},
{
Query: "SELECT * from dolt_diff_summary('STAGED', 'STAGED', 't')",
Expected: []sql.Row{},
@@ -5921,20 +5981,83 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
"select dolt_checkout('main');",
"insert into t values (2, 'two', 'three');",
"set @Commit6 = dolt_commit('-am', 'inserting row 2 in main');",
"create table newtable (pk int primary key);",
"insert into newtable values (1), (2);",
"set @Commit7 = dolt_commit('-Am', 'new table newtable');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * from dolt_diff_summary('main', 'branch1', 't');",
Expected: []sql.Row{{"t", 0, 0, 1, 1, 0, 4, 0, 2, 1, 6, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('main..branch1', 't');",
Expected: []sql.Row{{"t", 0, 0, 1, 1, 0, 4, 0, 2, 1, 6, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('main', 'branch1');",
Expected: []sql.Row{
{"t", 0, 0, 1, 1, 0, 4, 0, 2, 1, 6, 2},
{"newtable", 0, 0, 2, 0, 0, 2, 0, 2, 0, 2, 0},
},
},
{
Query: "SELECT * from dolt_diff_summary('main..branch1');",
Expected: []sql.Row{
{"t", 0, 0, 1, 1, 0, 4, 0, 2, 1, 6, 2},
{"newtable", 0, 0, 2, 0, 0, 2, 0, 2, 0, 2, 0},
},
},
{
Query: "SELECT * from dolt_diff_summary('branch1', 'main', 't');",
Expected: []sql.Row{{"t", 0, 1, 0, 1, 4, 0, 1, 1, 2, 2, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('main~', 'branch1', 't');",
Query: "SELECT * from dolt_diff_summary('branch1..main', 't');",
Expected: []sql.Row{{"t", 0, 1, 0, 1, 4, 0, 1, 1, 2, 2, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('main~2', 'branch1', 't');",
Expected: []sql.Row{{"t", 0, 1, 1, 0, 2, 3, 0, 1, 1, 3, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('main~2..branch1', 't');",
Expected: []sql.Row{{"t", 0, 1, 1, 0, 2, 3, 0, 1, 1, 3, 2}},
},
// Three dot
{
Query: "SELECT * from dolt_diff_summary('main...branch1', 't');",
Expected: []sql.Row{{"t", 0, 1, 1, 0, 2, 3, 0, 1, 1, 3, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('main...branch1');",
Expected: []sql.Row{{"t", 0, 1, 1, 0, 2, 3, 0, 1, 1, 3, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('branch1...main', 't');",
Expected: []sql.Row{{"t", 1, 1, 0, 0, 3, 0, 0, 1, 2, 3, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('branch1...main');",
Expected: []sql.Row{
{"t", 1, 1, 0, 0, 3, 0, 0, 1, 2, 3, 6},
{"newtable", 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 2},
},
},
{
Query: "SELECT * from dolt_diff_summary('branch1...main^');",
Expected: []sql.Row{{"t", 1, 1, 0, 0, 3, 0, 0, 1, 2, 3, 6}},
},
{
Query: "SELECT * from dolt_diff_summary('branch1...main', 'newtable');",
Expected: []sql.Row{{"newtable", 0, 2, 0, 0, 2, 0, 0, 0, 2, 0, 2}},
},
{
Query: "SELECT * from dolt_diff_summary('main...main', 'newtable');",
Expected: []sql.Row{},
},
},
},
{
@@ -6092,11 +6215,20 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "select * from dolt_diff_summary('HEAD~', 'HEAD', 't2')",
Expected: []sql.Row{{"t2", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}},
},
{
Query: "select * from dolt_diff_summary('HEAD~..HEAD', 't2')",
Expected: []sql.Row{{"t2", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}},
},
{
// Old table name can be matched as well
Query: "select * from dolt_diff_summary('HEAD~', 'HEAD', 't1')",
Expected: []sql.Row{{"t1", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}},
},
{
// Old table name can be matched as well
Query: "select * from dolt_diff_summary('HEAD~..HEAD', 't1')",
Expected: []sql.Row{{"t1", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}},
},
},
},
{