diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 4d8e5e5237..8b7115aea6 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -308,6 +308,29 @@ func CreateLogArgParser(isTableFunction bool) *argparser.ArgParser { return ap } +func CreateDiffArgParser(isTableFunction bool) *argparser.ArgParser { + ap := argparser.NewArgParserWithVariableArgs("diff") + ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") + ap.SupportsStringList(IncludeCols, "ic", "columns", "A list of columns to include in the diff.") + if !isTableFunction { // TODO: support for table function + ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") + ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") + ap.SupportsFlag(StatFlag, "", "Show stats of data changes") + ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes") + ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") + ap.SupportsString(WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") + ap.SupportsInt(LimitParam, "", "record_count", "limits to the first N diffs.") + ap.SupportsFlag(StagedFlag, "", "Show only the staged data changes.") + ap.SupportsFlag(CachedFlag, "c", "Synonym for --staged") + ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") + ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") + ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.") + ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.") + ap.SupportsFlag(SystemFlag, "", "Show system tables in addition to user tables") + } + return ap +} + func CreateGCArgParser() *argparser.ArgParser { ap := argparser.NewArgParserWithMaxArgs("gc", 0) ap.SupportsFlag(ShallowFlag, "s", "perform a fast, but incomplete garbage collection pass") diff --git a/go/cmd/dolt/cli/flags.go b/go/cmd/dolt/cli/flags.go index 5eddc811d9..691ddbd973 100644 --- a/go/cmd/dolt/cli/flags.go +++ b/go/cmd/dolt/cli/flags.go @@ -88,3 +88,19 @@ const ( UpperCaseAllFlag = "ALL" UserFlag = "user" ) + +// Flags used by `dolt diff` command and `dolt_diff()` table function. +const ( + SkinnyFlag = "skinny" + IncludeCols = "include-cols" + DataFlag = "data" + SchemaFlag = "schema" + NameOnlyFlag = "name-only" + SummaryFlag = "summary" + WhereParam = "where" + LimitParam = "limit" + MergeBase = "merge-base" + DiffMode = "diff-mode" + ReverseFlag = "reverse" + FormatFlag = "result-format" +) diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index 68e57bb572..e25416d009 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -60,18 +60,6 @@ const ( TabularDiffOutput diffOutput = 1 SQLDiffOutput diffOutput = 2 JsonDiffOutput diffOutput = 3 - - DataFlag = "data" - SchemaFlag = "schema" - NameOnlyFlag = "name-only" - StatFlag = "stat" - SummaryFlag = "summary" - whereParam = "where" - limitParam = "limit" - SkinnyFlag = "skinny" - MergeBase = "merge-base" - DiffMode = "diff-mode" - ReverseFlag = "reverse" ) var diffDocs = cli.CommandDocumentationContent{ @@ -107,12 +95,13 @@ The {{.EmphasisLeft}}--diff-mode{{.EmphasisRight}} argument controls how modifie } type diffDisplaySettings struct { - diffParts diffPart - diffOutput diffOutput - diffMode diff.Mode - limit int - where string - skinny bool + diffParts diffPart + diffOutput diffOutput + diffMode diff.Mode + limit int + where string + skinny bool + includeCols []string } type diffDatasets struct { @@ -164,23 +153,7 @@ func (cmd DiffCmd) Docs() *cli.CommandDocumentation { } func (cmd DiffCmd) ArgParser() *argparser.ArgParser { - ap := argparser.NewArgParserWithVariableArgs(cmd.Name()) - ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") - ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") - ap.SupportsFlag(StatFlag, "", "Show stats of data changes") - ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes") - ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") - ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") - ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") - ap.SupportsFlag(cli.StagedFlag, "", "Show only the staged data changes.") - ap.SupportsFlag(cli.CachedFlag, "c", "Synonym for --staged") - ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") - ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") - ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") - ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.") - ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.") - ap.SupportsFlag(cli.SystemFlag, "", "Show system tables in addition to user tables") - return ap + return cli.CreateDiffArgParser(false) } func (cmd DiffCmd) RequiresRepo() bool { @@ -228,14 +201,14 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, _ } func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError { - if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) { - if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) { + if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) { + if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) { return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build() } } - if apr.Contains(NameOnlyFlag) { - if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) || apr.Contains(StatFlag) || apr.Contains(SummaryFlag) { + if apr.Contains(cli.NameOnlyFlag) { + if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) || apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) { return errhand.BuildDError("invalid Arguments: --name-only cannot be combined with --schema, --data, --stat, or --summary").Build() } } @@ -254,25 +227,29 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin displaySettings := &diffDisplaySettings{} displaySettings.diffParts = SchemaAndDataDiff - if apr.Contains(DataFlag) && !apr.Contains(SchemaFlag) { + if apr.Contains(cli.DataFlag) && !apr.Contains(cli.SchemaFlag) { displaySettings.diffParts = DataOnlyDiff - } else if apr.Contains(SchemaFlag) && !apr.Contains(DataFlag) { + } else if apr.Contains(cli.SchemaFlag) && !apr.Contains(cli.DataFlag) { displaySettings.diffParts = SchemaOnlyDiff - } else if apr.Contains(StatFlag) { + } else if apr.Contains(cli.StatFlag) { displaySettings.diffParts = Stat - } else if apr.Contains(SummaryFlag) { + } else if apr.Contains(cli.SummaryFlag) { displaySettings.diffParts = Summary - } else if apr.Contains(NameOnlyFlag) { + } else if apr.Contains(cli.NameOnlyFlag) { displaySettings.diffParts = NameOnlyDiff } - displaySettings.skinny = apr.Contains(SkinnyFlag) + displaySettings.skinny = apr.Contains(cli.SkinnyFlag) + + if cols, ok := apr.GetValueList(cli.IncludeCols); ok { + displaySettings.includeCols = cols + } f := apr.GetValueOrDefault(FormatFlag, "tabular") switch strings.ToLower(f) { case "tabular": displaySettings.diffOutput = TabularDiffOutput - switch strings.ToLower(apr.GetValueOrDefault(DiffMode, "context")) { + switch strings.ToLower(apr.GetValueOrDefault(cli.DiffMode, "context")) { case "row": displaySettings.diffMode = diff.ModeRow case "line": @@ -288,8 +265,8 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin displaySettings.diffOutput = JsonDiffOutput } - displaySettings.limit, _ = apr.GetInt(limitParam) - displaySettings.where = apr.GetValueOrDefault(whereParam, "") + displaySettings.limit, _ = apr.GetInt(cli.LimitParam) + displaySettings.where = apr.GetValueOrDefault(cli.WhereParam, "") return displaySettings } @@ -301,12 +278,12 @@ func parseDiffArgs(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.Ar staged := apr.Contains(cli.StagedFlag) || apr.Contains(cli.CachedFlag) - tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(MergeBase)) + tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(cli.MergeBase)) if err != nil { return nil, err } - if apr.Contains(ReverseFlag) { + if apr.Contains(cli.ReverseFlag) { dArgs.diffDatasets = &diffDatasets{ fromRef: dArgs.toRef, toRef: dArgs.fromRef, @@ -1556,7 +1533,9 @@ func diffRows( if err != nil { return errhand.BuildDError("Error running diff query:\n%s", interpolatedQuery).AddCause(err).Build() } - + for _, col := range dArgs.includeCols { + modifiedColNames[col] = true // ensure included columns are always present + } // instantiate a new schema that only contains the columns with changes var filteredUnionSch sql.Schema for _, s := range unionSch { diff --git a/go/cmd/dolt/commands/show.go b/go/cmd/dolt/commands/show.go index bc86246d53..858bef9740 100644 --- a/go/cmd/dolt/commands/show.go +++ b/go/cmd/dolt/commands/show.go @@ -83,17 +83,17 @@ func (cmd ShowCmd) ArgParser() *argparser.ArgParser { ap.SupportsFlag(cli.NoPrettyFlag, "", "Show the object without making it pretty.") // Flags inherited from Diff - ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") - ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") - ap.SupportsFlag(StatFlag, "", "Show stats of data changes") - ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes") + ap.SupportsFlag(cli.DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") + ap.SupportsFlag(cli.SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") + ap.SupportsFlag(cli.StatFlag, "", "Show stats of data changes") + ap.SupportsFlag(cli.SummaryFlag, "", "Show summary of data and schema changes") ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") - ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") - ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") + ap.SupportsString(cli.WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") + ap.SupportsInt(cli.LimitParam, "", "record_count", "limits to the first N diffs.") ap.SupportsFlag(cli.CachedFlag, "c", "Show only the staged data changes.") - ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") - ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") - ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") + ap.SupportsFlag(cli.SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") + ap.SupportsFlag(cli.MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") + ap.SupportsString(cli.DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") return ap } @@ -275,8 +275,8 @@ func getValueFromRefSpec(ctx context.Context, dEnv *env.DoltEnv, specRef string) } func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError { - if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) { - if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) { + if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) { + if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) { return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build() } } diff --git a/go/libraries/doltcore/sqle/dtablefunctions/dolt_diff.go b/go/libraries/doltcore/sqle/dtablefunctions/dolt_diff.go index 07e49437b8..541dc03bfa 100644 --- a/go/libraries/doltcore/sqle/dtablefunctions/dolt_diff.go +++ b/go/libraries/doltcore/sqle/dtablefunctions/dolt_diff.go @@ -16,12 +16,15 @@ package dtablefunctions import ( "fmt" + "io" "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" gmstypes "github.com/dolthub/go-mysql-server/sql/types" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/merge" @@ -32,6 +35,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" + dolttable "github.com/dolthub/dolt/go/libraries/doltcore/table" "github.com/dolthub/dolt/go/store/types" ) @@ -56,6 +60,8 @@ type DiffTableFunction struct { fromDate *types.Timestamp toDate *types.Timestamp sqlSch sql.Schema + showSkinny bool + includeCols map[string]struct{} } // NewInstance creates a new instance of TableFunction interface @@ -100,26 +106,25 @@ func (dtf *DiffTableFunction) WithDatabase(database sql.Database) (sql.Node, err // Expressions implements the sql.Expressioner interface func (dtf *DiffTableFunction) Expressions() []sql.Expression { + exprs := []sql.Expression{} + if dtf.dotCommitExpr != nil { - return []sql.Expression{ - dtf.dotCommitExpr, dtf.tableNameExpr, - } - } - return []sql.Expression{ - dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr, + exprs = append(exprs, dtf.dotCommitExpr, dtf.tableNameExpr) + } else { + exprs = append(exprs, dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr) } + return exprs } // WithExpressions implements the sql.Expressioner interface -func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) { - if len(expression) < 2 { - return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expression)) - } - +func (dtf *DiffTableFunction) WithExpressions(expressions ...sql.Expression) (sql.Node, error) { + newDtf := *dtf // TODO: For now, we will only support literal / fully-resolved arguments to the // DiffTableFunction to avoid issues where the schema is needed in the analyzer // before the arguments could be resolved. - for _, expr := range expression { + var exprStrs []string + strToExpr := map[string]sql.Expression{} + for _, expr := range expressions { if !expr.Resolved() { return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String()) } @@ -127,22 +132,52 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql if _, ok := expr.(sql.FunctionExpression); ok { return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String()) } + strVal := expr.String() + if lit, ok := expr.(*expression.Literal); ok { // rm quotes from string literals + strVal = fmt.Sprintf("%v", lit.Value()) + } + exprStrs = append(exprStrs, strVal) // args extracted from apr later to filter out options + strToExpr[strVal] = expr } - newDtf := *dtf - if strings.Contains(expression[0].String(), "..") { - if len(expression) != 2 { - return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expression)) + apr, err := cli.CreateDiffArgParser(true).Parse(exprStrs) + if err != nil { + return nil, err + } + + if apr.Contains(cli.SkinnyFlag) { + newDtf.showSkinny = true + } + + if cols, ok := apr.GetValueList(cli.IncludeCols); ok { + newDtf.includeCols = make(map[string]struct{}) + for _, col := range cols { + newDtf.includeCols[col] = struct{}{} } - newDtf.dotCommitExpr = expression[0] - newDtf.tableNameExpr = expression[1] + } + + expressions = []sql.Expression{} + for _, posArg := range apr.Args { + expressions = append(expressions, strToExpr[posArg]) + } + + if len(expressions) < 2 { + return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expressions)) + } + + if strings.Contains(expressions[0].String(), "..") { + if len(expressions) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expressions)) + } + newDtf.dotCommitExpr = expressions[0] + newDtf.tableNameExpr = expressions[1] } else { - if len(expression) != 3 { - return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expression)) + if len(expressions) != 3 { + return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expressions)) } - newDtf.fromCommitExpr = expression[0] - newDtf.toCommitExpr = expression[1] - newDtf.tableNameExpr = expression[2] + newDtf.fromCommitExpr = expressions[0] + newDtf.toCommitExpr = expressions[1] + newDtf.tableNameExpr = expressions[2] } fromCommitVal, toCommitVal, dotCommitVal, tableName, err := newDtf.evaluateArguments() @@ -423,6 +458,110 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int return fromCommitVal, toCommitVal, nil, tableName, nil } +// filterDeltaSchemaToSkinnyCols creates a filtered version of the table delta that omits columns which are identical +// in type and value across all rows in both schemas, except for primary key columns or explicitly included using the +// include-cols option. This also updates dtf.tableDelta with the filtered result. +func (dtf *DiffTableFunction) filterDeltaSchemaToSkinnyCols(ctx *sql.Context, delta *diff.TableDelta) (*diff.TableDelta, error) { + if delta.FromTable == nil || delta.ToTable == nil { + return delta, nil + } + + // gather map of potential cols for removal from skinny diff + equalDiffColsIndices := map[string][2]int{} + toCols := delta.ToSch.GetAllCols() + for fromIdx, fromCol := range delta.FromSch.GetAllCols().GetColumns() { + if _, ok := dtf.includeCols[fromCol.Name]; ok { + continue // user explicitly included this column + } + + col, ok := delta.ToSch.GetAllCols().GetByName(fromCol.Name) + if !ok { // column was dropped + continue + } + if fromCol.TypeInfo.Equals(col.TypeInfo) { + toIdx := toCols.TagToIdx[toCols.NameToCol[fromCol.Name].Tag] + equalDiffColsIndices[fromCol.Name] = [2]int{fromIdx, toIdx} + } + } + + fromRowData, err := delta.FromTable.GetRowData(ctx) + if err != nil { + return nil, err + } + + toRowData, err := delta.ToTable.GetRowData(ctx) + if err != nil { + return nil, err + } + + fromIter, err := dolttable.NewTableIterator(ctx, delta.FromSch, fromRowData) + if err != nil { + return nil, err + } + defer fromIter.Close(ctx) + + toIter, err := dolttable.NewTableIterator(ctx, delta.ToSch, toRowData) + if err != nil { + return nil, err + } + defer toIter.Close(ctx) + + for len(equalDiffColsIndices) > 0 { + fromRow, fromErr := fromIter.Next(ctx) + toRow, toErr := toIter.Next(ctx) + + if fromErr == io.EOF && toErr == io.EOF { + break + } + + if fromErr != nil && fromErr != io.EOF { + return nil, fromErr + } + + if toErr != nil && toErr != io.EOF { + return nil, toErr + } + + // xor: if only one is nil, then all cols are diffs + if (fromRow == nil) != (toRow == nil) { + equalDiffColsIndices = map[string][2]int{} + break + } + + if fromRow == nil && toRow == nil { + continue + } + + for colName, idx := range equalDiffColsIndices { + if fromRow[idx[0]] != toRow[idx[1]] { // same row and col, values differ + delete(equalDiffColsIndices, colName) + } + } + } + + var fromSkCols []schema.Column + for _, col := range delta.FromSch.GetAllCols().GetColumns() { + _, ok := equalDiffColsIndices[col.Name] + if col.IsPartOfPK || !ok { + fromSkCols = append(fromSkCols, col) + } + } + + var toSkCols []schema.Column + for _, col := range delta.ToSch.GetAllCols().GetColumns() { + _, ok := equalDiffColsIndices[col.Name] + if col.IsPartOfPK || !ok { + toSkCols = append(toSkCols, col) + } + } + + skDelta := *delta + skDelta.FromSch = schema.MustSchemaFromCols(schema.NewColCollection(fromSkCols...)) + skDelta.ToSch = schema.MustSchemaFromCols(schema.NewColCollection(toSkCols...)) + dtf.tableDelta = skDelta + return &skDelta, nil +} + func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, toCommitVal, dotCommitVal interface{}, tableName string) error { if !dtf.Resolved() { return nil @@ -438,27 +577,27 @@ func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, to return err } + if dtf.showSkinny { + skDelta, err := dtf.filterDeltaSchemaToSkinnyCols(ctx, &delta) + if err != nil { + return err + } + delta = *skDelta + } + fromTable, fromTableExists := delta.FromTable, delta.FromTable != nil toTable, toTableExists := delta.ToTable, delta.ToTable != nil - if !toTableExists && !fromTableExists { + var format *types.NomsBinFormat + if toTableExists { + format = toTable.Format() + } else if fromTableExists { + format = fromTable.Format() + } else { return sql.ErrTableNotFound.New(tableName) } - var toSchema, fromSchema schema.Schema - var format *types.NomsBinFormat - - if fromTableExists { - fromSchema = delta.FromSch - format = fromTable.Format() - } - - if toTableExists { - toSchema = delta.ToSch - format = toTable.Format() - } - - diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, fromSchema, toSchema) + diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, delta.FromSch, delta.ToSch) if err != nil { return err } @@ -571,15 +710,14 @@ func (dtf *DiffTableFunction) IsReadOnly() bool { // String implements the Stringer interface func (dtf *DiffTableFunction) String() string { + args := []string{} if dtf.dotCommitExpr != nil { - return fmt.Sprintf("DOLT_DIFF(%s, %s)", - dtf.dotCommitExpr.String(), - dtf.tableNameExpr.String()) + args = append(args, dtf.dotCommitExpr.String(), dtf.tableNameExpr.String()) + } else { + args = append(args, dtf.fromCommitExpr.String(), dtf.toCommitExpr.String(), dtf.tableNameExpr.String()) } - return fmt.Sprintf("DOLT_DIFF(%s, %s, %s)", - dtf.fromCommitExpr.String(), - dtf.toCommitExpr.String(), - dtf.tableNameExpr.String()) + + return fmt.Sprintf("DOLT_DIFF(%s)", strings.Join(args, ", ")) } // Name implements the sql.TableFunction interface diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_diff.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_diff.go index 19689f622c..e7ffd501b2 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_diff.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_diff.go @@ -15,6 +15,9 @@ package enginetest import ( + "fmt" + "strings" + "github.com/dolthub/go-mysql-server/enginetest/queries" "github.com/dolthub/go-mysql-server/sql" gmstypes "github.com/dolthub/go-mysql-server/sql/types" @@ -818,7 +821,138 @@ var Dolt1DiffSystemTableScripts = []queries.ScriptTest{ }, } +// assertDoltDiffColumnCount returns assertions that verify a dolt_diff view +// has the expected number of distinct data columns (excluding commit metadata). +func assertDoltDiffColumnCount(view, selectStmt string, expected int64) []queries.ScriptTestAssertion { + excluded := []string{ + "'to_commit'", + "'from_commit'", + "'to_commit_date'", + "'from_commit_date'", + "'diff_type'", + } + + query := fmt.Sprintf(` + SELECT COUNT(DISTINCT REPLACE(REPLACE(column_name, 'to_', ''), 'from_', '')) + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = '%s' + AND column_name NOT IN (%s)`, + view, strings.Join(excluded, ", "), + ) + + return []queries.ScriptTestAssertion{ + {Query: fmt.Sprintf("DROP VIEW IF EXISTS %s;", view)}, + {Query: fmt.Sprintf("CREATE VIEW %s AS %s;", view, selectStmt)}, + {Query: query, Expected: []sql.Row{{expected}}}, + {Query: fmt.Sprintf("DROP VIEW %s;", view)}, + } +} + var DiffTableFunctionScriptTests = []queries.ScriptTest{ + { + Name: "dolt_diff: SELECT * skinny schema visibility", + SetUpScript: []string{ + `CREATE TABLE t ( + pk BIGINT NOT NULL COMMENT 'tag:0', + c1 BIGINT COMMENT 'tag:1', + c2 BIGINT COMMENT 'tag:2', + c3 BIGINT COMMENT 'tag:3', + c4 BIGINT COMMENT 'tag:4', + c5 BIGINT COMMENT 'tag:5', + PRIMARY KEY (pk) + );`, + "call dolt_add('.')", + "set @C0 = '';", + "call dolt_commit_hash_out(@C0, '-m', 'Created table t');", + "INSERT INTO t VALUES (0,1,2,3,4,5), (1,1,2,3,4,5);", + "call dolt_add('.')", + "set @C1 = '';", + "call dolt_commit_hash_out(@C1, '-m', 'Added initial data');", + + "UPDATE t SET c1=100, c3=300 WHERE pk=0;", + "UPDATE t SET c2=200 WHERE pk=1;", + "call dolt_add('.')", + "set @C2 = '';", + "call dolt_commit_hash_out(@C2, '-m', 'Updated some columns');", + + "ALTER TABLE t ADD COLUMN c6 BIGINT;", + "UPDATE t SET c6=600 WHERE pk=0;", + "call dolt_add('.')", + "set @C3 = '';", + "call dolt_commit_hash_out(@C3, '-m', 'Added new column and updated it');", + + "DELETE FROM t WHERE pk=1;", + "call dolt_add('.')", + "set @C4 = '';", + "call dolt_commit_hash_out(@C4, '-m', 'Deleted a row');", + }, + Assertions: func() []queries.ScriptTestAssertion { + asserts := []queries.ScriptTestAssertion{ + { + Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.to_c4, d.to_c5, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.diff_type " + + "FROM (SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')) d " + + "ORDER BY COALESCE(d.to_pk, d.from_pk)", + Expected: []sql.Row{ + {int64(0), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"}, + {int64(1), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"}, + }, + }, + { + Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.diff_type " + + "FROM (SELECT * FROM dolt_diff(@C1, @C2, 't')) d " + + "ORDER BY COALESCE(d.to_pk, d.from_pk)", + Expected: []sql.Row{ + {int64(0), int64(100), int64(2), int64(300), int64(0), int64(1), int64(2), int64(3), "modified"}, + {int64(1), int64(1), int64(200), int64(3), int64(1), int64(1), int64(2), int64(3), "modified"}, + }, + }, + { + Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.diff_type " + + "FROM (SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')) d " + + "ORDER BY d.to_pk", + Expected: []sql.Row{ + {int64(0), int64(100), int64(2), int64(300), "modified"}, + {int64(1), int64(1), int64(200), int64(3), "modified"}, + }, + }, + { + Query: "SELECT d.to_pk, d.to_c6, d.diff_type " + + "FROM (SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')) d", + Expected: []sql.Row{ + {int64(0), int64(600), "modified"}, + }, + }, + { + Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c6, d.diff_type " + + "FROM (SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')) d", + Expected: []sql.Row{ + {int64(0), int64(100), int64(2), int64(600), "modified"}, + }, + }, + { + Query: "SELECT d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.from_c6, d.diff_type " + + "FROM (SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')) d", + Expected: []sql.Row{ + {int64(1), int64(1), int64(200), int64(3), int64(4), int64(5), nil, "removed"}, + }, + }, + } + asserts = append(asserts, assertDoltDiffColumnCount("v_all_01", "SELECT * FROM dolt_diff(@C0, @C1, 't')", 6)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_01", "SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')", 6)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_all_12", "SELECT * FROM dolt_diff(@C1, @C2, 't')", 6)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_12", "SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')", 4)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_all_23", "SELECT * FROM dolt_diff(@C2, @C3, 't')", 7)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')", 2)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff(@C2, @C3, 't', '--skinny')", 2)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')", 4)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2,c6', @C2, @C3, 't')", 4)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_all_34", "SELECT * FROM dolt_diff(@C3, @C4, 't')", 7)...) + asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_34", "SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')", 7)...) + + return asserts + }(), + }, { Name: "invalid arguments", SetUpScript: []string{ diff --git a/integration-tests/bats/helper/sql-diff.bash b/integration-tests/bats/helper/sql-diff.bash new file mode 100644 index 0000000000..ffb9d3020c --- /dev/null +++ b/integration-tests/bats/helper/sql-diff.bash @@ -0,0 +1,137 @@ +# dolt/integration-tests/bats/helper/sql-diff.bash + +: "${SQL_DIFF_DEBUG:=}" # set to any value to enable debug output +_dbg() { [ -n "$SQL_DIFF_DEBUG" ] && printf '%s\n' "$*" >&2; } +_dbg_block() { [ -n "$SQL_DIFF_DEBUG" ] && { printf '%s\n' "$1" >&2; printf '%s\n' "$2" >&2; }; } + +# first table header row from CLI diff (data section), as newline list +_cli_header_cols() { + awk ' + /^\s*\|\s*[-+<>]\s*\|/ && last_header != "" { print last_header; exit } + /^\s*\|/ { last_header = $0 } + ' <<<"$1" \ + | tr '|' '\n' \ + | sed -e 's/^[[:space:]]*//;s/[[:space:]]*$//' \ + | grep -v -E '^(<|>|)$' \ + | grep -v '^$' +} + +# first table header row from SQL diff, strip to_/from_, drop metadata, as newline list +_sql_data_header_cols() { + echo "$1" \ + | awk '/^\|/ {print; exit}' \ + | tr '|' '\n' \ + | sed -e 's/^[[:space:] ]*//;s/[[:space:] ]*$//' \ + | grep -E '^(to_|from_)' \ + | sed -E 's/^(to_|from_)//' \ + | grep -Ev '^(commit|commit_date|diff_type)$' \ + | grep -v '^$' +} + +# count CLI changes by unique PK (includes +, -, <, >) +_cli_change_count() { + awk -F'|' ' + # start counting once we see a data row marker + /^\s*\|\s*[-+<>]\s*\|/ { in_table=1 } + in_table && $2 ~ /^[[:space:]]*[-+<>][[:space:]]*$/ { + pk=$3 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", pk) + if (pk != "") seen[pk]=1 + } + END { c=0; for (k in seen) c++; print c+0 } + ' <<<"$1" +} + +# count SQL data rows (lines starting with '|' minus header) +_sql_row_count() { + echo "$1" | awk '/^\|/ {c++} END{print (c>0?c-1:0)}' +} + +# compare two newline lists as sets (sorted) +_compare_sets_or_err() { + local name="$1" cli_cols="$2" sql_cols="$3" cli_out="$4" sql_out="$5" + + local cli_sorted sql_sorted + cli_sorted=$(echo "$cli_cols" | sort -u) + sql_sorted=$(echo "$sql_cols" | sort -u) + + _dbg_block "$name CLI columns:" "$cli_sorted" + _dbg_block "$name SQL data columns:" "$sql_sorted" + + if [ "$cli_sorted" != "$sql_sorted" ]; then + echo "$name column set mismatch" + echo "--- $name CLI columns ---"; echo "$cli_sorted" + echo "--- $name SQL data columns ---"; echo "$sql_sorted" + echo "--- $name CLI output ---"; echo "$cli_out" + echo "--- $name SQL output ---"; echo "$sql_out" + return 1 + fi + return 0 +} + +# compare change/row counts; on mismatch, print both outputs +_compare_counts_or_err() { + local name="$1" cli_out="$2" sql_out="$3" cli_count="$4" sql_count="$5" + + _dbg "$name counts: CLI=$cli_count SQL=$sql_count" + + if [ "$cli_count" != "$sql_count" ]; then + echo "$name change count mismatch: CLI=$cli_count, SQL=$sql_count" + echo "--- $name CLI output ---"; echo "$cli_out" + echo "--- $name SQL output ---"; echo "$sql_out" + return 1 + fi + return 0 +} + +# ---- main entrypoint ---- + +# Compare CLI diff with SQL dolt_diff +# Usage: compare_dolt_diff [all dolt diff args...] +compare_dolt_diff() { + local args=("$@") # all arguments + + # --- normal diff --- + local cli_output sql_output cli_status sql_status + cli_output=$(dolt diff "${args[@]}" 2>&1) + cli_status=$? + + # Build SQL argument list safely + local sql_args="" + for arg in "${args[@]}"; do + if [ -z "$sql_args" ]; then + sql_args="'$arg'" + else + sql_args+=", '$arg'" + fi + done + sql_output=$(dolt sql -q "SELECT * FROM dolt_diff($sql_args)" 2>&1) + sql_status=$? + + # normally prints in bats using `run`, so no debug blocks here + echo "$cli_output" + echo "$sql_output" + + if [ $cli_status -ne 0 ]; then + _dbg "$cli_output" + return 1 + fi + if [ $sql_status -ne 0 ]; then + _dbg "$sql_output" + return 1 + fi + + # Compare counts + local cli_changes sql_rows + cli_changes=$(_cli_change_count "$cli_output") + sql_rows=$(_sql_row_count "$sql_output") + _compare_counts_or_err "Diff" "$cli_output" "$sql_output" "$cli_changes" "$sql_rows" || return 1 + + # Compare columns + local cli_cols sql_cols + cli_cols=$(_cli_header_cols "$cli_output") + sql_cols=$(_sql_data_header_cols "$sql_output") + _compare_sets_or_err "Diff" "$cli_cols" "$sql_cols" "$cli_output" "$sql_output" || return 1 + + return 0 +} diff --git a/integration-tests/bats/sql-diff.bats b/integration-tests/bats/sql-diff.bats index 3f1daf1a3c..72579392b2 100644 --- a/integration-tests/bats/sql-diff.bats +++ b/integration-tests/bats/sql-diff.bats @@ -1,5 +1,6 @@ #!/usr/bin/env bats load $BATS_TEST_DIRNAME/helper/common.bash +load $BATS_TEST_DIRNAME/helper/sql-diff.bash setup() { setup_common @@ -890,3 +891,60 @@ EOF [ "$status" -eq 1 ] [[ "$output" =~ "invalid output format: sql. SQL format diffs only rendered for schema or data changes" ]] || false } + +@test "sql-diff: skinny flag comparison between CLI and SQL table function" { + dolt sql <