From d97c456e09256f2a7847f7c3a0cc99d4bb2d3574 Mon Sep 17 00:00:00 2001 From: Pavel Safronov Date: Fri, 23 Jun 2023 10:14:32 -0700 Subject: [PATCH] Starting work on migrating `dolt diff`: - temporarily disable `dolt show` - introduce TableInfo struct to maintain SQL table state - update diffWriter to use non-doltdb primitives - update diff.go to use non-doltdb primitives and load most data from SQL --- go/cmd/dolt/commands/diff.go | 775 +++++++++++++----- go/cmd/dolt/commands/diff_output.go | 177 ++-- go/cmd/dolt/commands/show.go | 700 ++++++++-------- go/cmd/dolt/dolt.go | 4 +- go/libraries/doltcore/diff/diff.go | 2 +- go/libraries/doltcore/diff/table_deltas.go | 109 ++- .../sqle/dolt_patch_table_function.go | 81 +- .../doltcore/sqle/sqlfmt/schema_fmt.go | 28 +- .../table/typed/json/json_diff_writer.go | 2 +- 9 files changed, 1131 insertions(+), 747 deletions(-) diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index d0bdb6f10c..b524a1ae67 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -17,8 +17,10 @@ package commands import ( "context" "fmt" + "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" + "github.com/dolthub/go-mysql-server/sql/parse" + "github.com/dolthub/vitess/go/vt/sqlparser" "io" - "sort" "strconv" "strings" @@ -26,7 +28,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/dolt/go/cmd/dolt/cli" - "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" "github.com/dolthub/dolt/go/cmd/dolt/errhand" eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" "github.com/dolthub/dolt/go/libraries/doltcore/diff" @@ -109,10 +110,8 @@ type diffDisplaySettings struct { } type diffDatasets struct { - fromRoot *doltdb.RootValue - toRoot *doltdb.RootValue - fromRef string - toRef string + fromRef string + toRef string } type diffArgs struct { @@ -121,6 +120,21 @@ type diffArgs struct { tableSet *set.StrSet } +type diffStatistics struct { + TableName string + RowsUnmodified uint64 + RowsAdded uint64 + RowsDeleted uint64 + RowsModified uint64 + CellsAdded uint64 + CellsDeleted uint64 + CellsModified uint64 + OldRowCount uint64 + NewRowCount uint64 + OldCellCount uint64 + NewCellCount uint64 +} + type DiffCmd struct{} // Name is returns the name of the Dolt cli command. This is what is used on the command line to invoke the command @@ -171,12 +185,20 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, d return HandleVErrAndExitCode(verr, usage) } - dArgs, err := parseDiffArgs(ctx, dEnv, apr) + queryist, sqlCtx, closeFunc, err := cliCtx.QueryEngine(ctx) + if err != nil { + return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) + } + if closeFunc != nil { + defer closeFunc() + } + + dArgs, err := parseDiffArgs(queryist, sqlCtx, ctx, apr) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - verr = diffUserTables(ctx, dEnv, dArgs) + verr = diffUserTables(queryist, sqlCtx, ctx, dArgs) return HandleVErrAndExitCode(verr, usage) } @@ -197,7 +219,7 @@ func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseE return nil } -func parseDiffDisplaySettings(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) *diffDisplaySettings { +func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettings { displaySettings := &diffDisplaySettings{} displaySettings.diffParts = SchemaAndDataDiff @@ -239,26 +261,24 @@ func parseDiffDisplaySettings(ctx context.Context, dEnv *env.DoltEnv, apr *argpa return displaySettings } -func parseDiffArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (*diffArgs, error) { +func parseDiffArgs(queryist cli.Queryist, sqlCtx *sql.Context, ctx context.Context, apr *argparser.ArgParseResults) (*diffArgs, error) { dArgs := &diffArgs{ - diffDisplaySettings: parseDiffDisplaySettings(ctx, dEnv, apr), + diffDisplaySettings: parseDiffDisplaySettings(apr), } - tableNames, err := dArgs.applyDiffRoots(ctx, dEnv, apr.Args, apr.Contains(cli.CachedFlag), apr.Contains(MergeBase)) + tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, ctx, apr.Args, apr.Contains(cli.CachedFlag), apr.Contains(MergeBase)) if err != nil { return nil, err } if apr.Contains(ReverseFlag) { dArgs.diffDatasets = &diffDatasets{ - fromRoot: dArgs.toRoot, - fromRef: dArgs.toRef, - toRoot: dArgs.fromRoot, - toRef: dArgs.fromRef, + fromRef: dArgs.toRef, + toRef: dArgs.fromRef, } } - tableSet, err := parseDiffTableSet(ctx, dEnv, dArgs.diffDatasets, tableNames) + tableSet, err := parseDiffTableSetSql(queryist, sqlCtx, ctx, dArgs.diffDatasets, tableNames) if err != nil { return nil, err } @@ -268,75 +288,96 @@ func parseDiffArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPar return dArgs, nil } -func parseDiffTableSet(ctx context.Context, dEnv *env.DoltEnv, datasets *diffDatasets, tableNames []string) (*set.StrSet, error) { +func parseDiffTableSetSql(queryist cli.Queryist, sqlCtx *sql.Context, ctx context.Context, datasets *diffDatasets, tableNames []string) (*set.StrSet, error) { + tablesAtFromRef, err := getTableNamesAtRef(queryist, sqlCtx, datasets.fromRef) + if err != nil { + return nil, err + } + tablesAtToRef, err := getTableNamesAtRef(queryist, sqlCtx, datasets.toRef) + if err != nil { + return nil, err + } tableSet := set.NewStrSet(nil) for _, tableName := range tableNames { // verify table args exist in at least one root - _, ok, err := datasets.fromRoot.GetTable(ctx, tableName) - if err != nil { - return nil, err - } + _, ok := tablesAtFromRef[tableName] if ok { tableSet.Add(tableName) continue } - _, ok, err = datasets.toRoot.GetTable(ctx, tableName) - if err != nil { - return nil, err - } + _, ok = tablesAtToRef[tableName] if ok { tableSet.Add(tableName) continue } - if !ok { - return nil, fmt.Errorf("table %s does not exist in either revision", tableName) - } + + return nil, fmt.Errorf("table %s does not exist in either revision", tableName) } // if no tables or docs were specified as args, diff all tables and docs if len(tableNames) == 0 { - utn, err := doltdb.UnionTableNames(ctx, datasets.fromRoot, datasets.toRoot) - if err != nil { - return nil, err + seenTableNames := make(map[string]bool) + for _, tables := range []map[string]bool{tablesAtFromRef, tablesAtToRef} { + for tableName := range tables { + if _, ok := seenTableNames[tableName]; !ok { + seenTableNames[tableName] = true + tableSet.Add(tableName) + } + } } - tableSet.Add(utn...) } return tableSet, nil } +var doltSystemTables = []string{ + "dolt_procedures", + "dolt_schemas", +} + +func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string) (map[string]bool, error) { + // query for user-created tables + q := fmt.Sprintf("SHOW FULL TABLES AS OF '%s'", ref) + rows, err := getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return nil, err + } + + tableNames := make(map[string]bool) + for _, row := range rows { + tableName := row[0].(string) + tableType := row[1].(string) + isTable := tableType == "BASE TABLE" + if isTable { + tableNames[tableName] = true + } + } + + // add system tables, if they exist at this ref + for _, sysTable := range doltSystemTables { + q = fmt.Sprintf("show create table %s as of '%s'", sysTable, ref) + _, err = getRowsForSql(queryist, sqlCtx, q) + if err == nil { + tableNames[sysTable] = true + } + } + + return tableNames, nil +} + // applyDiffRoots applies the appropriate |from| and |to| root values to the receiver and returns the table names // (if any) given to the command. -func (dArgs *diffArgs) applyDiffRoots(ctx context.Context, dEnv *env.DoltEnv, args []string, isCached, useMergeBase bool) ([]string, error) { - headRoot, err := dEnv.HeadRoot(ctx) - if err != nil { - return nil, err - } - - stagedRoot, err := dEnv.StagedRoot(ctx) - if err != nil { - return nil, err - } - - workingRoot, err := dEnv.WorkingRoot(ctx) - if err != nil { - return nil, err - } - +func (dArgs *diffArgs) applyDiffRoots(queryist cli.Queryist, sqlCtx *sql.Context, ctx context.Context, args []string, isCached, useMergeBase bool) ([]string, error) { dArgs.diffDatasets = &diffDatasets{ - fromRoot: stagedRoot, - fromRef: doltdb.Staged, - toRoot: workingRoot, - toRef: doltdb.Working, + fromRef: doltdb.Staged, + toRef: doltdb.Working, } if isCached { - dArgs.fromRoot = headRoot dArgs.fromRef = "HEAD" - dArgs.toRoot = stagedRoot dArgs.toRef = doltdb.Staged } @@ -352,31 +393,30 @@ func (dArgs *diffArgs) applyDiffRoots(ctx context.Context, dEnv *env.DoltEnv, ar if useMergeBase { return nil, fmt.Errorf("Cannot use `..` or `...` with --merge-base flag") } - err = dArgs.applyDotRevisions(ctx, dEnv, args) + err := dArgs.applyDotRevisions(queryist, sqlCtx, ctx, args) if err != nil { return nil, err } return args[1:], err } + fromRef := args[0] // treat the first arg as a ref spec - fromRoot, ok := diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, args[0]) + _, err := getTableNamesAtRef(queryist, sqlCtx, fromRef) // if it doesn't resolve, treat it as a table name - if !ok { + if err != nil { // `dolt diff table` if useMergeBase { return nil, fmt.Errorf("Must supply at least one revision when using --merge-base flag") } return args, nil } - - dArgs.fromRoot = fromRoot - dArgs.fromRef = args[0] + dArgs.fromRef = fromRef if len(args) == 1 { // `dolt diff from_commit` if useMergeBase { - err := dArgs.applyMergeBase(ctx, dEnv, args[0], "HEAD") + err := dArgs.applyMergeBase(queryist, sqlCtx, args[0], "HEAD") if err != nil { return nil, err } @@ -384,23 +424,24 @@ func (dArgs *diffArgs) applyDiffRoots(ctx context.Context, dEnv *env.DoltEnv, ar return nil, nil } - toRoot, ok := diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, args[1]) - if !ok { + toRef := args[1] + // treat the first arg as a ref spec + _, err = getTableNamesAtRef(queryist, sqlCtx, toRef) + // if it doesn't resolve, treat it as a table name + if err != nil { // `dolt diff from_commit [...tables]` if useMergeBase { - err := dArgs.applyMergeBase(ctx, dEnv, args[0], "HEAD") + err := dArgs.applyMergeBase(queryist, sqlCtx, args[0], "HEAD") if err != nil { return nil, err } } return args[1:], nil } - - dArgs.toRoot = toRoot - dArgs.toRef = args[1] + dArgs.toRef = toRef if useMergeBase { - err := dArgs.applyMergeBase(ctx, dEnv, args[0], args[1]) + err := dArgs.applyMergeBase(queryist, sqlCtx, args[0], args[1]) if err != nil { return nil, err } @@ -412,31 +453,37 @@ func (dArgs *diffArgs) applyDiffRoots(ctx context.Context, dEnv *env.DoltEnv, ar // applyMergeBase applies the merge base of two revisions to the |from| root // values. -func (dArgs *diffArgs) applyMergeBase(ctx context.Context, dEnv *env.DoltEnv, leftStr, rightStr string) error { - mergeBaseStr, err := getMergeBaseFromStrings(ctx, dEnv, leftStr, rightStr) +func (dArgs *diffArgs) applyMergeBase(queryist cli.Queryist, sqlCtx *sql.Context, leftStr, rightStr string) error { + //mergeBaseStr, err := getMergeBaseFromStrings(ctx, dEnv, leftStr, rightStr) + mergeBaseStr, err := getCommonAncestor(queryist, sqlCtx, leftStr, rightStr) if err != nil { return err } - fromRoot, ok := diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, mergeBaseStr) - if !ok { - return fmt.Errorf("merge base invalid %s", mergeBaseStr) - } - - dArgs.fromRoot = fromRoot dArgs.fromRef = mergeBaseStr return nil } +func getCommonAncestor(queryist cli.Queryist, sqlCtx *sql.Context, c1, c2 string) (string, error) { + q := fmt.Sprintf("select dolt_merge_base('%s', '%s')", c1, c2) + rows, err := getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return "", err + } + if len(rows) != 1 { + return "", fmt.Errorf("unexpected number of rows returned from dolt_merge_base") + } + ancestor := rows[0][0].(string) + return ancestor, nil +} + // applyDotRevisions applies the appropriate |from| and |to| root values to the // receiver for arguments containing `..` or `...` -func (dArgs *diffArgs) applyDotRevisions(ctx context.Context, dEnv *env.DoltEnv, args []string) error { +func (dArgs *diffArgs) applyDotRevisions(queryist cli.Queryist, sqlCtx *sql.Context, ctx context.Context, args []string) error { // `dolt diff from_commit...to_commit [...tables]` if strings.Contains(args[0], "...") { refs := strings.Split(args[0], "...") - var toRoot *doltdb.RootValue - ok := true if len(refs[0]) > 0 { right := refs[1] @@ -445,17 +492,13 @@ func (dArgs *diffArgs) applyDotRevisions(ctx context.Context, dEnv *env.DoltEnv, right = "HEAD" } - err := dArgs.applyMergeBase(ctx, dEnv, refs[0], right) + err := dArgs.applyMergeBase(queryist, sqlCtx, refs[0], right) if err != nil { return err } } if len(refs[1]) > 0 { - if toRoot, ok = diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, refs[1]); !ok { - return fmt.Errorf("to ref in three dot diff must be valid ref: %s", refs[1]) - } - dArgs.toRoot = toRoot dArgs.toRef = refs[1] } @@ -465,23 +508,12 @@ func (dArgs *diffArgs) applyDotRevisions(ctx context.Context, dEnv *env.DoltEnv, // `dolt diff from_commit..to_commit [...tables]` if strings.Contains(args[0], "..") { refs := strings.Split(args[0], "..") - var fromRoot *doltdb.RootValue - var toRoot *doltdb.RootValue - ok := true if len(refs[0]) > 0 { - if fromRoot, ok = diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, refs[0]); !ok { - return fmt.Errorf("from ref in two dot diff must be valid ref: %s", refs[0]) - } - dArgs.fromRoot = fromRoot dArgs.fromRef = refs[0] } if len(refs[1]) > 0 { - if toRoot, ok = diff.MaybeResolveRoot(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, refs[1]); !ok { - return fmt.Errorf("to ref in two dot diff must be valid ref: %s", refs[1]) - } - dArgs.toRoot = toRoot dArgs.toRef = refs[1] } @@ -498,63 +530,168 @@ var diffSummarySchema = sql.Schema{ &sql.Column{Name: "Schema change", Type: types.Boolean, Nullable: false}, } -func printDiffSummary(ctx context.Context, tds []diff.TableDelta, dArgs *diffArgs) errhand.VerboseError { +func printDiffSummary(ctx context.Context, diffSummaries []diff.TableDeltaSummary, dArgs *diffArgs) errhand.VerboseError { cliWR := iohelp.NopWrCloser(cli.OutStream) wr := tabular.NewFixedWidthTableWriter(diffSummarySchema, cliWR, 100) defer wr.Close(ctx) - for _, td := range tds { - if !dArgs.tableSet.Contains(td.FromName) && !dArgs.tableSet.Contains(td.ToName) { - continue + for _, diffSummary := range diffSummaries { + + shouldPrintTables := dArgs.tableSet.Contains(diffSummary.FromTableName) || dArgs.tableSet.Contains(diffSummary.ToTableName) + if !shouldPrintTables { + return nil } - if td.FromTable == nil && td.ToTable == nil { - return errhand.BuildDError("error: both tables in tableDelta are nil").Build() + tableName := diffSummary.TableName + if diffSummary.DiffType == "renamed" { + tableName = fmt.Sprintf("%s -> %s", diffSummary.FromTableName, diffSummary.ToTableName) } - - summ, err := td.GetSummary(ctx) - if err != nil { - return errhand.BuildDError("could not get table delta summary").AddCause(err).Build() - } - tableName := summ.TableName - if summ.DiffType == "renamed" { - tableName = fmt.Sprintf("%s -> %s", summ.FromTableName, summ.ToTableName) - } - - err = wr.WriteSqlRow(ctx, sql.Row{tableName, summ.DiffType, summ.DataChange, summ.SchemaChange}) + err := wr.WriteSqlRow(ctx, sql.Row{tableName, diffSummary.DiffType, diffSummary.DataChange, diffSummary.SchemaChange}) if err != nil { return errhand.BuildDError("could not write table delta summary").AddCause(err).Build() } + } return nil } -func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) errhand.VerboseError { +func getDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Context, fromRef, toRef string) ([]diff.TableDeltaSummary, error) { + q := fmt.Sprintf("select * from dolt_diff_summary('%s', '%s')", fromRef, toRef) + dataDiffRows, err := getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return nil, fmt.Errorf("error: unable to get diff summary from %s to %s: %w", fromRef, toRef, err) + } + + q = fmt.Sprintf("select * from dolt_schema_diff('%s', '%s')", fromRef, toRef) + schemaDiffRows, err := getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return nil, fmt.Errorf("error: unable to get schema diff from %s to %s: %w", fromRef, toRef, err) + } + + summaries := []diff.TableDeltaSummary{} + + for _, row := range dataDiffRows { + summary := diff.TableDeltaSummary{} + summary.FromTableName = row[0].(string) + summary.ToTableName = row[1].(string) + summary.DiffType = row[2].(string) + summary.DataChange, err = getTinyIntColAsBool(row[3]) + if err != nil { + return nil, fmt.Errorf("error: unable to parse data change value '%s': %w", row[3], err) + } + summary.SchemaChange, err = getTinyIntColAsBool(row[4]) + if err != nil { + return nil, fmt.Errorf("error: unable to parse schema change value '%s': %w", row[4], err) + } + + switch summary.DiffType { + case "dropped": + summary.TableName = summary.FromTableName + case "added": + summary.TableName = summary.ToTableName + case "renamed": + summary.TableName = summary.ToTableName + case "modified": + summary.TableName = summary.FromTableName + default: + return nil, fmt.Errorf("error: unexpected diff type '%s'", summary.DiffType) + } + + summaries = append(summaries, summary) + } + + //schemaSummaries := []diff.TableDeltaSummary{} + for _, row := range schemaDiffRows { + fromTable := row[0].(string) + toTable := row[1].(string) + fromCreateStmt := row[2].(string) + toCreateStmt := row[3].(string) + alterStmt := row[4].(string) + pkChanged, err := getTinyIntColAsBool(row[5]) + if err != nil { + return nil, fmt.Errorf("error: unable to parse pk changed value '%s': %w", row[5], err) + } + + var schemaChanged = alterStmt != "" || pkChanged + var diffType = "" + var tableName = "" + switch { + case fromTable == toTable: + if fromCreateStmt != toCreateStmt { + diffType = "modified" + tableName = fromTable + schemaChanged = true + } + case fromTable == "": + diffType = "added" + tableName = toTable + schemaChanged = true + case toTable == "": + diffType = "dropped" + tableName = fromTable + schemaChanged = true + case fromTable != "" && toTable != "" && fromTable != toTable: + diffType = "renamed" + tableName = toTable + schemaChanged = true + default: + return nil, fmt.Errorf("error: unexpected schema diff case: fromTable='%s', toTable='%s'", fromTable, toTable) + } + + if !schemaChanged { + continue + } + + var existingSummaryIndex = -1 + for i, summary := range summaries { + isSameSummary := + summary.TableName == tableName && + summary.FromTableName == fromTable && + summary.ToTableName == toTable && + summary.DiffType == diffType + if isSameSummary { + existingSummaryIndex = i + break + } + } + if existingSummaryIndex == -1 { + summary := diff.TableDeltaSummary{ + TableName: tableName, + FromTableName: fromTable, + ToTableName: toTable, + DiffType: diffType, + DataChange: false, + SchemaChange: true, + PkChanged: pkChanged, + } + if alterStmt != "" { + summary.AlterStmts = []string{alterStmt} + } + + summaries = append(summaries, summary) + } else { + summary := summaries[existingSummaryIndex] + summary.SchemaChange = true + summary.AlterStmts = append(summary.AlterStmts, alterStmt) + summary.PkChanged = summary.PkChanged || pkChanged + summaries[existingSummaryIndex] = summary + } + } + + return summaries, nil +} + +func diffUserTables(queryist cli.Queryist, sqlCtx *sql.Context, ctx context.Context, dArgs *diffArgs) errhand.VerboseError { var err error - tableDeltas, err := diff.GetTableDeltas(ctx, dArgs.fromRoot, dArgs.toRoot) + diffSummaries, err := getDiffSummariesBetweenRefs(queryist, sqlCtx, dArgs.fromRef, dArgs.toRef) if err != nil { - return errhand.BuildDError("error: unable to diff tables").AddCause(err).Build() + return errhand.BuildDError("error: unable to get diff summary").AddCause(err).Build() } - sqlEng, dbName, err := engine.NewSqlEngineForEnv(ctx, dEnv) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - - sqlCtx, err := sqlEng.NewLocalContext(ctx) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - sqlCtx.SetCurrentDatabase(dbName) - - sort.Slice(tableDeltas, func(i, j int) bool { - return strings.Compare(tableDeltas[i].ToName, tableDeltas[j].ToName) < 0 - }) - if dArgs.diffParts&Summary != 0 { - return printDiffSummary(ctx, tableDeltas, dArgs) + return printDiffSummary(ctx, diffSummaries, dArgs) } dw, err := newDiffWriter(dArgs.diffOutput) @@ -562,36 +699,17 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) err return errhand.VerboseErrorFromError(err) } - roots, err := dEnv.Roots(ctx) - if err != nil { - return errhand.VerboseErrorFromError(fmt.Errorf("couldn't get working root, cause: %w", err)) - } - - ignoredTablePatterns, err := doltdb.GetIgnoredTablePatterns(ctx, roots) + ignoredTablePatterns, err := getIgnoredTablePatternsFromSql(queryist, sqlCtx) if err != nil { return errhand.VerboseErrorFromError(fmt.Errorf("couldn't get ignored table patterns, cause: %w", err)) } - toRootHash, err := dArgs.diffDatasets.toRoot.HashOf() - if err != nil { - return errhand.VerboseErrorFromError(err) - } - - fromRootHash, err := dArgs.diffDatasets.fromRoot.HashOf() - if err != nil { - return errhand.VerboseErrorFromError(err) - } - - workingSetHash, err := roots.Working.HashOf() - if err != nil { - return errhand.VerboseErrorFromError(err) - } - doltSchemasChanged := false - for _, td := range tableDeltas { + for _, diffSummary := range diffSummaries { + // Don't print tables if one side of the diff is an ignored table in the working set being added. - if toRootHash == workingSetHash && td.FromTable == nil { - ignoreResult, err := ignoredTablePatterns.IsTableNameIgnored(td.ToName) + if len(diffSummary.FromTableName) == 0 { + ignoreResult, err := ignoredTablePatterns.IsTableNameIgnored(diffSummary.ToTableName) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -600,8 +718,8 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) err } } - if fromRootHash == workingSetHash && td.ToTable == nil { - ignoreResult, err := ignoredTablePatterns.IsTableNameIgnored(td.FromName) + if len(diffSummary.ToTableName) == 0 { + ignoreResult, err := ignoredTablePatterns.IsTableNameIgnored(diffSummary.FromTableName) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -610,15 +728,15 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) err } } - if !shouldPrintTableDelta(dArgs.tableSet, td) { + if !shouldPrintTableDelta(dArgs.tableSet, diffSummary.ToTableName, diffSummary.FromTableName) { continue } - if isDoltSchemasTable(td) { + if isDoltSchemasTable(diffSummary.ToTableName, diffSummary.FromTableName) { // save dolt_schemas table diff for last in diff output doltSchemasChanged = true } else { - verr := diffUserTable(sqlCtx, td, sqlEng, dArgs, dw) + verr := diffUserTable(queryist, sqlCtx, diffSummary, dArgs, dw) if verr != nil { return verr } @@ -626,7 +744,7 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) err } if doltSchemasChanged { - verr := diffDoltSchemasTable(sqlCtx, sqlEng, dArgs, dw) + verr := diffDoltSchemasTable(queryist, sqlCtx, dArgs, dw) if verr != nil { return verr } @@ -640,57 +758,224 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) err return nil } -func shouldPrintTableDelta(tablesToPrint *set.StrSet, td diff.TableDelta) bool { +func shouldPrintTableDelta(tablesToPrint *set.StrSet, toTableName, fromTableName string) bool { // TODO: this should be case insensitive - return tablesToPrint.Contains(td.FromName) || tablesToPrint.Contains(td.ToName) + return tablesToPrint.Contains(fromTableName) || tablesToPrint.Contains(toTableName) } -func isDoltSchemasTable(td diff.TableDelta) bool { - return td.FromName == doltdb.SchemasTableName || td.ToName == doltdb.SchemasTableName +func isDoltSchemasTable(toTableName, fromTableName string) bool { + return fromTableName == doltdb.SchemasTableName || toTableName == doltdb.SchemasTableName +} + +func getTableInfoAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableName string, ref string) (diff.TableInfo, error) { + fks, err := getForeignKeysForTable(queryist, sqlCtx, tableName, ref) + if err != nil { + return diff.TableInfo{}, fmt.Errorf("error: unable to get foreign keys for table '%s': %w", tableName, err) + } + + sch, createStmt, err := getTableSchemaAtRef(queryist, sqlCtx, tableName, ref, fks) + if err != nil { + return diff.TableInfo{}, fmt.Errorf("error: unable to get schema for table '%s': %w", tableName, err) + } + + fksParentSch, err := getFkParentSchemas(queryist, sqlCtx, fks, ref) + if err != nil { + return diff.TableInfo{}, fmt.Errorf("error: unable to get parent schemas for foreign keys for table '%s': %w", tableName, err) + } + + tableInfo := diff.TableInfo{ + Name: tableName, + Sch: sch, + CreateStmt: createStmt, + Fks: fks, + FksParentSch: fksParentSch, + } + return tableInfo, nil +} + + + +func getTableSchemaAtRef(queryist cli.Queryist, sqlCtx *sql.Context, tableName string, ref string, fks []diff.ForeignKeyInfo) (sch schema.Schema, createStmt string, err error) { + var rows []sql.Row + q := fmt.Sprintf("show create table %s as of '%s'", tableName, ref) + rows, err = getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return sch, createStmt, err + } + + if len(rows) != 1 { + return sch, createStmt, fmt.Errorf("creating schema, expected 1 row, got %d", len(rows)) + } + createStmt = rows[0][1].(string) + + // append ; at the end, if one isn't there yet + if createStmt[len(createStmt)-1] != ';' { + createStmt += ";" + } + + + sch, err = schemaFromCreateTableStmt(sqlCtx, createStmt) + if err != nil { + return sch, createStmt, err + } + + return sch, createStmt, nil +} + +func schemaFromCreateTableStmt(sqlCtx *sql.Context, createTableStmt string) (schema.Schema, error) { + p, err := sqlparser.Parse(createTableStmt) + if err != nil { + return nil, err + } + ddl := p.(*sqlparser.DDL) + + s, _, err := parse.TableSpecToSchema(sqlCtx, ddl.TableSpec, false) + if err != nil { + return nil, err + } + + cols := []schema.Column{} + for _, col := range s.Schema { + typeInfo, err := typeinfo.FromSqlType(col.Type) + if err != nil { + return nil, err + } + + sCol, err := schema.NewColumnWithTypeInfo( + col.Name, + 0, + typeInfo, + col.PrimaryKey, + col.Default.String(), + col.AutoIncrement, + col.Comment, + ) + cols = append(cols, sCol) + } + + sch, err := schema.NewSchema(schema.NewColCollection(cols...), nil, schema.Collation_Default, nil, nil) + if err != nil { + return nil, err + } + + return sch, err +} + +func getTableDiffStats(queryist cli.Queryist, sqlCtx *sql.Context, tableName, fromRef, toRef string) ([]diffStatistics, error) { + q := fmt.Sprintf("select * from dolt_diff_stat('%s', '%s', '%s')", fromRef, toRef, tableName) + rows, err := getRowsForSql(queryist, sqlCtx, q) + if err != nil { + return nil, fmt.Errorf("error running diff stats query: %w", err) + } + + allStats := []diffStatistics{} + for _, row := range rows { + stats := diffStatistics{ + TableName: row[0].(string), + RowsUnmodified: coallesceNilToUint64(row[1]), + RowsAdded: coallesceNilToUint64(row[2]), + RowsDeleted: coallesceNilToUint64(row[3]), + RowsModified: coallesceNilToUint64(row[4]), + CellsAdded: coallesceNilToUint64(row[5]), + CellsDeleted: coallesceNilToUint64(row[6]), + CellsModified: coallesceNilToUint64(row[7]), + OldRowCount: coallesceNilToUint64(row[8]), + NewRowCount: coallesceNilToUint64(row[9]), + OldCellCount: coallesceNilToUint64(row[10]), + NewCellCount: coallesceNilToUint64(row[11]), + } + allStats = append(allStats, stats) + } + return allStats, nil +} + +func coallesceNilToUint64(val interface{}) uint64 { + if val == nil { + return 0 + } + return uint64(val.(int64)) } func diffUserTable( - ctx *sql.Context, - td diff.TableDelta, - sqlEng *engine.SqlEngine, + queryist cli.Queryist, + sqlCtx *sql.Context, + tableSummary diff.TableDeltaSummary, dArgs *diffArgs, dw diffWriter, ) errhand.VerboseError { - fromTable := td.FromTable - toTable := td.ToTable + fromTable := tableSummary.FromTableName + toTable := tableSummary.ToTableName - if fromTable == nil && toTable == nil { - return errhand.BuildDError("error: both tables in tableDelta are nil").Build() - } - - err := dw.BeginTable(ctx, td) + err := dw.BeginTable(tableSummary.FromTableName, tableSummary.ToTableName, tableSummary.IsAdd(), tableSummary.IsDrop()) if err != nil { return errhand.VerboseErrorFromError(err) } - fromSch, toSch, err := td.GetSchemas(ctx) - if err != nil { - return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() + var fromTableInfo, toTableInfo *diff.TableInfo + + from, err := getTableInfoAtRef(queryist, sqlCtx, fromTable, dArgs.fromRef) + if err == nil { + fromTableInfo = &from + } + to, err := getTableInfoAtRef(queryist, sqlCtx, toTable, dArgs.toRef) + if err == nil { + toTableInfo = &to + } + + tableName := fromTable + if tableName == "" { + tableName = toTable } if dArgs.diffParts&Stat != 0 { - return printDiffStat(ctx, td, fromSch.GetAllCols().Size(), toSch.GetAllCols().Size()) + var areTablesKeyless = false + + var fromKeyless = false + if fromTableInfo != nil { + fromKeyless = schema.IsKeyless(fromTableInfo.Sch) + } + var toKeyless = false + if toTableInfo != nil { + toKeyless = schema.IsKeyless(toTableInfo.Sch) + } + + // nil table is neither keyless nor keyed + if fromTableInfo == nil { + areTablesKeyless = toKeyless + } else if toTableInfo == nil { + areTablesKeyless = fromKeyless + } else { + if fromKeyless && toKeyless { + areTablesKeyless = true + } else if !fromKeyless && !toKeyless { + areTablesKeyless = false + } else { + return errhand.BuildDError("mismatched keyless and keyed schemas for table %s", tableName).Build() + } + } + + diffStats, err := getTableDiffStats(queryist, sqlCtx, tableName, dArgs.fromRef, dArgs.toRef) + if err != nil { + return errhand.BuildDError("cannot retrieve diff stats between '%s' and '%s'", dArgs.fromRef, dArgs.toRef).AddCause(err).Build() + } + + return printDiffStat(diffStats, fromTableInfo.Sch.GetAllCols().Size(), toTableInfo.Sch.GetAllCols().Size(), areTablesKeyless) } if dArgs.diffParts&SchemaOnlyDiff != 0 { - err := dw.WriteTableSchemaDiff(ctx, dArgs.fromRoot, dArgs.toRoot, td) + err := dw.WriteTableSchemaDiff(fromTableInfo, toTableInfo, tableSummary) if err != nil { return errhand.VerboseErrorFromError(err) } } - if td.IsDrop() && dArgs.diffOutput == SQLDiffOutput { + if tableSummary.IsDrop() && dArgs.diffOutput == SQLDiffOutput { return nil // don't output DELETE FROM statements after DROP TABLE - } else if td.IsAdd() { - fromSch = toSch + } else if tableSummary.IsAdd() { + //fromSch = toSch } - verr := diffRows(ctx, sqlEng, td, dArgs, dw) + verr := diffRows(queryist, sqlCtx, tableSummary, fromTableInfo, toTableInfo, dArgs, dw) if verr != nil { return verr } @@ -699,8 +984,8 @@ func diffUserTable( } func diffDoltSchemasTable( + queryist cli.Queryist, sqlCtx *sql.Context, - sqlEng *engine.SqlEngine, dArgs *diffArgs, dw diffWriter, ) errhand.VerboseError { @@ -709,7 +994,8 @@ func diffDoltSchemasTable( "order by coalesce(from_type, to_type), coalesce(from_name, to_name)", dArgs.fromRef, dArgs.toRef, doltdb.SchemasTableName) - _, rowIter, err := sqlEng.Query(sqlCtx, query) + //_, rowIter, err := sqlEng.Query(sqlCtx, query) + _, rowIter, err := queryist.Query(sqlCtx, query) if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } @@ -779,26 +1065,74 @@ func diffDoltSchemasTable( return nil } +// ArePrimaryKeySetsDiffable checks if two schemas are diffable. Assumes the +// passed in schema are from the same table between commits. If __DOLT__, then +// it also checks if the underlying SQL types of the columns are equal. +func arePrimaryKeySetsDiffable(fromTableInfo, toTableInfo *diff.TableInfo) bool { + var fromSch schema.Schema = nil + var toSch schema.Schema = nil + if fromTableInfo != nil { + fromSch = fromTableInfo.Sch + } + if toTableInfo != nil { + toSch = toTableInfo.Sch + } + + if fromSch == nil && toSch == nil { + return false + // Empty case + } else if fromSch == nil || fromSch.GetAllCols().Size() == 0 || + toSch == nil || toSch.GetAllCols().Size() == 0 { + return true + } + + // Keyless case for comparing + if schema.IsKeyless(fromSch) && schema.IsKeyless(toSch) { + return true + } + + cc1 := fromSch.GetPKCols() + cc2 := toSch.GetPKCols() + + if cc1.Size() != cc2.Size() { + return false + } + + for i := 0; i < cc1.Size(); i++ { + c1 := cc1.GetByIndex(i) + c2 := cc2.GetByIndex(i) + if c1.IsPartOfPK != c2.IsPartOfPK { + return false + } + if !c1.TypeInfo.ToSqlType().Equals(c2.TypeInfo.ToSqlType()) { + return false + } + } + + return true +} + func diffRows( - ctx *sql.Context, - sqlEng *engine.SqlEngine, - td diff.TableDelta, + queryist cli.Queryist, + sqlCtx *sql.Context, + tableSummary diff.TableDeltaSummary, + fromTableInfo, toTableInfo *diff.TableInfo, dArgs *diffArgs, dw diffWriter, ) errhand.VerboseError { - diffable := schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch) - canSqlDiff := !(td.ToSch == nil || (td.FromSch != nil && !schema.SchemasAreEqual(td.FromSch, td.ToSch))) + diffable := arePrimaryKeySetsDiffable(fromTableInfo, toTableInfo) + canSqlDiff := !(toTableInfo == nil || (fromTableInfo != nil && !schema.SchemasAreEqual(fromTableInfo.Sch, toTableInfo.Sch))) var toSch, fromSch sql.Schema - if td.FromSch != nil { - pkSch, err := sqlutil.FromDoltSchema(td.FromName, td.FromSch) + if fromTableInfo != nil { + pkSch, err := sqlutil.FromDoltSchema(fromTableInfo.Name, fromTableInfo.Sch) if err != nil { return errhand.VerboseErrorFromError(err) } fromSch = pkSch.Schema } - if td.ToSch != nil { - pkSch, err := sqlutil.FromDoltSchema(td.ToName, td.ToSch) + if toTableInfo != nil { + pkSch, err := sqlutil.FromDoltSchema(toTableInfo.Name, toTableInfo.Sch) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -808,7 +1142,7 @@ func diffRows( unionSch := unionSchemas(fromSch, toSch) // We always instantiate a RowWriter in case the diffWriter needs it to close off any work from schema output - rowWriter, err := dw.RowWriter(ctx, td, unionSch) + rowWriter, err := dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, unionSch) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -816,16 +1150,16 @@ func diffRows( // can't diff if !diffable { // TODO: this messes up some structured output if the user didn't redirect it - cli.PrintErrf("Primary key sets differ between revisions for table '%s', skipping data diff\n", td.ToName) - err := rowWriter.Close(ctx) + cli.PrintErrf("Primary key sets differ between revisions for table '%s', skipping data diff\n", tableSummary.ToTableName) + err := rowWriter.Close(sqlCtx) if err != nil { return errhand.VerboseErrorFromError(err) } return nil } else if dArgs.diffOutput == SQLDiffOutput && !canSqlDiff { // TODO: this is overly broad, we can absolutely do better - _, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff for table '%s'\n", td.ToName) - err := rowWriter.Close(ctx) + _, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff for table '%s'\n", tableSummary.ToTableName) + err := rowWriter.Close(sqlCtx) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -834,7 +1168,7 @@ func diffRows( // no data diff requested if dArgs.diffParts&DataOnlyDiff == 0 { - err := rowWriter.Close(ctx) + err := rowWriter.Close(sqlCtx) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -842,12 +1176,12 @@ func diffRows( } // do the data diff - tableName := td.ToName + tableName := tableSummary.ToTableName if len(tableName) == 0 { - tableName = td.FromName + tableName = tableSummary.FromTableName } - columns := getColumnNamesString(td.FromSch, td.ToSch) + columns := getColumnNamesString(fromTableInfo, toTableInfo) query := fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columns, "diff_type", dArgs.fromRef, dArgs.toRef, tableName) if len(dArgs.where) > 0 { @@ -858,19 +1192,19 @@ func diffRows( query += " limit " + strconv.Itoa(dArgs.limit) } - sch, rowIter, err := sqlEng.Query(ctx, query) + sch, rowIter, err := queryist.Query(sqlCtx, query) if sql.ErrSyntaxError.Is(err) { return errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build() } else if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } - defer rowIter.Close(ctx) - defer rowWriter.Close(ctx) + defer rowIter.Close(sqlCtx) + defer rowWriter.Close(sqlCtx) var modifiedColNames map[string]bool if dArgs.skinny { - modifiedColNames, err = getModifiedCols(ctx, rowIter, unionSch, sch) + modifiedColNames, err = getModifiedCols(sqlCtx, rowIter, unionSch, sch) if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } @@ -886,19 +1220,20 @@ func diffRows( } // instantiate a new RowWriter with the new schema that only contains the columns with changes - rowWriter, err = dw.RowWriter(ctx, td, filteredUnionSch) + rowWriter, err = dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, filteredUnionSch) if err != nil { return errhand.VerboseErrorFromError(err) } - defer rowWriter.Close(ctx) + defer rowWriter.Close(sqlCtx) // reset the row iterator - err = rowIter.Close(ctx) + err = rowIter.Close(sqlCtx) if err != nil { return errhand.BuildDError("Error closing row iterator:\n%s", query).AddCause(err).Build() } - _, rowIter, err = sqlEng.Query(ctx, query) - defer rowIter.Close(ctx) + //_, rowIter, err = sqlEng.Query(ctx, query) + _, rowIter, err = queryist.Query(sqlCtx, query) + defer rowIter.Close(sqlCtx) if sql.ErrSyntaxError.Is(err) { return errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build() } else if err != nil { @@ -906,7 +1241,7 @@ func diffRows( } } - err = writeDiffResults(ctx, sch, unionSch, rowIter, rowWriter, modifiedColNames, dArgs) + err = writeDiffResults(sqlCtx, sch, unionSch, rowIter, rowWriter, modifiedColNames, dArgs) if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } @@ -927,7 +1262,15 @@ func unionSchemas(s1 sql.Schema, s2 sql.Schema) sql.Schema { return union } -func getColumnNamesString(fromSch, toSch schema.Schema) string { +func getColumnNamesString(fromTableInfo, toTableInfo *diff.TableInfo) string { + var fromSch, toSch schema.Schema + if fromTableInfo != nil { + fromSch = fromTableInfo.Sch + } + if toTableInfo != nil { + toSch = toTableInfo.Sch + } + var cols []string if fromSch != nil { fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { diff --git a/go/cmd/dolt/commands/diff_output.go b/go/cmd/dolt/commands/diff_output.go index dc90a1da32..6381037688 100644 --- a/go/cmd/dolt/commands/diff_output.go +++ b/go/cmd/dolt/commands/diff_output.go @@ -28,24 +28,21 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/errhand" "github.com/dolthub/dolt/go/libraries/doltcore/diff" - "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" - "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json" "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/sqlexport" "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/tabular" "github.com/dolthub/dolt/go/libraries/utils/iohelp" - "github.com/dolthub/dolt/go/store/atomicerr" ) // diffWriter is an interface that lets us write diffs in a variety of output formats type diffWriter interface { // BeginTable is called when a new table is about to be written, before any schema or row diffs are written - BeginTable(ctx context.Context, td diff.TableDelta) error + BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error // WriteTableSchemaDiff is called to write a schema diff for the table given (if requested by args) - WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error + WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error // WriteEventDiff is called to write an event diff WriteEventDiff(ctx context.Context, eventName, oldDefn, newDefn string) error // WriteTriggerDiff is called to write a trigger diff @@ -54,7 +51,7 @@ type diffWriter interface { WriteViewDiff(ctx context.Context, viewName, oldDefn, newDefn string) error // RowWriter returns a row writer for the table delta provided, which will have Close() called on it when rows are // done being written. - RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) + RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) // Close finalizes the work of the writer Close(ctx context.Context) error } @@ -73,34 +70,20 @@ func newDiffWriter(diffOutput diffOutput) (diffWriter, error) { } } -func printDiffStat(ctx context.Context, td diff.TableDelta, oldColLen, newColLen int) errhand.VerboseError { - // todo: use errgroup.Group - ae := atomicerr.New() - ch := make(chan diff.DiffStatProgress) - go func() { - defer close(ch) - err := diff.StatForTableDelta(ctx, ch, td) - - ae.SetIfError(err) - }() - +func printDiffStat(diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) errhand.VerboseError { acc := diff.DiffStatProgress{} var count int64 var pos int eP := cli.NewEphemeralPrinter() - for p := range ch { - if ae.IsSet() { - break - } - - acc.Adds += p.Adds - acc.Removes += p.Removes - acc.Changes += p.Changes - acc.CellChanges += p.CellChanges - acc.NewRowSize += p.NewRowSize - acc.OldRowSize += p.OldRowSize - acc.NewCellSize += p.NewCellSize - acc.OldCellSize += p.OldCellSize + for _, diffStat := range diffStats { + acc.Adds += diffStat.RowsAdded + acc.Removes += diffStat.RowsDeleted + acc.Changes += diffStat.RowsModified + acc.CellChanges += diffStat.CellsModified + acc.NewRowSize += diffStat.NewRowCount + acc.OldRowSize += diffStat.OldRowCount + acc.NewCellSize += diffStat.NewCellCount + acc.OldCellSize += diffStat.OldCellCount if count%10000 == 0 { eP.Printf("prev size: %d, new size: %d, adds: %d, deletes: %d, modifications: %d\n", acc.OldRowSize, acc.NewRowSize, acc.Adds, acc.Removes, acc.Changes) @@ -112,21 +95,12 @@ func printDiffStat(ctx context.Context, td diff.TableDelta, oldColLen, newColLen pos = cli.DeleteAndPrint(pos, "") - if err := ae.Get(); err != nil { - return errhand.BuildDError("").AddCause(err).Build() - } - - keyless, err := td.IsKeyless(ctx) - if err != nil { - return errhand.BuildDError("").AddCause(err).Build() - } - if (acc.Adds+acc.Removes+acc.Changes) == 0 && (acc.OldCellSize-acc.NewCellSize) == 0 { cli.Println("No data changes. See schema changes by using -s or --schema.") return nil } - if keyless { + if areTablesKeyless { printKeylessStat(acc) } else { printStat(acc, oldColLen, newColLen) @@ -195,56 +169,31 @@ func (t tabularDiffWriter) Close(ctx context.Context) error { return nil } -func (t tabularDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { +func (t tabularDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error { bold := color.New(color.Bold) - if td.IsDrop() { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.FromName) + if isDrop { + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", fromTableName, fromTableName) _, _ = bold.Println("deleted table") - } else if td.IsAdd() { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.ToName, td.ToName) + } else if isAdd { + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", toTableName, toTableName) _, _ = bold.Println("added table") } else { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.ToName) - h1, err := td.FromTable.HashOf() - - if err != nil { - panic(err) - } - - _, _ = bold.Printf("--- a/%s @ %s\n", td.FromName, h1.String()) - - h2, err := td.ToTable.HashOf() - - if err != nil { - panic(err) - } - - _, _ = bold.Printf("+++ b/%s @ %s\n", td.ToName, h2.String()) + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", fromTableName, toTableName) + _, _ = bold.Printf("--- a/%s\n", fromTableName) + _, _ = bold.Printf("+++ b/%s\n", toTableName) } return nil } -func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error { +func (t tabularDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error { var fromCreateStmt = "" - if td.FromTable != nil { - sqlDb := sqle.NewUserSpaceDatabase(fromRoot, editor.Options{}) - sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) - var err error - fromCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.FromName) - if err != nil { - return errhand.VerboseErrorFromError(err) - } + if fromTableInfo != nil { + fromCreateStmt = fromTableInfo.CreateStmt } var toCreateStmt = "" - if td.ToTable != nil { - sqlDb := sqle.NewUserSpaceDatabase(toRoot, editor.Options{}) - sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) - var err error - toCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName) - if err != nil { - return errhand.VerboseErrorFromError(err) - } + if toTableInfo != nil { + toCreateStmt = toTableInfo.CreateStmt } if fromCreateStmt != toCreateStmt { @@ -252,18 +201,22 @@ func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *d } resolvedFromFks := map[string]struct{}{} - for _, fk := range td.FromFks { - if len(fk.ReferencedTableColumns) > 0 { - resolvedFromFks[fk.Name] = struct{}{} + if fromTableInfo != nil { + for _, fk := range fromTableInfo.Fks { + if len(fk.ReferencedTableColumns) > 0 { + resolvedFromFks[fk.Name] = struct{}{} + } } } - for _, fk := range td.ToFks { - if _, ok := resolvedFromFks[fk.Name]; ok { - continue - } - if len(fk.ReferencedTableColumns) > 0 { - cli.Println(fmt.Sprintf("resolved foreign key `%s` on table `%s`", fk.Name, fk.TableName)) + if toTableInfo != nil { + for _, fk := range toTableInfo.Fks { + if _, ok := resolvedFromFks[fk.Name]; ok { + continue + } + if len(fk.ReferencedTableColumns) > 0 { + cli.Println(fmt.Sprintf("resolved foreign key `%s` on table `%s`", fk.Name, fk.TableName)) + } } } @@ -286,7 +239,7 @@ func (t tabularDiffWriter) WriteViewDiff(ctx context.Context, viewName, oldDefn, return nil } -func (t tabularDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { +func (t tabularDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { return tabular.NewFixedWidthDiffTableWriter(unionSch, iohelp.NopWrCloser(cli.CliOut), 100), nil } @@ -298,17 +251,16 @@ func (s sqlDiffWriter) Close(ctx context.Context) error { return nil } -func (s sqlDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { +func (s sqlDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error { return nil } -func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error { - toSchemas, err := toRoot.GetAllSchemas(ctx) - if err != nil { - return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build() - } +func (s sqlDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error { + //for _, stmt := range tds.AlterStmts { + // cli.Println(stmt) + //} - ddlStatements, err := diff.SqlSchemaDiff(ctx, td, toSchemas) + ddlStatements, err := diff.SqlSchemaDiff(fromTableInfo, toTableInfo, tds) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -362,13 +314,16 @@ func (s sqlDiffWriter) WriteViewDiff(ctx context.Context, viewName, oldDefn, new return nil } -func (s sqlDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { - targetSch := td.ToSch +func (s sqlDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { + var targetSch schema.Schema + if toTableInfo != nil { + targetSch = toTableInfo.Sch + } if targetSch == nil { - targetSch = td.FromSch + targetSch = fromTableInfo.Sch } - return sqlexport.NewSqlDiffWriter(td.ToName, targetSch, iohelp.NopWrCloser(cli.CliOut)), nil + return sqlexport.NewSqlDiffWriter(tds.ToTableName, targetSch, iohelp.NopWrCloser(cli.CliOut)), nil } type jsonDiffWriter struct { @@ -402,7 +357,7 @@ func (j *jsonDiffWriter) beginDocumentIfNecessary() error { return nil } -func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { +func (j *jsonDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error { err := j.beginDocumentIfNecessary() if err != nil { return err @@ -420,9 +375,9 @@ func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) err } } - tableName := td.FromName + tableName := fromTableName if len(tableName) == 0 { - tableName = td.ToName + tableName = toTableName } err = iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(jsonDiffTableHeader, tableName))) @@ -436,19 +391,21 @@ func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) err return err } -func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error { - toSchemas, err := toRoot.GetAllSchemas(ctx) - if err != nil { - return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build() - } +func (j *jsonDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error { + //for _, stmt := range tds.AlterStmts { + // err := j.schemaDiffWriter.WriteSchemaDiff(stmt) + // if err != nil { + // return err + // } + //} - stmts, err := diff.SqlSchemaDiff(ctx, td, toSchemas) + stmts, err := diff.SqlSchemaDiff(fromTableInfo, toTableInfo, tds) if err != nil { return err } for _, stmt := range stmts { - err := j.schemaDiffWriter.WriteSchemaDiff(ctx, stmt) + err := j.schemaDiffWriter.WriteSchemaDiff(stmt) if err != nil { return err } @@ -457,7 +414,7 @@ func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *dol return nil } -func (j *jsonDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { +func (j *jsonDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { // close off the schema diff block, start the data block err := iohelp.WriteAll(j.wr, []byte(jsonDiffDataDiffHeader)) if err != nil { diff --git a/go/cmd/dolt/commands/show.go b/go/cmd/dolt/commands/show.go index 026889a099..e12527cb83 100644 --- a/go/cmd/dolt/commands/show.go +++ b/go/cmd/dolt/commands/show.go @@ -14,353 +14,353 @@ package commands -import ( - "context" - "fmt" - "regexp" - "strings" - - "github.com/pkg/errors" - - "github.com/dolthub/dolt/go/cmd/dolt/cli" - "github.com/dolthub/dolt/go/cmd/dolt/errhand" - eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" - "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" - "github.com/dolthub/dolt/go/libraries/doltcore/env" - "github.com/dolthub/dolt/go/libraries/utils/argparser" - "github.com/dolthub/dolt/go/store/datas" - "github.com/dolthub/dolt/go/store/hash" - "github.com/dolthub/dolt/go/store/util/outputpager" -) - -var hashRegex = regexp.MustCompile(`^#?[0-9a-v]{32}$`) - -type showOpts struct { - showParents bool - pretty bool - decoration string - specRefs []string - - *diffDisplaySettings -} - -var showDocs = cli.CommandDocumentationContent{ - ShortDesc: `Show information about a specific commit`, - LongDesc: `Show information about a specific commit`, - Synopsis: []string{ - `[{{.LessThan}}revision{{.GreaterThan}}]`, - }, -} - -type ShowCmd struct{} - -// Name returns the name of the Dolt cli command. This is what is used on the command line to invoke the command -func (cmd ShowCmd) Name() string { - return "show" -} - -// Description returns a description of the command -func (cmd ShowCmd) Description() string { - return "Show information about a specific commit." -} - -// EventType returns the type of the event to log -func (cmd ShowCmd) EventType() eventsapi.ClientEventType { - return eventsapi.ClientEventType_SHOW -} - -func (cmd ShowCmd) Docs() *cli.CommandDocumentation { - ap := cmd.ArgParser() - return cli.NewCommandDocumentation(showDocs, ap) -} - -func (cmd ShowCmd) ArgParser() *argparser.ArgParser { - ap := argparser.NewArgParserWithVariableArgs(cmd.Name()) - // Flags inherited from Log - ap.SupportsFlag(cli.ParentsFlag, "", "Shows all parents of each commit in the log.") - ap.SupportsString(cli.DecorateFlag, "", "decorate_fmt", "Shows refs next to commits. Valid options are short, full, no, and auto") - ap.SupportsFlag(cli.NoPrettyFlag, "", "Show the object without making it pretty.") - - // Flags inherited from Diff - ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") - ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") - ap.SupportsFlag(StatFlag, "", "Show stats of data changes") - ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes") - ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") - ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") - ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") - ap.SupportsFlag(cli.CachedFlag, "c", "Show only the staged data changes.") - ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") - ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") - ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") - return ap -} - -// Exec executes the command -func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int { - ap := cmd.ArgParser() - help, usage := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, showDocs, ap)) - apr := cli.ParseArgsOrDie(ap, args, help) - - opts, err := parseShowArgs(ctx, dEnv, apr) - if err != nil { - return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) - } - - if err := cmd.validateArgs(apr); err != nil { - return handleErrAndExit(err) - } - - if !opts.pretty && !dEnv.DoltDB.Format().UsesFlatbuffers() { - cli.PrintErrln("dolt show --no-pretty is not supported when using old LD_1 storage format.") - return 1 - } - - opts.diffDisplaySettings = parseDiffDisplaySettings(ctx, dEnv, apr) - - err = showObjects(ctx, dEnv, opts) - - return handleErrAndExit(err) -} - -func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError { - if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) { - if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) { - return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build() - } - } - - f, _ := apr.GetValue(FormatFlag) - switch strings.ToLower(f) { - case "tabular", "sql", "json", "": - default: - return errhand.BuildDError("invalid output format: %s", f).Build() - } - - return nil -} - -func parseShowArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (*showOpts, error) { - - decorateOption := apr.GetValueOrDefault(cli.DecorateFlag, "auto") - switch decorateOption { - case "short", "full", "auto", "no": - default: - return nil, fmt.Errorf("fatal: invalid --decorate option: %s", decorateOption) - } - - return &showOpts{ - showParents: apr.Contains(cli.ParentsFlag), - pretty: !apr.Contains(cli.NoPrettyFlag), - decoration: decorateOption, - specRefs: apr.Args, - }, nil -} - -func showObjects(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts) error { - if len(opts.specRefs) == 0 { - headRef, err := dEnv.RepoStateReader().CWBHeadSpec() - if err != nil { - return err - } - return showCommitSpec(ctx, dEnv, opts, headRef) - } - - for _, specRef := range opts.specRefs { - err := showSpecRef(ctx, dEnv, opts, specRef) - if err != nil { - return err - } - } - - return nil -} - -// parseHashString converts a string representing a hash into a hash.Hash. -func parseHashString(hashStr string) (hash.Hash, error) { - unprefixed := strings.TrimPrefix(hashStr, "#") - parsedHash, ok := hash.MaybeParse(unprefixed) - if !ok { - return hash.Hash{}, errors.New("invalid hash: " + hashStr) - } - return parsedHash, nil -} - -func showSpecRef(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, specRef string) error { - roots, err := dEnv.Roots(ctx) - if err != nil { - return err - } - - upperCaseSpecRef := strings.ToUpper(specRef) - if upperCaseSpecRef == doltdb.Working || upperCaseSpecRef == doltdb.Staged || hashRegex.MatchString(specRef) { - var refHash hash.Hash - var err error - if upperCaseSpecRef == doltdb.Working { - refHash, err = roots.Working.HashOf() - } else if upperCaseSpecRef == doltdb.Staged { - refHash, err = roots.Staged.HashOf() - } else { - refHash, err = parseHashString(specRef) - } - if err != nil { - return err - } - value, err := dEnv.DoltDB.ValueReadWriter().ReadValue(ctx, refHash) - if err != nil { - return err - } - if value == nil { - return fmt.Errorf("Unable to resolve object ref %s", specRef) - } - - if !opts.pretty { - cli.Println(value.Kind(), value.HumanReadableString()) - } - - // If this is a commit, use the pretty printer. To determine whether it's a commit, try calling NewCommitFromValue. - commit, err := doltdb.NewCommitFromValue(ctx, dEnv.DoltDB.ValueReadWriter(), dEnv.DoltDB.NodeStore(), value) - - if err == datas.ErrNotACommit { - if !dEnv.DoltDB.Format().UsesFlatbuffers() { - return fmt.Errorf("dolt show cannot show non-commit objects when using the old LD_1 storage format: %s is not a commit", specRef) - } - cli.Println(value.Kind(), value.HumanReadableString()) - } else if err == nil { - showCommit(ctx, dEnv, opts, commit) - } else { - return err - } - } else { // specRef is a CommitSpec, which must resolve to a Commit. - commitSpec, err := getCommitSpec(specRef) - if err != nil { - return err - } - - err = showCommitSpec(ctx, dEnv, opts, commitSpec) - if err != nil { - return err - } - } - return nil -} - -func showCommitSpec(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, commitSpec *doltdb.CommitSpec) error { - - headRef, err := dEnv.RepoStateReader().CWBHeadRef() - if err != nil { - return err - } - - commit, err := dEnv.DoltDB.Resolve(ctx, commitSpec, headRef) - if err != nil { - return err - } - - if opts.pretty { - err = showCommit(ctx, dEnv, opts, commit) - if err != nil { - return err - } - } else { - value := commit.Value() - cli.Println(value.Kind(), value.HumanReadableString()) - } - return nil -} - -func showCommit(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, comm *doltdb.Commit) error { - - cHashToRefs, err := getHashToRefs(ctx, dEnv, opts.decoration) - if err != nil { - return err - } - - meta, mErr := comm.GetCommitMeta(ctx) - if mErr != nil { - cli.PrintErrln("error: failed to get commit metadata") - return err - } - pHashes, pErr := comm.ParentHashes(ctx) - if pErr != nil { - cli.PrintErrln("error: failed to get parent hashes") - return err - } - cmHash, cErr := comm.HashOf() - if cErr != nil { - cli.PrintErrln("error: failed to get commit hash") - return err - } - - headRef, err := dEnv.RepoStateReader().CWBHeadRef() - if err != nil { - return err - } - cwbHash, err := dEnv.DoltDB.GetHashForRefStr(ctx, headRef.String()) - if err != nil { - return err - } - - cli.ExecuteWithStdioRestored(func() { - pager := outputpager.Start() - defer pager.Stop() - - PrintCommit(pager, 0, opts.showParents, opts.decoration, logNode{ - commitMeta: meta, - commitHash: cmHash, - parentHashes: pHashes, - branchNames: cHashToRefs[cmHash], - isHead: cmHash == *cwbHash}) - }) - - if comm.NumParents() == 0 { - return nil - } - - if comm.NumParents() > 1 { - return fmt.Errorf("requested commit is a merge commit. 'dolt show' currently only supports viewing non-merge commits") - } - - commitRoot, err := comm.GetRootValue(ctx) - if err != nil { - return err - } - - parent, err := comm.GetParent(ctx, 0) - if err != nil { - return err - } - - parentRoot, err := parent.GetRootValue(ctx) - if err != nil { - return err - } - - parentHash, err := parent.HashOf() - if err != nil { - return err - } - - datasets := &diffDatasets{ - fromRoot: parentRoot, - toRoot: commitRoot, - fromRef: parentHash.String(), - toRef: cmHash.String(), - } - - // An empty string will cause all tables to be printed. - var tableNames []string - - tableSet, err := parseDiffTableSet(ctx, dEnv, datasets, tableNames) - if err != nil { - return err - } - - dArgs := &diffArgs{ - diffDisplaySettings: opts.diffDisplaySettings, - diffDatasets: datasets, - tableSet: tableSet, - } - - return diffUserTables(ctx, dEnv, dArgs) -} +//import ( +// "context" +// "fmt" +// "regexp" +// "strings" +// +// "github.com/pkg/errors" +// +// "github.com/dolthub/dolt/go/cmd/dolt/cli" +// "github.com/dolthub/dolt/go/cmd/dolt/errhand" +// eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" +// "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" +// "github.com/dolthub/dolt/go/libraries/doltcore/env" +// "github.com/dolthub/dolt/go/libraries/utils/argparser" +// "github.com/dolthub/dolt/go/store/datas" +// "github.com/dolthub/dolt/go/store/hash" +// "github.com/dolthub/dolt/go/store/util/outputpager" +//) +// +//var hashRegex = regexp.MustCompile(`^#?[0-9a-v]{32}$`) +// +//type showOpts struct { +// showParents bool +// pretty bool +// decoration string +// specRefs []string +// +// *diffDisplaySettings +//} +// +//var showDocs = cli.CommandDocumentationContent{ +// ShortDesc: `Show information about a specific commit`, +// LongDesc: `Show information about a specific commit`, +// Synopsis: []string{ +// `[{{.LessThan}}revision{{.GreaterThan}}]`, +// }, +//} +// +//type ShowCmd struct{} +// +//// Name returns the name of the Dolt cli command. This is what is used on the command line to invoke the command +//func (cmd ShowCmd) Name() string { +// return "show" +//} +// +//// Description returns a description of the command +//func (cmd ShowCmd) Description() string { +// return "Show information about a specific commit." +//} +// +//// EventType returns the type of the event to log +//func (cmd ShowCmd) EventType() eventsapi.ClientEventType { +// return eventsapi.ClientEventType_SHOW +//} +// +//func (cmd ShowCmd) Docs() *cli.CommandDocumentation { +// ap := cmd.ArgParser() +// return cli.NewCommandDocumentation(showDocs, ap) +//} +// +//func (cmd ShowCmd) ArgParser() *argparser.ArgParser { +// ap := argparser.NewArgParserWithVariableArgs(cmd.Name()) +// // Flags inherited from Log +// ap.SupportsFlag(cli.ParentsFlag, "", "Shows all parents of each commit in the log.") +// ap.SupportsString(cli.DecorateFlag, "", "decorate_fmt", "Shows refs next to commits. Valid options are short, full, no, and auto") +// ap.SupportsFlag(cli.NoPrettyFlag, "", "Show the object without making it pretty.") +// +// // Flags inherited from Diff +// ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") +// ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") +// ap.SupportsFlag(StatFlag, "", "Show stats of data changes") +// ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes") +// ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") +// ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") +// ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") +// ap.SupportsFlag(cli.CachedFlag, "c", "Show only the staged data changes.") +// ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") +// ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") +// ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.") +// return ap +//} +// +//// Exec executes the command +//func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int { +// ap := cmd.ArgParser() +// help, usage := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, showDocs, ap)) +// apr := cli.ParseArgsOrDie(ap, args, help) +// +// opts, err := parseShowArgs(ctx, dEnv, apr) +// if err != nil { +// return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) +// } +// +// if err := cmd.validateArgs(apr); err != nil { +// return handleErrAndExit(err) +// } +// +// if !opts.pretty && !dEnv.DoltDB.Format().UsesFlatbuffers() { +// cli.PrintErrln("dolt show --no-pretty is not supported when using old LD_1 storage format.") +// return 1 +// } +// +// opts.diffDisplaySettings = parseDiffDisplaySettings(ctx, dEnv, apr) +// +// err = showObjects(ctx, dEnv, opts) +// +// return handleErrAndExit(err) +//} +// +//func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError { +// if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) { +// if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) { +// return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build() +// } +// } +// +// f, _ := apr.GetValue(FormatFlag) +// switch strings.ToLower(f) { +// case "tabular", "sql", "json", "": +// default: +// return errhand.BuildDError("invalid output format: %s", f).Build() +// } +// +// return nil +//} +// +//func parseShowArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (*showOpts, error) { +// +// decorateOption := apr.GetValueOrDefault(cli.DecorateFlag, "auto") +// switch decorateOption { +// case "short", "full", "auto", "no": +// default: +// return nil, fmt.Errorf("fatal: invalid --decorate option: %s", decorateOption) +// } +// +// return &showOpts{ +// showParents: apr.Contains(cli.ParentsFlag), +// pretty: !apr.Contains(cli.NoPrettyFlag), +// decoration: decorateOption, +// specRefs: apr.Args, +// }, nil +//} +// +//func showObjects(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts) error { +// if len(opts.specRefs) == 0 { +// headRef, err := dEnv.RepoStateReader().CWBHeadSpec() +// if err != nil { +// return err +// } +// return showCommitSpec(ctx, dEnv, opts, headRef) +// } +// +// for _, specRef := range opts.specRefs { +// err := showSpecRef(ctx, dEnv, opts, specRef) +// if err != nil { +// return err +// } +// } +// +// return nil +//} +// +//// parseHashString converts a string representing a hash into a hash.Hash. +//func parseHashString(hashStr string) (hash.Hash, error) { +// unprefixed := strings.TrimPrefix(hashStr, "#") +// parsedHash, ok := hash.MaybeParse(unprefixed) +// if !ok { +// return hash.Hash{}, errors.New("invalid hash: " + hashStr) +// } +// return parsedHash, nil +//} +// +//func showSpecRef(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, specRef string) error { +// roots, err := dEnv.Roots(ctx) +// if err != nil { +// return err +// } +// +// upperCaseSpecRef := strings.ToUpper(specRef) +// if upperCaseSpecRef == doltdb.Working || upperCaseSpecRef == doltdb.Staged || hashRegex.MatchString(specRef) { +// var refHash hash.Hash +// var err error +// if upperCaseSpecRef == doltdb.Working { +// refHash, err = roots.Working.HashOf() +// } else if upperCaseSpecRef == doltdb.Staged { +// refHash, err = roots.Staged.HashOf() +// } else { +// refHash, err = parseHashString(specRef) +// } +// if err != nil { +// return err +// } +// value, err := dEnv.DoltDB.ValueReadWriter().ReadValue(ctx, refHash) +// if err != nil { +// return err +// } +// if value == nil { +// return fmt.Errorf("Unable to resolve object ref %s", specRef) +// } +// +// if !opts.pretty { +// cli.Println(value.Kind(), value.HumanReadableString()) +// } +// +// // If this is a commit, use the pretty printer. To determine whether it's a commit, try calling NewCommitFromValue. +// commit, err := doltdb.NewCommitFromValue(ctx, dEnv.DoltDB.ValueReadWriter(), dEnv.DoltDB.NodeStore(), value) +// +// if err == datas.ErrNotACommit { +// if !dEnv.DoltDB.Format().UsesFlatbuffers() { +// return fmt.Errorf("dolt show cannot show non-commit objects when using the old LD_1 storage format: %s is not a commit", specRef) +// } +// cli.Println(value.Kind(), value.HumanReadableString()) +// } else if err == nil { +// showCommit(ctx, dEnv, opts, commit) +// } else { +// return err +// } +// } else { // specRef is a CommitSpec, which must resolve to a Commit. +// commitSpec, err := getCommitSpec(specRef) +// if err != nil { +// return err +// } +// +// err = showCommitSpec(ctx, dEnv, opts, commitSpec) +// if err != nil { +// return err +// } +// } +// return nil +//} +// +//func showCommitSpec(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, commitSpec *doltdb.CommitSpec) error { +// +// headRef, err := dEnv.RepoStateReader().CWBHeadRef() +// if err != nil { +// return err +// } +// +// commit, err := dEnv.DoltDB.Resolve(ctx, commitSpec, headRef) +// if err != nil { +// return err +// } +// +// if opts.pretty { +// err = showCommit(ctx, dEnv, opts, commit) +// if err != nil { +// return err +// } +// } else { +// value := commit.Value() +// cli.Println(value.Kind(), value.HumanReadableString()) +// } +// return nil +//} +// +//func showCommit(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts, comm *doltdb.Commit) error { +// +// cHashToRefs, err := getHashToRefs(ctx, dEnv, opts.decoration) +// if err != nil { +// return err +// } +// +// meta, mErr := comm.GetCommitMeta(ctx) +// if mErr != nil { +// cli.PrintErrln("error: failed to get commit metadata") +// return err +// } +// pHashes, pErr := comm.ParentHashes(ctx) +// if pErr != nil { +// cli.PrintErrln("error: failed to get parent hashes") +// return err +// } +// cmHash, cErr := comm.HashOf() +// if cErr != nil { +// cli.PrintErrln("error: failed to get commit hash") +// return err +// } +// +// headRef, err := dEnv.RepoStateReader().CWBHeadRef() +// if err != nil { +// return err +// } +// cwbHash, err := dEnv.DoltDB.GetHashForRefStr(ctx, headRef.String()) +// if err != nil { +// return err +// } +// +// cli.ExecuteWithStdioRestored(func() { +// pager := outputpager.Start() +// defer pager.Stop() +// +// PrintCommit(pager, 0, opts.showParents, opts.decoration, logNode{ +// commitMeta: meta, +// commitHash: cmHash, +// parentHashes: pHashes, +// branchNames: cHashToRefs[cmHash], +// isHead: cmHash == *cwbHash}) +// }) +// +// if comm.NumParents() == 0 { +// return nil +// } +// +// if comm.NumParents() > 1 { +// return fmt.Errorf("requested commit is a merge commit. 'dolt show' currently only supports viewing non-merge commits") +// } +// +// commitRoot, err := comm.GetRootValue(ctx) +// if err != nil { +// return err +// } +// +// parent, err := comm.GetParent(ctx, 0) +// if err != nil { +// return err +// } +// +// parentRoot, err := parent.GetRootValue(ctx) +// if err != nil { +// return err +// } +// +// parentHash, err := parent.HashOf() +// if err != nil { +// return err +// } +// +// datasets := &diffDatasets{ +// fromRoot: parentRoot, +// toRoot: commitRoot, +// fromRef: parentHash.String(), +// toRef: cmHash.String(), +// } +// +// // An empty string will cause all tables to be printed. +// var tableNames []string +// +// tableSet, err := parseDiffTableSet(ctx, dEnv, datasets, tableNames) +// if err != nil { +// return err +// } +// +// dArgs := &diffArgs{ +// diffDisplaySettings: opts.diffDisplaySettings, +// diffDatasets: datasets, +// tableSet: tableSet, +// } +// +// return diffUserTables(ctx, dEnv, dArgs) +//} diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index cc45cf99ac..99beec99cc 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -80,7 +80,7 @@ var doltSubCommands = []cli.Command{ sqlserver.SqlServerCmd{VersionStr: Version}, sqlserver.SqlClientCmd{VersionStr: Version}, commands.LogCmd{}, - commands.ShowCmd{}, + //commands.ShowCmd{}, commands.BranchCmd{}, commands.CheckoutCmd{}, commands.MergeCmd{}, @@ -121,14 +121,12 @@ var doltSubCommands = []cli.Command{ } var commandsWithoutCliCtx = []cli.Command{ - commands.DiffCmd{}, commands.ResetCmd{}, commands.CleanCmd{}, admin.Commands, sqlserver.SqlServerCmd{VersionStr: Version}, sqlserver.SqlClientCmd{VersionStr: Version}, commands.LogCmd{}, - commands.ShowCmd{}, commands.CheckoutCmd{}, cnfcmds.Commands, commands.CloneCmd{}, diff --git a/go/libraries/doltcore/diff/diff.go b/go/libraries/doltcore/diff/diff.go index 4bbc1173f8..6fb7d78181 100755 --- a/go/libraries/doltcore/diff/diff.go +++ b/go/libraries/doltcore/diff/diff.go @@ -85,7 +85,7 @@ type SqlRowDiffWriter interface { type SchemaDiffWriter interface { // WriteSchemaDiff writes the schema diff given (a SQL statement) and returns any error. A single table may have // many SQL statements for a single diff. WriteSchemaDiff will be called before any row diffs via |WriteRow| - WriteSchemaDiff(ctx context.Context, schemaDiffStatement string) error + WriteSchemaDiff(schemaDiffStatement string) error // Close finalizes the work of this writer. Close(ctx context.Context) error } diff --git a/go/libraries/doltcore/diff/table_deltas.go b/go/libraries/doltcore/diff/table_deltas.go index 0bc6337e2f..fcf4b145cc 100644 --- a/go/libraries/doltcore/diff/table_deltas.go +++ b/go/libraries/doltcore/diff/table_deltas.go @@ -21,12 +21,10 @@ import ( "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/dolt/go/cmd/dolt/errhand" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" "github.com/dolthub/dolt/go/libraries/utils/set" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" @@ -41,6 +39,11 @@ const ( RemovedTable ) +type TableInfo struct { + Name string + Sch schema.Schema + CreateStmt string +} // TableDelta represents the change of a single table between two roots. // FromFKs and ToFKs contain Foreign Keys that constrain columns in this table, // they do not contain Foreign Keys that reference this table. @@ -68,6 +71,24 @@ type TableDeltaSummary struct { TableName string FromTableName string ToTableName string + AlterStmts []string +} +// IsAdd returns true if the table was added between the fromRoot and toRoot. +func (tds TableDeltaSummary) IsAdd() bool { + return tds.FromTableName == "" && tds.ToTableName != "" +} + +// IsDrop returns true if the table was dropped between the fromRoot and toRoot. +func (tds TableDeltaSummary) IsDrop() bool { + return tds.FromTableName != "" && tds.ToTableName == "" +} + +// IsRename return true if the table was renamed between the fromRoot and toRoot. +func (tds TableDeltaSummary) IsRename() bool { + if tds.IsAdd() || tds.IsDrop() { + return false + } + return tds.FromTableName != tds.ToTableName } // GetStagedUnstagedTableDeltas represents staged and unstaged changes as TableDelta slices. @@ -562,27 +583,15 @@ func (td TableDelta) GetRowData(ctx context.Context) (from, to durable.Index, er // SqlSchemaDiff returns a slice of DDL statements that will transform the schema in the from delta to the schema in // the to delta. -func SqlSchemaDiff(ctx context.Context, td TableDelta, toSchemas map[string]schema.Schema) ([]string, error) { - fromSch, toSch, err := td.GetSchemas(ctx) - if err != nil { - return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error()) - } +func SqlSchemaDiff(fromTableInfo, toTableInfo *TableInfo, tds TableDeltaSummary) ([]string, error) { var ddlStatements []string - if td.IsDrop() { - ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName)) - } else if td.IsAdd() { - toPkSch, err := sqlutil.FromDoltSchema(td.ToName, td.ToSch) - if err != nil { - return nil, err - } - stmt, err := GenerateCreateTableStatement(td.ToName, td.ToSch, toPkSch, td.ToFks, td.ToFksParentSch) - if err != nil { - return nil, errhand.VerboseErrorFromError(err) - } - ddlStatements = append(ddlStatements, stmt) + if tds.IsDrop() { + ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(tds.FromTableName)) + } else if tds.IsAdd() { + ddlStatements = append(ddlStatements, toTableInfo.CreateStmt) } else { - stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) + stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(tds, fromTableInfo, toTableInfo) if err != nil { return nil, err } @@ -593,47 +602,32 @@ func SqlSchemaDiff(ctx context.Context, td TableDelta, toSchemas map[string]sche } // GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements. -func GetNonCreateNonDropTableSqlSchemaDiff(td TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) { - if td.IsAdd() || td.IsDrop() { +func GetNonCreateNonDropTableSqlSchemaDiff(tds TableDeltaSummary, fromTableInfo, toTableInfo *TableInfo) ([]string, error) { + if tds.IsAdd() || tds.IsDrop() { // use add and drop specific methods return nil, nil } var ddlStatements []string - if td.FromName != td.ToName { - ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName)) + if tds.FromTableName != tds.ToTableName { + ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(tds.FromTableName, tds.ToTableName)) } + fromSch := fromTableInfo.Sch + toSch := toTableInfo.Sch + eq := schema.SchemasAreEqual(fromSch, toSch) - if eq && !td.HasFKChanges() { + if eq && !hasFkChanges(fromTableInfo, toTableInfo) { return ddlStatements, nil } - colDiffs, unionTags := DiffSchColumns(fromSch, toSch) - for _, tag := range unionTags { - cd := colDiffs[tag] - switch cd.DiffType { - case SchDiffNone: - case SchDiffAdded: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New))) - case SchDiffRemoved: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) - case SchDiffModified: - // Ignore any primary key set changes here - if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { - continue - } - if cd.Old.Name != cd.New.Name { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) - } - } - } + ddlStatements = append(ddlStatements, tds.AlterStmts...) // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(tds.ToTableName)) if toSch.GetPKCols().Size() > 0 { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols())) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(tds.ToTableName, toSch.GetPKCols().GetColumnNames())) } } @@ -641,28 +635,27 @@ func GetNonCreateNonDropTableSqlSchemaDiff(td TableDelta, toSchemas map[string]s switch idxDiff.DiffType { case SchDiffNone: case SchDiffAdded: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(tds.ToTableName, idxDiff.To)) case SchDiffRemoved: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(tds.FromTableName, idxDiff.From)) case SchDiffModified: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(tds.FromTableName, idxDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(tds.ToTableName, idxDiff.To)) } } - for _, fkDiff := range DiffForeignKeys(td.FromFks, td.ToFks) { + for _, fkDiff := range DiffForeignKeyInfos(fromTableInfo.Fks, toTableInfo.Fks) { switch fkDiff.DiffType { case SchDiffNone: case SchDiffAdded: - parentSch := toSchemas[fkDiff.To.ReferencedTableName] - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + to := fkDiff.To + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmtSimple(to.TableName, to.Name, to.ReferencedTableName, to.TableColumns, to.ReferencedTableColumns)) case SchDiffRemoved: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From.TableName, fkDiff.From.Name)) case SchDiffModified: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) - - parentSch := toSchemas[fkDiff.To.ReferencedTableName] - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From.TableName, fkDiff.From.Name)) + to := fkDiff.To + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmtSimple(to.TableName, to.Name, to.ReferencedTableName, to.TableColumns, to.ReferencedTableColumns)) } } diff --git a/go/libraries/doltcore/sqle/dolt_patch_table_function.go b/go/libraries/doltcore/sqle/dolt_patch_table_function.go index d87155f79d..cb41499d48 100644 --- a/go/libraries/doltcore/sqle/dolt_patch_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_patch_table_function.go @@ -400,7 +400,7 @@ func getSchemaSqlPatch(ctx *sql.Context, toRoot *doltdb.RootValue, td diff.Table } ddlStatements = append(ddlStatements, stmt) } else { - stmts, err := diff.GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) + stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) if err != nil { return nil, err } @@ -498,6 +498,85 @@ func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema } } +// GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements. +func GetNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) { + if td.IsAdd() || td.IsDrop() { + // use add and drop specific methods + return nil, nil + } + + var ddlStatements []string + if td.FromName != td.ToName { + ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName)) + } + + eq := schema.SchemasAreEqual(fromSch, toSch) + if eq && !td.HasFKChanges() { + return ddlStatements, nil + } + + colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch) + for _, tag := range unionTags { + cd := colDiffs[tag] + switch cd.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New))) + case diff.SchDiffRemoved: + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) + case diff.SchDiffModified: + // Ignore any primary key set changes here + if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { + continue + } + if cd.Old.Name != cd.New.Name { + ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) + } + } + } + + // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD + if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName)) + if toSch.GetPKCols().Size() > 0 { + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames())) + } + } + + for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) { + switch idxDiff.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + case diff.SchDiffRemoved: + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + case diff.SchDiffModified: + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + } + } + + for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) { + switch fkDiff.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + parentSch := toSchemas[fkDiff.To.ReferencedTableName] + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + case diff.SchDiffRemoved: + from := fkDiff.From + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) + case diff.SchDiffModified: + from := fkDiff.From + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) + + parentSch := toSchemas[fkDiff.To.ReferencedTableName] + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + } + } + + return ddlStatements, nil +} + // getDiffQuery returns diff schema for specified columns and array of sql.Expression as projection to be used // on diff table function row iter. This function attempts to imitate running a query // fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columnsWithDiff, "diff_type", fromRef, toRef, tableName) diff --git a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go index 9e65d84174..80b0a9c42c 100644 --- a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go +++ b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go @@ -148,17 +148,17 @@ func AlterTableDropPks(tableName string) string { return b.String() } -func AlterTableAddPrimaryKeys(tableName string, pks *schema.ColCollection) string { +func AlterTableAddPrimaryKeys(tableName string, pkColNames []string) string { var b strings.Builder b.WriteString("ALTER TABLE ") b.WriteString(QuoteIdentifier(tableName)) b.WriteString(" ADD PRIMARY KEY (") - for i := 0; i < pks.Size(); i++ { + for i := 0; i < len(pkColNames); i++ { if i == 0 { - b.WriteString(pks.GetByIndex(i).Name) + b.WriteString(pkColNames[i]) } else { - b.WriteString("," + pks.GetByIndex(i).Name) + b.WriteString("," + pkColNames[i]) } } b.WriteRune(')') @@ -225,12 +225,26 @@ func AlterTableAddForeignKeyStmt(fk doltdb.ForeignKey, sch, parentSch schema.Sch return b.String() } -func AlterTableDropForeignKeyStmt(fk doltdb.ForeignKey) string { +func AlterTableAddForeignKeyStmtSimple(tableName, fkName, fkReferencedTableName string, fkTableColumns, fkReferencedTableColumns []string) string { var b strings.Builder b.WriteString("ALTER TABLE ") - b.WriteString(QuoteIdentifier(fk.TableName)) + b.WriteString(QuoteIdentifier(tableName)) + b.WriteString(" ADD CONSTRAINT ") + b.WriteString(QuoteIdentifier(fkName)) + b.WriteString(" FOREIGN KEY ") + b.WriteString("(" + strings.Join(fkTableColumns, ",") + ")") + b.WriteString(" REFERENCES ") + b.WriteString(QuoteIdentifier(fkReferencedTableName)) + b.WriteString(" (" + strings.Join(fkReferencedTableColumns, ",") + ");") + return b.String() +} + +func AlterTableDropForeignKeyStmt(tableName, fkName string) string { + var b strings.Builder + b.WriteString("ALTER TABLE ") + b.WriteString(QuoteIdentifier(tableName)) b.WriteString(" DROP FOREIGN KEY ") - b.WriteString(QuoteIdentifier(fk.Name)) + b.WriteString(QuoteIdentifier(fkName)) b.WriteRune(';') return b.String() } diff --git a/go/libraries/doltcore/table/typed/json/json_diff_writer.go b/go/libraries/doltcore/table/typed/json/json_diff_writer.go index 35ff7f4b11..998838bd46 100755 --- a/go/libraries/doltcore/table/typed/json/json_diff_writer.go +++ b/go/libraries/doltcore/table/typed/json/json_diff_writer.go @@ -166,7 +166,7 @@ func NewSchemaDiffWriter(wr io.WriteCloser) (*SchemaDiffWriter, error) { }, nil } -func (j *SchemaDiffWriter) WriteSchemaDiff(ctx context.Context, schemaDiffStatement string) error { +func (j *SchemaDiffWriter) WriteSchemaDiff(schemaDiffStatement string) error { if j.schemaStmtsWritten > 0 { err := iohelp.WriteAll(j.wr, []byte(",")) if err != nil {