diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index 4f3ea60489..17d7482aa9 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -22,10 +22,7 @@ import ( "strconv" "strings" - textdiff "github.com/andreyvit/diff" "github.com/dolthub/go-mysql-server/sql" - humanize "github.com/dustin/go-humanize" - "github.com/fatih/color" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" @@ -34,18 +31,12 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" - "github.com/dolthub/dolt/go/libraries/doltcore/row" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" - "github.com/dolthub/dolt/go/libraries/doltcore/table/pipeline" - "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/sqlexport" - "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/tabular" "github.com/dolthub/dolt/go/libraries/utils/argparser" - "github.com/dolthub/dolt/go/libraries/utils/iohelp" "github.com/dolthub/dolt/go/libraries/utils/set" - "github.com/dolthub/dolt/go/store/atomicerr" ) type diffOutput int @@ -60,6 +51,7 @@ const ( TabularDiffOutput diffOutput = 1 SQLDiffOutput diffOutput = 2 + JsonDiffOutput diffOutput = 3 DataFlag = "data" SchemaFlag = "schema" @@ -70,12 +62,6 @@ const ( CachedFlag = "cached" ) -type DiffSink interface { - GetSchema() schema.Schema - ProcRowWithProps(r row.Row, props pipeline.ReadableMap) error - Close() error -} - var diffDocs = cli.CommandDocumentationContent{ ShortDesc: "Show changes between commits, commit and working tree, etc", LongDesc: ` @@ -139,7 +125,7 @@ func (cmd DiffCmd) ArgParser() *argparser.ArgParser { ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).") ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).") ap.SupportsFlag(SummaryFlag, "", "Show summary of data changes") - ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular & sql. Defaults to tabular. ") + ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") ap.SupportsFlag(CachedFlag, "c", "Show only the unstaged data changes.") @@ -178,7 +164,7 @@ func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseE f, _ := apr.GetValue(FormatFlag) switch strings.ToLower(f) { - case "tabular", "sql", "": + case "tabular", "sql", "json", "": default: return errhand.BuildDError("invalid output format: %s", f).Build() } @@ -204,6 +190,8 @@ func parseDiffArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPar dArgs.diffOutput = TabularDiffOutput case "sql": dArgs.diffOutput = SQLDiffOutput + case "json": + dArgs.diffOutput = JsonDiffOutput } dArgs.limit, _ = apr.GetInt(limitParam) @@ -332,7 +320,7 @@ func maybeResolve(ctx context.Context, dEnv *env.DoltEnv, spec string) (*doltdb. return root, true } -func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) (verr errhand.VerboseError) { +func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) errhand.VerboseError { var err error tableDeltas, err := diff.GetTableDeltas(ctx, dArgs.fromRoot, dArgs.toRoot) @@ -348,132 +336,124 @@ func diffUserTables(ctx context.Context, dEnv *env.DoltEnv, dArgs *diffArgs) (ve sort.Slice(tableDeltas, func(i, j int) bool { return strings.Compare(tableDeltas[i].ToName, tableDeltas[j].ToName) < 0 }) + + dw, err := newDiffWriter(dArgs.diffOutput) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + for _, td := range tableDeltas { - if !dArgs.tableSet.Contains(td.FromName) && !dArgs.tableSet.Contains(td.ToName) { - continue - } - - tblName := td.ToName - fromTable := td.FromTable - toTable := td.ToTable - - if fromTable == nil && toTable == nil { - return errhand.BuildDError("error: both tables in tableDelta are nil").Build() - } - - if dArgs.diffOutput == TabularDiffOutput { - printTableDiffSummary(td) - } - - fromSch, toSch, err := td.GetSchemas(ctx) - if err != nil { - return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() - } - - if dArgs.diffParts&Summary != 0 { - numCols := fromSch.GetAllCols().Size() - verr = printDiffSummary(ctx, td, numCols) - } - - if dArgs.diffParts&SchemaOnlyDiff != 0 { - verr = diffSchemas(ctx, dArgs.toRoot, td, dArgs) - } - - if dArgs.diffParts&DataOnlyDiff != 0 { - if td.IsDrop() && dArgs.diffOutput == SQLDiffOutput { - continue // don't output DELETE FROM statements after DROP TABLE - } else if td.IsAdd() { - fromSch = toSch - } - - if !schema.ArePrimaryKeySetsDiffable(td.Format(), fromSch, toSch) { - cli.PrintErrf("Primary key sets differ between revisions for table %s, skipping data diff\n", tblName) - continue - } - - verr = diffRows(ctx, engine, td, dArgs) - } - + verr := diffUserTable(ctx, td, engine, dArgs, dw) if verr != nil { return verr } } - return nil -} - -func diffSchemas(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta, dArgs *diffArgs) errhand.VerboseError { - toSchemas, err := toRoot.GetAllSchemas(ctx) + err = dw.Close(ctx) if err != nil { - return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build() - } - - if dArgs.diffOutput == TabularDiffOutput { - return printShowCreateTableDiff(ctx, td) - } - - return sqlSchemaDiff(ctx, td, toSchemas) -} - -func printShowCreateTableDiff(ctx context.Context, td diff.TableDelta) errhand.VerboseError { - fromSch, toSch, err := td.GetSchemas(ctx) - if err != nil { - return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() - } - - var fromCreateStmt = "" - if td.FromTable != nil { - // TODO: use UserSpaceDatabase for these, no reason for this separate database implementation - sqlDb := sqle.NewSingleTableDatabase(td.FromName, fromSch, td.FromFks, td.FromFksParentSch) - sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) - fromCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.FromName) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - } - - var toCreateStmt = "" - if td.ToTable != nil { - sqlDb := sqle.NewSingleTableDatabase(td.ToName, toSch, td.ToFks, td.ToFksParentSch) - sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) - toCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - } - - if fromCreateStmt != toCreateStmt { - cli.Println(textdiff.LineDiff(fromCreateStmt, toCreateStmt)) + return errhand.VerboseErrorFromError(err) } return nil } -// TODO: this doesn't handle check constraints or triggers -func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string]schema.Schema) errhand.VerboseError { +func diffUserTable( + ctx context.Context, + td diff.TableDelta, + engine *engine.SqlEngine, + dArgs *diffArgs, + dw diffWriter, +) errhand.VerboseError { + if !dArgs.tableSet.Contains(td.FromName) && !dArgs.tableSet.Contains(td.ToName) { + return nil + } + + fromTable := td.FromTable + toTable := td.ToTable + + if fromTable == nil && toTable == nil { + return errhand.BuildDError("error: both tables in tableDelta are nil").Build() + } + + err := dw.BeginTable(ctx, td) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + fromSch, toSch, err := td.GetSchemas(ctx) if err != nil { return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() } + if dArgs.diffParts&Summary != 0 { + numCols := fromSch.GetAllCols().Size() + return printDiffSummary(ctx, td, numCols) + } + + if dArgs.diffParts&SchemaOnlyDiff != 0 { + err := dw.WriteSchemaDiff(ctx, dArgs.toRoot, td) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + } + + if td.IsDrop() && dArgs.diffOutput == SQLDiffOutput { + return nil // don't output DELETE FROM statements after DROP TABLE + } else if td.IsAdd() { + fromSch = toSch + } + + verr := diffRows(ctx, engine, td, dArgs, dw) + if verr != nil { + return verr + } + + return nil +} + +func writeSqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string]schema.Schema) errhand.VerboseError { + ddlStatements, err := sqlSchemaDiff(ctx, td, toSchemas) + if err != nil { + return err + } + + for _, stmt := range ddlStatements { + cli.Println(stmt) + } + + return nil +} + +// sqlSchemaDiff returns a slice of DDL statements that will transform the schema in the from delta to the schema in +// the to delta. +// TODO: this doesn't handle constraints or triggers +// TODO: this should live in the diff package +func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string]schema.Schema) ([]string, errhand.VerboseError) { + fromSch, toSch, err := td.GetSchemas(ctx) + if err != nil { + return nil, errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() + } + + var ddlStatements []string + if td.IsDrop() { - cli.Println(sqlfmt.DropTableStmt(td.FromName)) + ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName)) } else if td.IsAdd() { sqlDb := sqle.NewSingleTableDatabase(td.ToName, toSch, td.ToFks, td.ToFksParentSch) sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) stmt, err := sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName) if err != nil { - return errhand.VerboseErrorFromError(err) + return nil, errhand.VerboseErrorFromError(err) } - cli.Println(stmt) + ddlStatements = append(ddlStatements, stmt) } else { if td.FromName != td.ToName { - cli.Println(sqlfmt.RenameTableStmt(td.FromName, td.ToName)) + ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName)) } eq := schema.SchemasAreEqual(fromSch, toSch) if eq && !td.HasFKChanges() { - return nil + return ddlStatements, nil } colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch) @@ -482,25 +462,25 @@ func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string switch cd.DiffType { case diff.SchDiffNone: case diff.SchDiffAdded: - cli.Println(sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.FmtCol(0, 0, 0, *cd.New))) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.FmtCol(0, 0, 0, *cd.New))) case diff.SchDiffRemoved: - cli.Println(sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) case diff.SchDiffModified: // Ignore any primary key set changes here if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { continue } if cd.Old.Name != cd.New.Name { - cli.Println(sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) } } } // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { - cli.Println(sqlfmt.AlterTableDropPks(td.ToName)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName)) if toSch.GetPKCols().Size() > 0 { - cli.Println(sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols())) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols())) } } @@ -508,12 +488,12 @@ func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string switch idxDiff.DiffType { case diff.SchDiffNone: case diff.SchDiffAdded: - cli.Println(sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) case diff.SchDiffRemoved: - cli.Println(sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) case diff.SchDiffModified: - cli.Println(sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) - cli.Println(sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) } } @@ -522,23 +502,87 @@ func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string case diff.SchDiffNone: case diff.SchDiffAdded: parentSch := toSchemas[fkDiff.To.ReferencedTableName] - cli.Println(sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) case diff.SchDiffRemoved: - cli.Println(sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) case diff.SchDiffModified: - cli.Println(sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(fkDiff.From)) parentSch := toSchemas[fkDiff.To.ReferencedTableName] - cli.Println(sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) } } } - return nil + + return ddlStatements, nil } -func diffRows(ctx context.Context, se *engine.SqlEngine, td diff.TableDelta, dArgs *diffArgs) errhand.VerboseError { +func diffRows( + ctx context.Context, + se *engine.SqlEngine, + td diff.TableDelta, + dArgs *diffArgs, + dw diffWriter, +) errhand.VerboseError { from, to := dArgs.fromRef, dArgs.toRef + diffable := schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch) + canSqlDiff := !(td.ToSch == nil || (td.FromSch != nil && !schema.SchemasAreEqual(td.FromSch, td.ToSch))) + + var toSch, fromSch sql.Schema + if td.FromSch != nil { + pkSch, err := sqlutil.FromDoltSchema(td.FromName, td.FromSch) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + fromSch = pkSch.Schema + } + + if td.ToSch != nil { + pkSch, err := sqlutil.FromDoltSchema(td.ToName, td.ToSch) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + toSch = pkSch.Schema + } + + unionSch := unionSchemas(fromSch, toSch) + + // We always instantiate a RowWriter in case the diffWriter needs it to close off any work from schema output + rowWriter, err := dw.RowWriter(ctx, td, unionSch) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + + // can't diff + if !diffable { + // TODO: this messes up some structured output if the user didn't redirect it + cli.PrintErrf("Primary key sets differ between revisions for table %s, skipping data diff\n", td.ToName) + err := rowWriter.Close(ctx) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + return nil + } else if dArgs.diffOutput == SQLDiffOutput && !canSqlDiff { + // TODO: this is overly broad, we can absolutely do better + _, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff\n") + err := rowWriter.Close(ctx) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + return nil + } + + // no data diff requested + if dArgs.diffParts&DataOnlyDiff == 0 { + err := rowWriter.Close(ctx) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + return nil + } + + // do the data diff tableName := td.ToName if len(tableName) == 0 { tableName = td.FromName @@ -569,47 +613,7 @@ func diffRows(ctx context.Context, se *engine.SqlEngine, td diff.TableDelta, dAr defer rowIter.Close(sqlCtx) - var toSch, fromSch sql.Schema - if td.FromSch != nil { - pkSch, err := sqlutil.FromDoltSchema(td.FromName, td.FromSch) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - fromSch = pkSch.Schema - } - - if td.ToSch != nil { - pkSch, err := sqlutil.FromDoltSchema(td.ToName, td.ToSch) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - toSch = pkSch.Schema - } - - unionSch := unionSchemas(fromSch, toSch) - - // In some cases we can't print SQL output diffs - if dArgs.diffOutput == SQLDiffOutput && - (td.ToSch == nil || - (td.FromSch != nil && !schema.SchemasAreEqual(td.FromSch, td.ToSch))) { - _, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff\n") - return nil - } - - var diffWriter diff.SqlRowDiffWriter - switch dArgs.diffOutput { - case TabularDiffOutput: - // TODO: default sample size - diffWriter = tabular.NewFixedWidthDiffTableWriter(unionSch, iohelp.NopWrCloser(cli.CliOut), 100) - case SQLDiffOutput: - targetSch := td.ToSch - if targetSch == nil { - targetSch = td.FromSch - } - diffWriter = sqlexport.NewSqlDiffWriter(tableName, targetSch, iohelp.NopWrCloser(cli.CliOut)) - } - - err = writeDiffResults(sqlCtx, sch, unionSch, rowIter, diffWriter) + err = writeDiffResults(sqlCtx, sch, unionSch, rowIter, rowWriter) if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } @@ -630,6 +634,23 @@ func unionSchemas(s1 sql.Schema, s2 sql.Schema) sql.Schema { return union } +func getColumnNamesString(fromSch, toSch schema.Schema) string { + var cols []string + if fromSch != nil { + fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + cols = append(cols, fmt.Sprintf("`from_%s`", col.Name)) + return false, nil + }) + } + if toSch != nil { + toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + cols = append(cols, fmt.Sprintf("`to_%s`", col.Name)) + return false, nil + }) + } + return strings.Join(cols, ",") +} + func writeDiffResults( ctx *sql.Context, diffQuerySch sql.Schema, @@ -670,167 +691,3 @@ func writeDiffResults( } } } - -func getColumnNamesString(fromSch, toSch schema.Schema) string { - var cols []string - if fromSch != nil { - fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - cols = append(cols, fmt.Sprintf("`from_%s`", col.Name)) - return false, nil - }) - } - if toSch != nil { - toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - cols = append(cols, fmt.Sprintf("`to_%s`", col.Name)) - return false, nil - }) - } - return strings.Join(cols, ",") -} - -func printDiffLines(bold *color.Color, lines []string) { - for _, line := range lines { - if string(line[0]) == string("+") { - cli.Println(color.GreenString("+ " + line[1:])) - } else if string(line[0]) == ("-") { - cli.Println(color.RedString("- " + line[1:])) - } else { - cli.Println(" " + line) - } - } -} - -func printTableDiffSummary(td diff.TableDelta) { - bold := color.New(color.Bold) - if td.IsDrop() { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.FromName) - _, _ = bold.Println("deleted table") - } else if td.IsAdd() { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.ToName, td.ToName) - _, _ = bold.Println("added table") - } else { - _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.ToName) - h1, err := td.FromTable.HashOf() - - if err != nil { - panic(err) - } - - _, _ = bold.Printf("--- a/%s @ %s\n", td.FromName, h1.String()) - - h2, err := td.ToTable.HashOf() - - if err != nil { - panic(err) - } - - _, _ = bold.Printf("+++ b/%s @ %s\n", td.ToName, h2.String()) - } -} - -func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errhand.VerboseError { - // todo: use errgroup.Group - ae := atomicerr.New() - ch := make(chan diff.DiffSummaryProgress) - go func() { - defer close(ch) - err := diff.SummaryForTableDelta(ctx, ch, td) - - ae.SetIfError(err) - }() - - acc := diff.DiffSummaryProgress{} - var count int64 - var pos int - eP := cli.NewEphemeralPrinter() - for p := range ch { - if ae.IsSet() { - break - } - - acc.Adds += p.Adds - acc.Removes += p.Removes - acc.Changes += p.Changes - acc.CellChanges += p.CellChanges - acc.NewSize += p.NewSize - acc.OldSize += p.OldSize - - if count%10000 == 0 { - eP.Printf("prev size: %d, new size: %d, adds: %d, deletes: %d, modifications: %d\n", acc.OldSize, acc.NewSize, acc.Adds, acc.Removes, acc.Changes) - eP.Display() - } - - count++ - } - - pos = cli.DeleteAndPrint(pos, "") - - if err := ae.Get(); err != nil { - return errhand.BuildDError("").AddCause(err).Build() - } - - keyless, err := td.IsKeyless(ctx) - if err != nil { - return nil - } - - if (acc.Adds + acc.Removes + acc.Changes) == 0 { - cli.Println("No data changes. See schema changes by using -s or --schema.") - return nil - } - - if keyless { - printKeylessSummary(acc) - } else { - printSummary(acc, colLen) - } - - return nil -} - -func printSummary(acc diff.DiffSummaryProgress, colLen int) { - rowsUnmodified := uint64(acc.OldSize - acc.Changes - acc.Removes) - unmodified := pluralize("Row Unmodified", "Rows Unmodified", rowsUnmodified) - insertions := pluralize("Row Added", "Rows Added", acc.Adds) - deletions := pluralize("Row Deleted", "Rows Deleted", acc.Removes) - changes := pluralize("Row Modified", "Rows Modified", acc.Changes) - cellChanges := pluralize("Cell Modified", "Cells Modified", acc.CellChanges) - - oldValues := pluralize("Entry", "Entries", acc.OldSize) - newValues := pluralize("Entry", "Entries", acc.NewSize) - - percentCellsChanged := float64(100*acc.CellChanges) / (float64(acc.OldSize) * float64(colLen)) - - safePercent := func(num, dom uint64) float64 { - // returns +Inf for x/0 where x > 0 - if num == 0 { - return float64(0) - } - return float64(100*num) / (float64(dom)) - } - - cli.Printf("%s (%.2f%%)\n", unmodified, safePercent(rowsUnmodified, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", insertions, safePercent(acc.Adds, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", deletions, safePercent(acc.Removes, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", changes, safePercent(acc.Changes, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", cellChanges, percentCellsChanged) - cli.Printf("(%s vs %s)\n\n", oldValues, newValues) -} - -func printKeylessSummary(acc diff.DiffSummaryProgress) { - insertions := pluralize("Row Added", "Rows Added", acc.Adds) - deletions := pluralize("Row Deleted", "Rows Deleted", acc.Removes) - - cli.Printf("%s\n", insertions) - cli.Printf("%s\n", deletions) -} - -func pluralize(singular, plural string, n uint64) string { - var noun string - if n != 1 { - noun = plural - } else { - noun = singular - } - return fmt.Sprintf("%s %s", humanize.Comma(int64(n)), noun) -} diff --git a/go/cmd/dolt/commands/diff_output.go b/go/cmd/dolt/commands/diff_output.go new file mode 100644 index 0000000000..ff6cf8b7ca --- /dev/null +++ b/go/cmd/dolt/commands/diff_output.go @@ -0,0 +1,385 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commands + +import ( + "context" + "fmt" + "io" + + textdiff "github.com/andreyvit/diff" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dustin/go-humanize" + "github.com/fatih/color" + + "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/cmd/dolt/errhand" + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" + "github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json" + "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/sqlexport" + "github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/tabular" + "github.com/dolthub/dolt/go/libraries/utils/iohelp" + "github.com/dolthub/dolt/go/store/atomicerr" +) + +// diffWriter is an interface that lets us write diffs in a variety of output formats +type diffWriter interface { + // BeginTable is called when a new table is about to be written, before any schema or row diffs are written + BeginTable(ctx context.Context, td diff.TableDelta) error + // WriteSchemaDiff is called to write a schema diff for the table given (if requested by args) + WriteSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error + // RowWriter returns a row writer for the table delta provided, which will have Close() called on it when rows are + // done being written. + RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) + // Close finalizes the work of the writer + Close(ctx context.Context) error +} + +// newDiffWriter returns a diffWriter for the output format given +func newDiffWriter(diffOutput diffOutput) (diffWriter, error) { + switch diffOutput { + case TabularDiffOutput: + return tabularDiffWriter{}, nil + case SQLDiffOutput: + return sqlDiffWriter{}, nil + case JsonDiffOutput: + return newJsonDiffWriter(iohelp.NopWrCloser(cli.CliOut)) + default: + panic(fmt.Sprintf("unexpected diff output: %v", diffOutput)) + } +} + +func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errhand.VerboseError { + // todo: use errgroup.Group + ae := atomicerr.New() + ch := make(chan diff.DiffSummaryProgress) + go func() { + defer close(ch) + err := diff.SummaryForTableDelta(ctx, ch, td) + + ae.SetIfError(err) + }() + + acc := diff.DiffSummaryProgress{} + var count int64 + var pos int + eP := cli.NewEphemeralPrinter() + for p := range ch { + if ae.IsSet() { + break + } + + acc.Adds += p.Adds + acc.Removes += p.Removes + acc.Changes += p.Changes + acc.CellChanges += p.CellChanges + acc.NewSize += p.NewSize + acc.OldSize += p.OldSize + + if count%10000 == 0 { + eP.Printf("prev size: %d, new size: %d, adds: %d, deletes: %d, modifications: %d\n", acc.OldSize, acc.NewSize, acc.Adds, acc.Removes, acc.Changes) + eP.Display() + } + + count++ + } + + pos = cli.DeleteAndPrint(pos, "") + + if err := ae.Get(); err != nil { + return errhand.BuildDError("").AddCause(err).Build() + } + + keyless, err := td.IsKeyless(ctx) + if err != nil { + return nil + } + + if (acc.Adds + acc.Removes + acc.Changes) == 0 { + cli.Println("No data changes. See schema changes by using -s or --schema.") + return nil + } + + if keyless { + printKeylessSummary(acc) + } else { + printSummary(acc, colLen) + } + + return nil +} + +func printSummary(acc diff.DiffSummaryProgress, colLen int) { + rowsUnmodified := uint64(acc.OldSize - acc.Changes - acc.Removes) + unmodified := pluralize("Row Unmodified", "Rows Unmodified", rowsUnmodified) + insertions := pluralize("Row Added", "Rows Added", acc.Adds) + deletions := pluralize("Row Deleted", "Rows Deleted", acc.Removes) + changes := pluralize("Row Modified", "Rows Modified", acc.Changes) + cellChanges := pluralize("Cell Modified", "Cells Modified", acc.CellChanges) + + oldValues := pluralize("Entry", "Entries", acc.OldSize) + newValues := pluralize("Entry", "Entries", acc.NewSize) + + percentCellsChanged := float64(100*acc.CellChanges) / (float64(acc.OldSize) * float64(colLen)) + + safePercent := func(num, dom uint64) float64 { + // returns +Inf for x/0 where x > 0 + if num == 0 { + return float64(0) + } + return float64(100*num) / (float64(dom)) + } + + cli.Printf("%s (%.2f%%)\n", unmodified, safePercent(rowsUnmodified, acc.OldSize)) + cli.Printf("%s (%.2f%%)\n", insertions, safePercent(acc.Adds, acc.OldSize)) + cli.Printf("%s (%.2f%%)\n", deletions, safePercent(acc.Removes, acc.OldSize)) + cli.Printf("%s (%.2f%%)\n", changes, safePercent(acc.Changes, acc.OldSize)) + cli.Printf("%s (%.2f%%)\n", cellChanges, percentCellsChanged) + cli.Printf("(%s vs %s)\n\n", oldValues, newValues) +} + +func printKeylessSummary(acc diff.DiffSummaryProgress) { + insertions := pluralize("Row Added", "Rows Added", acc.Adds) + deletions := pluralize("Row Deleted", "Rows Deleted", acc.Removes) + + cli.Printf("%s\n", insertions) + cli.Printf("%s\n", deletions) +} + +func pluralize(singular, plural string, n uint64) string { + var noun string + if n != 1 { + noun = plural + } else { + noun = singular + } + return fmt.Sprintf("%s %s", humanize.Comma(int64(n)), noun) +} + +type tabularDiffWriter struct{} + +var _ diffWriter = (*tabularDiffWriter)(nil) + +func (t tabularDiffWriter) Close(ctx context.Context) error { + return nil +} + +func (t tabularDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { + bold := color.New(color.Bold) + if td.IsDrop() { + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.FromName) + _, _ = bold.Println("deleted table") + } else if td.IsAdd() { + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.ToName, td.ToName) + _, _ = bold.Println("added table") + } else { + _, _ = bold.Printf("diff --dolt a/%s b/%s\n", td.FromName, td.ToName) + h1, err := td.FromTable.HashOf() + + if err != nil { + panic(err) + } + + _, _ = bold.Printf("--- a/%s @ %s\n", td.FromName, h1.String()) + + h2, err := td.ToTable.HashOf() + + if err != nil { + panic(err) + } + + _, _ = bold.Printf("+++ b/%s @ %s\n", td.ToName, h2.String()) + } + return nil +} + +func (t tabularDiffWriter) WriteSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error { + fromSch, toSch, err := td.GetSchemas(ctx) + if err != nil { + return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build() + } + + var fromCreateStmt = "" + if td.FromTable != nil { + // TODO: use UserSpaceDatabase for these, no reason for this separate database implementation + sqlDb := sqle.NewSingleTableDatabase(td.FromName, fromSch, td.FromFks, td.FromFksParentSch) + sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) + fromCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.FromName) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + } + + var toCreateStmt = "" + if td.ToTable != nil { + sqlDb := sqle.NewSingleTableDatabase(td.ToName, toSch, td.ToFks, td.ToFksParentSch) + sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb) + toCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName) + if err != nil { + return errhand.VerboseErrorFromError(err) + } + } + + if fromCreateStmt != toCreateStmt { + cli.Println(textdiff.LineDiff(fromCreateStmt, toCreateStmt)) + } + + return nil +} + +func (t tabularDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { + return tabular.NewFixedWidthDiffTableWriter(unionSch, iohelp.NopWrCloser(cli.CliOut), 100), nil +} + +type sqlDiffWriter struct{} + +var _ diffWriter = (*tabularDiffWriter)(nil) + +func (s sqlDiffWriter) Close(ctx context.Context) error { + return nil +} + +func (s sqlDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { + return nil +} + +func (s sqlDiffWriter) WriteSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error { + toSchemas, err := toRoot.GetAllSchemas(ctx) + if err != nil { + return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build() + } + + return writeSqlSchemaDiff(ctx, td, toSchemas) +} + +func (s sqlDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { + targetSch := td.ToSch + if targetSch == nil { + targetSch = td.FromSch + } + + return sqlexport.NewSqlDiffWriter(td.ToName, targetSch, iohelp.NopWrCloser(cli.CliOut)), nil +} + +type jsonDiffWriter struct { + wr io.WriteCloser + schemaDiffWriter diff.SchemaDiffWriter + rowDiffWriter diff.SqlRowDiffWriter + schemaDiffsWritten int + tablesWritten int +} + +var _ diffWriter = (*tabularDiffWriter)(nil) + +func newJsonDiffWriter(wr io.WriteCloser) (*jsonDiffWriter, error) { + return &jsonDiffWriter{ + wr: wr, + }, nil +} + +const jsonDiffTableHeader = `{"name":"%s","schema_diff":` +const jsonDiffFooter = `}]}` + +func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error { + if j.schemaDiffWriter == nil { + err := iohelp.WriteAll(j.wr, []byte(`{"tables":[`)) + if err != nil { + return err + } + } else { + err := iohelp.WriteAll(j.wr, []byte(`},`)) + if err != nil { + return err + } + } + + err := iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(jsonDiffTableHeader, td.ToName))) + if err != nil { + return err + } + + j.tablesWritten++ + + j.schemaDiffWriter, err = json.NewSchemaDiffWriter(iohelp.NopWrCloser(j.wr)) + return err +} + +func (j *jsonDiffWriter) WriteSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error { + toSchemas, err := toRoot.GetAllSchemas(ctx) + if err != nil { + return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build() + } + + stmts, err := sqlSchemaDiff(ctx, td, toSchemas) + if err != nil { + return err + } + + for _, stmt := range stmts { + err := j.schemaDiffWriter.WriteSchemaDiff(ctx, stmt) + if err != nil { + return err + } + } + + return nil +} + +func (j *jsonDiffWriter) RowWriter(ctx context.Context, td diff.TableDelta, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) { + // close off the schema diff block, start the data block + err := iohelp.WriteAll(j.wr, []byte(`],"data_diff":[`)) + if err != nil { + return nil, err + } + + // Translate the union schema to its dolt version + cols := schema.NewColCollection() + for i, col := range unionSch { + doltCol, err := sqlutil.ToDoltCol(uint64(i), col) + if err != nil { + return nil, err + } + cols = cols.Append(doltCol) + } + + sch, err := schema.SchemaFromCols(cols) + if err != nil { + return nil, err + } + + j.rowDiffWriter, err = json.NewJsonDiffWriter(iohelp.NopWrCloser(cli.CliOut), sch) + return j.rowDiffWriter, err +} + +func (j *jsonDiffWriter) Close(ctx context.Context) error { + if j.tablesWritten > 0 { + err := iohelp.WriteLine(j.wr, jsonDiffFooter) + if err != nil { + return err + } + } else { + err := iohelp.WriteLine(j.wr, "") + if err != nil { + return err + } + } + + // Writer has already been closed here during row iteration, no need to close it here + return nil +} diff --git a/go/cmd/dolt/commands/dump.go b/go/cmd/dolt/commands/dump.go index 483ae75d2a..d8c9ed4e89 100644 --- a/go/cmd/dolt/commands/dump.go +++ b/go/cmd/dolt/commands/dump.go @@ -265,7 +265,7 @@ func dumpTable(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, fi return nil } -func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, outSch schema.Schema, filePath string) (table.SqlTableWriter, errhand.VerboseError) { +func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, outSch schema.Schema, filePath string) (table.SqlRowWriter, errhand.VerboseError) { opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()} writer, err := dEnv.FS.OpenForWriteAppend(filePath, os.ModePerm) diff --git a/go/cmd/dolt/commands/indexcmds/cat.go b/go/cmd/dolt/commands/indexcmds/cat.go index 5aa6c55010..9c96d4f9d9 100644 --- a/go/cmd/dolt/commands/indexcmds/cat.go +++ b/go/cmd/dolt/commands/indexcmds/cat.go @@ -176,7 +176,7 @@ func (cmd CatCmd) prettyPrintResults(ctx context.Context, doltSch schema.Schema, return nil } -func getTableWriter(format resultFormat, sch schema.Schema) (wr table.SqlTableWriter, err error) { +func getTableWriter(format resultFormat, sch schema.Schema) (wr table.SqlRowWriter, err error) { s, err := sqlutil.FromDoltSchema("", sch) if err != nil { return nil, err diff --git a/go/cmd/dolt/commands/schcmds/import.go b/go/cmd/dolt/commands/schcmds/import.go index 5aa33531a8..088afab0e8 100644 --- a/go/cmd/dolt/commands/schcmds/import.go +++ b/go/cmd/dolt/commands/schcmds/import.go @@ -344,7 +344,7 @@ func inferSchemaFromFile(ctx context.Context, nbf *types.NomsBinFormat, impOpts impOpts.fileType = impOpts.fileType[1:] } - var rd table.TableReadCloser + var rd table.ReadCloser csvInfo := csv.NewCSVInfo().SetDelim(",") switch impOpts.fileType { diff --git a/go/cmd/dolt/commands/tblcmds/export.go b/go/cmd/dolt/commands/tblcmds/export.go index a910c182b3..66eaffaf33 100644 --- a/go/cmd/dolt/commands/tblcmds/export.go +++ b/go/cmd/dolt/commands/tblcmds/export.go @@ -220,7 +220,7 @@ func (cmd ExportCmd) Exec(ctx context.Context, commandStr string, args []string, return 0 } -func getTableWriter(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, rdSchema schema.Schema, exOpts *exportOptions) (table.SqlTableWriter, errhand.VerboseError) { +func getTableWriter(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, rdSchema schema.Schema, exOpts *exportOptions) (table.SqlRowWriter, errhand.VerboseError) { ow, err := exOpts.checkOverwrite(ctx, root, dEnv.FS) if err != nil { return nil, errhand.VerboseErrorFromError(err) diff --git a/go/libraries/doltcore/diff/diff.go b/go/libraries/doltcore/diff/diff.go index 4250cce468..d00b9c49d1 100755 --- a/go/libraries/doltcore/diff/diff.go +++ b/go/libraries/doltcore/diff/diff.go @@ -57,7 +57,7 @@ type RowDiffer interface { Close() error } -// SqlRowDiffWriter knows how to write diff rows to an arbitrary format and destination. +// SqlRowDiffWriter knows how to write diff rows for a table to an arbitrary format and destination. type SqlRowDiffWriter interface { // WriteRow writes the diff row given, of the diff type provided. colDiffTypes is guaranteed to be the same length as // the input row. @@ -67,5 +67,12 @@ type SqlRowDiffWriter interface { Close(ctx context.Context) error } -// ColorFunc is a function that can color a format string -type ColorFunc func(a ...interface{}) string +// SchemaDiffWriter knows how to write SQL DDL statements for a schema diff for a table to an arbitrary format and +// destination. +type SchemaDiffWriter interface { + // WriteSchemaDiff writes the schema diff given (a SQL statement) and returns any error. A single table may have + // many SQL statements for a single diff. WriteSchemaDiff will be called before any row diffs via |WriteRow| + WriteSchemaDiff(ctx context.Context, schemaDiffStatement string) error + // Close finalizes the work of this writer. + Close(ctx context.Context) error +} diff --git a/go/libraries/doltcore/env/actions/infer_schema.go b/go/libraries/doltcore/env/actions/infer_schema.go index 6e05a419c4..8389611165 100644 --- a/go/libraries/doltcore/env/actions/infer_schema.go +++ b/go/libraries/doltcore/env/actions/infer_schema.go @@ -54,7 +54,7 @@ type InferenceArgs interface { } // InferColumnTypesFromTableReader will infer a data types from a table reader. -func InferColumnTypesFromTableReader(ctx context.Context, root *doltdb.RootValue, rd table.TableReadCloser, args InferenceArgs) (*schema.ColCollection, error) { +func InferColumnTypesFromTableReader(ctx context.Context, root *doltdb.RootValue, rd table.ReadCloser, args InferenceArgs) (*schema.ColCollection, error) { inferrer := newInferrer(rd.GetSchema(), args) var rowFailure *pipeline.TransformRowFailure diff --git a/go/libraries/doltcore/merge/violations_fk.go b/go/libraries/doltcore/merge/violations_fk.go index a147ff1523..91a28edb97 100644 --- a/go/libraries/doltcore/merge/violations_fk.go +++ b/go/libraries/doltcore/merge/violations_fk.go @@ -250,7 +250,7 @@ func nomsParentFkConstraintViolations( } shouldContinue, err := func() (bool, error) { - var mapIter table.TableReadCloser = noms.NewNomsRangeReader( + var mapIter table.ReadCloser = noms.NewNomsRangeReader( postParent.IndexSchema, durable.NomsMapFromIndex(postParent.IndexData), []*noms.ReadRange{{Start: postParentIndexPartialKey, Inclusive: true, Reverse: false, Check: noms.InRangeCheckPartial(postParentIndexPartialKey)}}) @@ -458,7 +458,7 @@ func childFkConstraintViolationsProcess( postChildCVMapEditor *types.MapEditor, vInfo types.JSON, ) (bool, error) { - var mapIter table.TableReadCloser = noms.NewNomsRangeReader( + var mapIter table.ReadCloser = noms.NewNomsRangeReader( postParent.IndexSchema, durable.NomsMapFromIndex(postParent.IndexData), []*noms.ReadRange{{Start: parentPartialKey, Inclusive: true, Reverse: false, Check: noms.InRangeCheckPartial(parentPartialKey)}}) diff --git a/go/libraries/doltcore/mvdata/data_loc.go b/go/libraries/doltcore/mvdata/data_loc.go index 6500a518c5..9e1ce37169 100644 --- a/go/libraries/doltcore/mvdata/data_loc.go +++ b/go/libraries/doltcore/mvdata/data_loc.go @@ -93,7 +93,7 @@ type DataLocation interface { // NewCreatingWriter will create a TableWriteCloser for a DataLocation that will create a new table, or overwrite // an existing table. - NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlTableWriter, error) + NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlRowWriter, error) } // NewDataLocation creates a DataLocation object from a path and a format string. If the path is the name of a table diff --git a/go/libraries/doltcore/mvdata/data_loc_test.go b/go/libraries/doltcore/mvdata/data_loc_test.go index 770b8024a0..f8639b1815 100644 --- a/go/libraries/doltcore/mvdata/data_loc_test.go +++ b/go/libraries/doltcore/mvdata/data_loc_test.go @@ -188,7 +188,7 @@ func TestCreateRdWr(t *testing.T) { }{ {NewDataLocation("file.csv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()}, {NewDataLocation("file.psv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()}, - {NewDataLocation("file.json", ""), reflect.TypeOf((*json.JSONReader)(nil)).Elem(), reflect.TypeOf((*json.JSONWriter)(nil)).Elem()}, + {NewDataLocation("file.json", ""), reflect.TypeOf((*json.JSONReader)(nil)).Elem(), reflect.TypeOf((*json.RowWriter)(nil)).Elem()}, //{NewDataLocation("file.nbf", ""), reflect.TypeOf((*nbf.NBFReader)(nil)).Elem(), reflect.TypeOf((*nbf.NBFWriter)(nil)).Elem()}, } diff --git a/go/libraries/doltcore/mvdata/data_mover.go b/go/libraries/doltcore/mvdata/data_mover.go index 7ed4fba915..a132e12f16 100644 --- a/go/libraries/doltcore/mvdata/data_mover.go +++ b/go/libraries/doltcore/mvdata/data_mover.go @@ -70,7 +70,7 @@ type DataMoverCloser interface { } type DataMover struct { - Rd table.TableReadCloser + Rd table.ReadCloser Transforms *pipeline.TransformCollection Wr table.TableWriteCloser ContOnErr bool @@ -121,7 +121,7 @@ func SchAndTableNameFromFile(ctx context.Context, path string, fs filesys.Readab } } -func InferSchema(ctx context.Context, root *doltdb.RootValue, rd table.TableReadCloser, tableName string, pks []string, args actions.InferenceArgs) (schema.Schema, error) { +func InferSchema(ctx context.Context, root *doltdb.RootValue, rd table.ReadCloser, tableName string, pks []string, args actions.InferenceArgs) (schema.Schema, error) { var err error infCols, err := actions.InferColumnTypesFromTableReader(ctx, root, rd, args) diff --git a/go/libraries/doltcore/mvdata/file_data_loc.go b/go/libraries/doltcore/mvdata/file_data_loc.go index 6d9576855e..65316375d7 100644 --- a/go/libraries/doltcore/mvdata/file_data_loc.go +++ b/go/libraries/doltcore/mvdata/file_data_loc.go @@ -178,7 +178,7 @@ func (dl FileDataLocation) NewReader(ctx context.Context, root *doltdb.RootValue // NewCreatingWriter will create a TableWriteCloser for a DataLocation that will create a new table, or overwrite // an existing table. -func (dl FileDataLocation) NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlTableWriter, error) { +func (dl FileDataLocation) NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlRowWriter, error) { switch dl.Format { case CsvFile: return csv.NewCSVWriter(wr, outSch, csv.NewCSVInfo()) diff --git a/go/libraries/doltcore/mvdata/pipeline.go b/go/libraries/doltcore/mvdata/pipeline.go index 4bb64b9f2a..7b99187774 100644 --- a/go/libraries/doltcore/mvdata/pipeline.go +++ b/go/libraries/doltcore/mvdata/pipeline.go @@ -30,10 +30,10 @@ type DataMoverPipeline struct { g *errgroup.Group ctx context.Context rd table.SqlRowReader - wr table.SqlTableWriter + wr table.SqlRowWriter } -func NewDataMoverPipeline(ctx context.Context, rd table.SqlRowReader, wr table.SqlTableWriter) *DataMoverPipeline { +func NewDataMoverPipeline(ctx context.Context, rd table.SqlRowReader, wr table.SqlRowWriter) *DataMoverPipeline { g, ctx := errgroup.WithContext(ctx) return &DataMoverPipeline{ g: g, diff --git a/go/libraries/doltcore/mvdata/stream_data_loc.go b/go/libraries/doltcore/mvdata/stream_data_loc.go index 3573684482..86ac6def4a 100644 --- a/go/libraries/doltcore/mvdata/stream_data_loc.go +++ b/go/libraries/doltcore/mvdata/stream_data_loc.go @@ -73,7 +73,7 @@ func (dl StreamDataLocation) NewReader(ctx context.Context, root *doltdb.RootVal // NewCreatingWriter will create a TableWriteCloser for a DataLocation that will create a new table, or overwrite // an existing table. -func (dl StreamDataLocation) NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlTableWriter, error) { +func (dl StreamDataLocation) NewCreatingWriter(ctx context.Context, mvOpts DataMoverOptions, root *doltdb.RootValue, outSch schema.Schema, opts editor.Options, wr io.WriteCloser) (table.SqlRowWriter, error) { switch dl.Format { case CsvFile: return csv.NewCSVWriter(iohelp.NopWrCloser(dl.Writer), outSch, csv.NewCSVInfo()) diff --git a/go/libraries/doltcore/table/composite_table_reader.go b/go/libraries/doltcore/table/composite_table_reader.go index 23065f4cea..4345e348ac 100644 --- a/go/libraries/doltcore/table/composite_table_reader.go +++ b/go/libraries/doltcore/table/composite_table_reader.go @@ -26,12 +26,12 @@ import ( // of multiple TableReader instances into a single set of results. type CompositeTableReader struct { sch schema.Schema - readers []TableReadCloser + readers []ReadCloser idx int } // NewCompositeTableReader creates a new CompositeTableReader instance from a slice of TableReadClosers. -func NewCompositeTableReader(readers []TableReadCloser) (*CompositeTableReader, error) { +func NewCompositeTableReader(readers []ReadCloser) (*CompositeTableReader, error) { if len(readers) == 0 { panic("nothing to iterate") } diff --git a/go/libraries/doltcore/table/composite_table_reader_test.go b/go/libraries/doltcore/table/composite_table_reader_test.go index f57123bd10..cae3643c83 100644 --- a/go/libraries/doltcore/table/composite_table_reader_test.go +++ b/go/libraries/doltcore/table/composite_table_reader_test.go @@ -42,7 +42,7 @@ func TestCompositeTableReader(t *testing.T) { sch, err := schema.SchemaFromCols(coll) require.NoError(t, err) - var readers []TableReadCloser + var readers []ReadCloser var expectedKeys []uint64 var expectedVals []int64 for i := 0; i < numReaders; i++ { diff --git a/go/libraries/doltcore/table/editor/bulk_import_tea.go b/go/libraries/doltcore/table/editor/bulk_import_tea.go index 811bba83fd..60c4fcc8ac 100644 --- a/go/libraries/doltcore/table/editor/bulk_import_tea.go +++ b/go/libraries/doltcore/table/editor/bulk_import_tea.go @@ -226,7 +226,7 @@ func (iea *BulkImportIEA) HasPartial(ctx context.Context, idxSch schema.Schema, var err error var matches []hashedTuple - var mapIter table.TableReadCloser = noms.NewNomsRangeReader(idxSch, iea.rowData, []*noms.ReadRange{ + var mapIter table.ReadCloser = noms.NewNomsRangeReader(idxSch, iea.rowData, []*noms.ReadRange{ {Start: partialKey, Inclusive: true, Reverse: false, Check: noms.InRangeCheckPartial(partialKey)}}) defer mapIter.Close(ctx) var r row.Row diff --git a/go/libraries/doltcore/table/editor/index_edit_accumulator.go b/go/libraries/doltcore/table/editor/index_edit_accumulator.go index aec1e17d93..efded4fb39 100644 --- a/go/libraries/doltcore/table/editor/index_edit_accumulator.go +++ b/go/libraries/doltcore/table/editor/index_edit_accumulator.go @@ -286,7 +286,7 @@ func (iea *indexEditAccumulatorImpl) HasPartial(ctx context.Context, idxSch sche var err error var matches []hashedTuple - var mapIter table.TableReadCloser = noms.NewNomsRangeReader(idxSch, iea.rowData, []*noms.ReadRange{ + var mapIter table.ReadCloser = noms.NewNomsRangeReader(idxSch, iea.rowData, []*noms.ReadRange{ {Start: partialKey, Inclusive: true, Reverse: false, Check: noms.InRangeCheckPartial(partialKey)}}) defer mapIter.Close(ctx) var r row.Row diff --git a/go/libraries/doltcore/table/inmem_table.go b/go/libraries/doltcore/table/inmem_table.go index 42ff302aed..65382b29d6 100644 --- a/go/libraries/doltcore/table/inmem_table.go +++ b/go/libraries/doltcore/table/inmem_table.go @@ -163,12 +163,12 @@ func (rd *InMemTableReader) VerifySchema(outSch schema.Schema) (bool, error) { return schema.VerifyInSchema(rd.tt.sch, outSch) } -// InMemTableWriter is an implementation of a TableWriter for an InMemTable +// InMemTableWriter is an implementation of a RowWriter for an InMemTable type InMemTableWriter struct { tt *InMemTable } -// NewInMemTableWriter creates an instance of a TableWriter from an InMemTable +// NewInMemTableWriter creates an instance of a RowWriter from an InMemTable func NewInMemTableWriter(imt *InMemTable) *InMemTableWriter { return &InMemTableWriter{imt} } diff --git a/go/libraries/doltcore/table/inmem_table_test.go b/go/libraries/doltcore/table/inmem_table_test.go index 0f859d2a33..7903adc903 100644 --- a/go/libraries/doltcore/table/inmem_table_test.go +++ b/go/libraries/doltcore/table/inmem_table_test.go @@ -90,7 +90,7 @@ func TestInMemTable(t *testing.T) { }() func() { - var r TableReadCloser + var r ReadCloser r = NewInMemTableReader(imt) defer r.Close(context.Background()) diff --git a/go/libraries/doltcore/table/io.go b/go/libraries/doltcore/table/io.go index 42c41eaa2e..a26d48f2bd 100644 --- a/go/libraries/doltcore/table/io.go +++ b/go/libraries/doltcore/table/io.go @@ -42,14 +42,14 @@ func GetRow(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, key types return } -// PipeRows will read a row from given TableReader and write it to the provided TableWriter. It will do this +// PipeRows will read a row from given TableReader and write it to the provided RowWriter. It will do this // for every row until the TableReader's ReadRow method returns io.EOF or encounters an error in either reading // or writing. The caller will need to handle closing the tables as necessary. If contOnBadRow is true, errors reading // or writing will be ignored and the pipe operation will continue. // // Returns a tuple: (number of rows written, number of errors ignored, error). In the case that err is non-nil, the // row counter fields in the tuple will be set to -1. -func PipeRows(ctx context.Context, rd TableReader, wr TableWriter, contOnBadRow bool) (int, int, error) { +func PipeRows(ctx context.Context, rd Reader, wr RowWriter, contOnBadRow bool) (int, int, error) { var numBad, numGood int for { r, err := rd.ReadRow(ctx) @@ -82,7 +82,7 @@ func PipeRows(ctx context.Context, rd TableReader, wr TableWriter, contOnBadRow // ReadAllRows reads all rows from a TableReader and returns a slice containing those rows. Usually this is used // for testing, or with very small data sets. -func ReadAllRows(ctx context.Context, rd TableReader, contOnBadRow bool) ([]row.Row, int, error) { +func ReadAllRows(ctx context.Context, rd Reader, contOnBadRow bool) ([]row.Row, int, error) { var rows []row.Row var err error diff --git a/go/libraries/doltcore/table/keyless_reader.go b/go/libraries/doltcore/table/keyless_reader.go new file mode 100644 index 0000000000..0f94be3094 --- /dev/null +++ b/go/libraries/doltcore/table/keyless_reader.go @@ -0,0 +1,157 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package table + +import ( + "context" + "io" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/row" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" + "github.com/dolthub/dolt/go/store/types" +) + +type keylessTableReader struct { + iter types.MapIterator + sch schema.Schema + + row row.Row + remainingCopies uint64 + bounded bool + remainingEntries uint64 +} + +var _ SqlTableReader = &keylessTableReader{} +var _ ReadCloser = &keylessTableReader{} + +// GetSchema implements the TableReader interface. +func (rdr *keylessTableReader) GetSchema() schema.Schema { + return rdr.sch +} + +// ReadSqlRow implements the SqlTableReader interface. +func (rdr *keylessTableReader) ReadRow(ctx context.Context) (row.Row, error) { + if rdr.remainingCopies <= 0 { + if rdr.bounded && rdr.remainingEntries == 0 { + return nil, io.EOF + } + + key, val, err := rdr.iter.Next(ctx) + if err != nil { + return nil, err + } else if key == nil { + return nil, io.EOF + } + + rdr.row, rdr.remainingCopies, err = row.KeylessRowsFromTuples(key.(types.Tuple), val.(types.Tuple)) + if err != nil { + return nil, err + } + + if rdr.remainingEntries > 0 { + rdr.remainingEntries -= 1 + } + + if rdr.remainingCopies == 0 { + return nil, row.ErrZeroCardinality + } + } + + rdr.remainingCopies -= 1 + + return rdr.row, nil +} + +// ReadSqlRow implements the SqlTableReader interface. +func (rdr *keylessTableReader) ReadSqlRow(ctx context.Context) (sql.Row, error) { + r, err := rdr.ReadRow(ctx) + if err != nil { + return nil, err + } + + return sqlutil.DoltRowToSqlRow(r, rdr.sch) +} + +// Close implements the TableReadCloser interface. +func (rdr *keylessTableReader) Close(_ context.Context) error { + return nil +} + +func newKeylessTableReader(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, buffered bool) (*keylessTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return nil, err + } + + return newKeylessTableReaderForRows(ctx, rows, sch, buffered) +} + +func newKeylessTableReaderForRows(ctx context.Context, rows types.Map, sch schema.Schema, buffered bool) (*keylessTableReader, error) { + var err error + var iter types.MapIterator + if buffered { + iter, err = rows.Iterator(ctx) + } else { + iter, err = rows.BufferedIterator(ctx) + } + if err != nil { + return nil, err + } + + return &keylessTableReader{ + iter: iter, + sch: sch, + }, nil +} + +func newKeylessTableReaderForPartition(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, start, end uint64) (SqlTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return nil, err + } + + iter, err := rows.BufferedIteratorAt(ctx, start) + if err != nil { + return nil, err + } + + return &keylessTableReader{ + iter: iter, + sch: sch, + remainingEntries: end - start, + bounded: true, + }, nil +} + +func newKeylessTableReaderFrom(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, val types.Value) (SqlTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return nil, err + } + + iter, err := rows.IteratorFrom(ctx, val) + if err != nil { + return nil, err + } + + return &keylessTableReader{ + iter: iter, + sch: sch, + }, nil +} diff --git a/go/libraries/doltcore/table/pipeline/procfunc_help.go b/go/libraries/doltcore/table/pipeline/procfunc_help.go index b5ae226c0a..56fe7d69a7 100644 --- a/go/libraries/doltcore/table/pipeline/procfunc_help.go +++ b/go/libraries/doltcore/table/pipeline/procfunc_help.go @@ -73,7 +73,7 @@ func ProcFuncForSourceFunc(sourceFunc SourceFunc) InFunc { } // ProcFuncForReader adapts a standard TableReader to work as an InFunc for a pipeline -func ProcFuncForReader(ctx context.Context, rd table.TableReader) InFunc { +func ProcFuncForReader(ctx context.Context, rd table.Reader) InFunc { return ProcFuncForSourceFunc(func() (row.Row, ImmutableProperties, error) { r, err := rd.ReadRow(ctx) @@ -135,7 +135,7 @@ func SourceFuncForRows(rows []row.Row) SourceFunc { } // ProcFuncForWriter adapts a standard TableWriter to work as an OutFunc for a pipeline -func ProcFuncForWriter(ctx context.Context, wr table.TableWriter) OutFunc { +func ProcFuncForWriter(ctx context.Context, wr table.RowWriter) OutFunc { return ProcFuncForSinkFunc(func(r row.Row, props ReadableMap) error { return wr.WriteRow(ctx, r) }) diff --git a/go/libraries/doltcore/table/pk_reader.go b/go/libraries/doltcore/table/pk_reader.go new file mode 100644 index 0000000000..9bdbd1eaf6 --- /dev/null +++ b/go/libraries/doltcore/table/pk_reader.go @@ -0,0 +1,153 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package table + +import ( + "context" + "io" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/row" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms" + "github.com/dolthub/dolt/go/store/types" +) + +type pkTableReader struct { + iter types.MapIterator + sch schema.Schema +} + +var _ SqlTableReader = pkTableReader{} +var _ ReadCloser = pkTableReader{} + +// GetSchema implements the TableReader interface. +func (rdr pkTableReader) GetSchema() schema.Schema { + return rdr.sch +} + +// ReadRow implements the TableReader interface. +func (rdr pkTableReader) ReadRow(ctx context.Context) (row.Row, error) { + key, val, err := rdr.iter.Next(ctx) + + if err != nil { + return nil, err + } else if key == nil { + return nil, io.EOF + } + + return row.FromNoms(rdr.sch, key.(types.Tuple), val.(types.Tuple)) +} + +// ReadSqlRow implements the SqlTableReader interface. +func (rdr pkTableReader) ReadSqlRow(ctx context.Context) (sql.Row, error) { + key, val, err := rdr.iter.Next(ctx) + + if err != nil { + return nil, err + } else if key == nil { + return nil, io.EOF + } + + return noms.SqlRowFromTuples(rdr.sch, key.(types.Tuple), val.(types.Tuple)) +} + +// Close implements the TableReadCloser interface. +func (rdr pkTableReader) Close(_ context.Context) error { + return nil +} + +func newPkTableReader(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, buffered bool) (pkTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return pkTableReader{}, err + } + + return newPkTableReaderForRows(ctx, rows, sch, buffered) +} + +func newPkTableReaderForRows(ctx context.Context, rows types.Map, sch schema.Schema, buffered bool) (pkTableReader, error) { + var err error + var iter types.MapIterator + if buffered { + iter, err = rows.Iterator(ctx) + } else { + iter, err = rows.BufferedIterator(ctx) + } + if err != nil { + return pkTableReader{}, err + } + + return pkTableReader{ + iter: iter, + sch: sch, + }, nil +} + +func newPkTableReaderFrom(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, val types.Value) (SqlTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return nil, err + } + + iter, err := rows.IteratorFrom(ctx, val) + if err != nil { + return nil, err + } + + return pkTableReader{ + iter: iter, + sch: sch, + }, nil +} + +type partitionTableReader struct { + SqlTableReader + remaining uint64 +} + +var _ SqlTableReader = &partitionTableReader{} + +func newPkTableReaderForPartition(ctx context.Context, tbl *doltdb.Table, sch schema.Schema, start, end uint64) (SqlTableReader, error) { + rows, err := tbl.GetNomsRowData(ctx) + if err != nil { + return nil, err + } + + iter, err := rows.BufferedIteratorAt(ctx, start) + if err != nil { + return nil, err + } + + return &partitionTableReader{ + SqlTableReader: pkTableReader{ + iter: iter, + sch: sch, + }, + remaining: end - start, + }, nil +} + +// ReadSqlRow implements the SqlTableReader interface. +func (rdr *partitionTableReader) ReadSqlRow(ctx context.Context) (sql.Row, error) { + if rdr.remaining == 0 { + return nil, io.EOF + } + rdr.remaining -= 1 + + return rdr.SqlTableReader.ReadSqlRow(ctx) +} diff --git a/go/libraries/doltcore/table/read_ahead_table_reader.go b/go/libraries/doltcore/table/read_ahead_table_reader.go index 35f5967eac..125e5a94ed 100644 --- a/go/libraries/doltcore/table/read_ahead_table_reader.go +++ b/go/libraries/doltcore/table/read_ahead_table_reader.go @@ -22,17 +22,17 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/async" ) -var _ TableReadCloser = (*AsyncReadAheadTableReader)(nil) +var _ ReadCloser = (*AsyncReadAheadTableReader)(nil) // AsyncReadAheadTableReader is a TableReadCloser implementation that spins up a go routine to keep reading data into // a buffered channel so that it is ready when the caller wants it. type AsyncReadAheadTableReader struct { - backingReader TableReadCloser + backingReader ReadCloser reader *async.AsyncReader } // NewAsyncReadAheadTableReader creates a new AsyncReadAheadTableReader -func NewAsyncReadAheadTableReader(tr TableReadCloser, bufferSize int) *AsyncReadAheadTableReader { +func NewAsyncReadAheadTableReader(tr ReadCloser, bufferSize int) *AsyncReadAheadTableReader { read := func(ctx context.Context) (interface{}, error) { return tr.ReadRow(ctx) } diff --git a/go/libraries/doltcore/table/table_reader.go b/go/libraries/doltcore/table/table_reader.go index 34277ad5be..dff02763f4 100644 --- a/go/libraries/doltcore/table/table_reader.go +++ b/go/libraries/doltcore/table/table_reader.go @@ -23,8 +23,8 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/schema" ) -// TableReader is an interface for reading rows from a table -type TableReader interface { +// Reader is an interface for reading rows from a table +type Reader interface { // GetSchema gets the schema of the rows that this reader will return GetSchema() schema.Schema @@ -34,29 +34,29 @@ type TableReader interface { ReadRow(ctx context.Context) (row.Row, error) } -// TableCloser is an interface for a table stream that can be closed to release resources -type TableCloser interface { +// Closer is an interface for a writer that can be closed to release resources +type Closer interface { // Close should release resources being held Close(ctx context.Context) error } -// TableReadCloser is an interface for reading rows from a table, that can be closed. -type TableReadCloser interface { - TableReader - TableCloser +// ReadCloser is an interface for reading rows from a table, that can be closed. +type ReadCloser interface { + Reader + Closer } type SqlRowReader interface { - TableReadCloser + ReadCloser ReadSqlRow(ctx context.Context) (sql.Row, error) } -// SqlTableReader is a TableReader that can read rows as sql.Row. +// SqlTableReader is a Reader that can read rows as sql.Row. type SqlTableReader interface { // GetSchema gets the schema of the rows that this reader will return GetSchema() schema.Schema - // ReadRow reads a row from a table as go-mysql-server sql.Row. + // ReadSqlRow reads a row from a table as go-mysql-server sql.Row. ReadSqlRow(ctx context.Context) (sql.Row, error) } diff --git a/go/libraries/doltcore/table/table_writer.go b/go/libraries/doltcore/table/table_writer.go index 988c2b5c9d..9bdb049cfd 100644 --- a/go/libraries/doltcore/table/table_writer.go +++ b/go/libraries/doltcore/table/table_writer.go @@ -22,19 +22,19 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/row" ) -// TableWriteCloser is an interface for writing rows to a table -type TableWriter interface { - // WriteRow will write a row to a table +// RowWriter knows how to write table rows to some destination +type RowWriter interface { + // WriteRow writes a row to the destination of this writer WriteRow(ctx context.Context, r row.Row) error } // TableWriteCloser is an interface for writing rows to a table, that can be closed type TableWriteCloser interface { - TableWriter - TableCloser + RowWriter + Closer } -type SqlTableWriter interface { +type SqlRowWriter interface { TableWriteCloser WriteSqlRow(ctx context.Context, r sql.Row) error } diff --git a/go/libraries/doltcore/table/typed/json/json_diff_writer.go b/go/libraries/doltcore/table/typed/json/json_diff_writer.go new file mode 100755 index 0000000000..ec2864a6cb --- /dev/null +++ b/go/libraries/doltcore/table/typed/json/json_diff_writer.go @@ -0,0 +1,194 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package json + +import ( + "context" + "encoding/json" + "fmt" + "io" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/utils/iohelp" +) + +type JsonDiffWriter struct { + rowWriter *RowWriter + wr io.WriteCloser + inModified bool + rowsWritten int +} + +var _ diff.SqlRowDiffWriter = (*JsonDiffWriter)(nil) + +func NewJsonDiffWriter(wr io.WriteCloser, outSch schema.Schema) (*JsonDiffWriter, error) { + writer, err := NewJSONWriterWithHeader(iohelp.NopWrCloser(wr), outSch, "", "", "") + if err != nil { + return nil, err + } + + return &JsonDiffWriter{ + rowWriter: writer, + wr: wr, + }, nil +} + +func (j *JsonDiffWriter) WriteRow( + ctx context.Context, + row sql.Row, + rowDiffType diff.ChangeType, + colDiffTypes []diff.ChangeType, +) error { + if len(row) != len(colDiffTypes) { + return fmt.Errorf("expected the same size for columns and diff types, got %d and %d", len(row), len(colDiffTypes)) + } + + prefix := "" + if j.inModified { + prefix = "," + } else if j.rowsWritten > 0 { + prefix = ",{" + } else { + prefix = "{" + } + + err := iohelp.WriteAll(j.wr, []byte(prefix)) + if err != nil { + return err + } + + diffMarker := "" + switch rowDiffType { + case diff.Removed: + diffMarker = "from_row" + case diff.ModifiedOld: + diffMarker = "from_row" + case diff.Added: + err := iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(`"%s":{},`, "from_row"))) + if err != nil { + return err + } + diffMarker = "to_row" + case diff.ModifiedNew: + diffMarker = "to_row" + } + + err = iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(`"%s":`, diffMarker))) + if err != nil { + return err + } + + err = j.rowWriter.WriteSqlRow(ctx, row) + if err != nil { + return err + } + + // The row writer buffers its output and we share an underlying write stream with it, so we need to flush after + // every call to WriteSqlRow + err = j.rowWriter.Flush() + if err != nil { + return err + } + + switch rowDiffType { + case diff.ModifiedNew, diff.ModifiedOld: + j.inModified = !j.inModified + case diff.Added: + case diff.Removed: + err := iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(`,"%s":{}`, "to_row"))) + if err != nil { + return err + } + } + + if !j.inModified { + err := iohelp.WriteAll(j.wr, []byte("}")) + if err != nil { + return err + } + j.rowsWritten++ + } + + return nil +} + +func (j *JsonDiffWriter) Close(ctx context.Context) error { + err := iohelp.WriteAll(j.wr, []byte("]")) + if err != nil { + return err + } + + err = j.rowWriter.Close(ctx) + if err != nil { + return err + } + + return j.wr.Close() +} + +type SchemaDiffWriter struct { + wr io.WriteCloser + schemaStmtsWritten int +} + +var _ diff.SchemaDiffWriter = (*SchemaDiffWriter)(nil) + +const jsonSchemaHeader = `[` +const jsonSchemaFooter = `]` + +func NewSchemaDiffWriter(wr io.WriteCloser) (*SchemaDiffWriter, error) { + err := iohelp.WriteAll(wr, []byte(jsonSchemaHeader)) + if err != nil { + return nil, err + } + + return &SchemaDiffWriter{ + wr: wr, + }, nil +} + +func (j *SchemaDiffWriter) WriteSchemaDiff(ctx context.Context, schemaDiffStatement string) error { + if j.schemaStmtsWritten > 0 { + err := iohelp.WriteAll(j.wr, []byte(",")) + if err != nil { + return err + } + } + + j.schemaStmtsWritten++ + + return iohelp.WriteAll(j.wr, []byte(fmt.Sprintf(`"%s"`, jsonEscape(schemaDiffStatement)))) +} + +func (j *SchemaDiffWriter) Close(ctx context.Context) error { + err := iohelp.WriteAll(j.wr, []byte(jsonSchemaFooter)) + if err != nil { + return err + } + + return j.wr.Close() +} + +func jsonEscape(s string) string { + b, err := json.Marshal(s) + if err != nil { + panic(err) + } + // Trim the beginning and trailing " character + return string(b[1 : len(b)-1]) +} diff --git a/go/libraries/doltcore/table/typed/json/writer.go b/go/libraries/doltcore/table/typed/json/writer.go index 452d73d9d4..b19858b536 100644 --- a/go/libraries/doltcore/table/typed/json/writer.go +++ b/go/libraries/doltcore/table/typed/json/writer.go @@ -38,31 +38,50 @@ const jsonFooter = `]}` var WriteBufSize = 256 * 1024 var defaultString = sql.MustCreateStringWithDefaults(sqltypes.VarChar, 16383) -type JSONWriter struct { +type RowWriter struct { closer io.Closer + header string + footer string + separator string bWr *bufio.Writer sch schema.Schema rowsWritten int } -var _ table.SqlTableWriter = (*JSONWriter)(nil) +var _ table.SqlRowWriter = (*RowWriter)(nil) -func NewJSONWriter(wr io.WriteCloser, outSch schema.Schema) (*JSONWriter, error) { +// NewJSONWriter returns a new writer that encodes rows as a single JSON object with a single key: "rows", which is a +// slice of all rows. To customize the output of the JSON object emitted, use |NewJSONWriterWithHeader| +func NewJSONWriter(wr io.WriteCloser, outSch schema.Schema) (*RowWriter, error) { + return NewJSONWriterWithHeader(wr, outSch, jsonHeader, jsonFooter, ",") +} + +func NewJSONWriterWithHeader(wr io.WriteCloser, outSch schema.Schema, header, footer, separator string) (*RowWriter, error) { bwr := bufio.NewWriterSize(wr, WriteBufSize) - err := iohelp.WriteAll(bwr, []byte(jsonHeader)) - if err != nil { - return nil, err + return &RowWriter{ + closer: wr, + bWr: bwr, + sch: outSch, + header: header, + footer: footer, + separator: separator, + }, nil +} + +func (j *RowWriter) GetSchema() schema.Schema { + return j.sch +} + +// WriteRow encodes the row given into JSON format and writes it, returning any error +func (j *RowWriter) WriteRow(ctx context.Context, r row.Row) error { + if j.rowsWritten == 0 { + err := iohelp.WriteAll(j.bWr, []byte(j.header)) + if err != nil { + return err + } } - return &JSONWriter{closer: wr, bWr: bwr, sch: outSch}, nil -} -func (jsonw *JSONWriter) GetSchema() schema.Schema { - return jsonw.sch -} - -// WriteRow will write a row to a table -func (jsonw *JSONWriter) WriteRow(ctx context.Context, r row.Row) error { - allCols := jsonw.sch.GetAllCols() + allCols := j.sch.GetAllCols() colValMap := make(map[string]interface{}, allCols.Size()) if err := allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { val, ok := r.GetColVal(tag) @@ -108,25 +127,31 @@ func (jsonw *JSONWriter) WriteRow(ctx context.Context, r row.Row) error { return errors.New("marshaling did not work") } - if jsonw.rowsWritten != 0 { - _, err := jsonw.bWr.WriteRune(',') - + if j.rowsWritten != 0 { + _, err := j.bWr.WriteString(j.separator) if err != nil { return err } } - newErr := iohelp.WriteAll(jsonw.bWr, data) + newErr := iohelp.WriteAll(j.bWr, data) if newErr != nil { return newErr } - jsonw.rowsWritten++ + j.rowsWritten++ return nil } -func (jsonw *JSONWriter) WriteSqlRow(ctx context.Context, row sql.Row) error { - allCols := jsonw.sch.GetAllCols() +func (j *RowWriter) WriteSqlRow(ctx context.Context, row sql.Row) error { + if j.rowsWritten == 0 { + err := iohelp.WriteAll(j.bWr, []byte(j.header)) + if err != nil { + return err + } + } + + allCols := j.sch.GetAllCols() colValMap := make(map[string]interface{}, allCols.Size()) if err := allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { val := row[allCols.TagToIdx[tag]] @@ -172,35 +197,39 @@ func (jsonw *JSONWriter) WriteSqlRow(ctx context.Context, row sql.Row) error { return errors.New("marshaling did not work") } - if jsonw.rowsWritten != 0 { - _, err := jsonw.bWr.WriteRune(',') - + if j.rowsWritten != 0 { + _, err := j.bWr.WriteString(j.separator) if err != nil { return err } } - newErr := iohelp.WriteAll(jsonw.bWr, data) + newErr := iohelp.WriteAll(j.bWr, data) if newErr != nil { return newErr } - jsonw.rowsWritten++ + j.rowsWritten++ return nil } -// Close should flush all writes, release resources being held -func (jsonw *JSONWriter) Close(ctx context.Context) error { - if jsonw.closer != nil { - err := iohelp.WriteAll(jsonw.bWr, []byte(jsonFooter)) +func (j *RowWriter) Flush() error { + return j.bWr.Flush() +} - if err != nil { - return err +// Close should flush all writes, release resources being held +func (j *RowWriter) Close(ctx context.Context) error { + if j.closer != nil { + if j.rowsWritten > 0 { + err := iohelp.WriteAll(j.bWr, []byte(j.footer)) + if err != nil { + return err + } } - errFl := jsonw.bWr.Flush() - errCl := jsonw.closer.Close() - jsonw.closer = nil + errFl := j.bWr.Flush() + errCl := j.closer.Close() + j.closer = nil if errCl != nil { return errCl @@ -208,8 +237,8 @@ func (jsonw *JSONWriter) Close(ctx context.Context) error { return errFl } - return errors.New("already closed") + return errors.New("already closed") } func marshalToJson(valMap interface{}) ([]byte, error) { diff --git a/go/libraries/doltcore/table/typed/parquet/writer.go b/go/libraries/doltcore/table/typed/parquet/writer.go index a532e73b4a..c2393c80cc 100644 --- a/go/libraries/doltcore/table/typed/parquet/writer.go +++ b/go/libraries/doltcore/table/typed/parquet/writer.go @@ -37,7 +37,7 @@ type ParquetWriter struct { sch schema.Schema } -var _ table.SqlTableWriter = (*ParquetWriter)(nil) +var _ table.SqlRowWriter = (*ParquetWriter)(nil) var typeMap = map[typeinfo.Identifier]string{ typeinfo.DatetimeTypeIdentifier: "type=INT64, convertedtype=TIMESTAMP_MICROS", diff --git a/go/libraries/doltcore/table/untyped/csv/writer.go b/go/libraries/doltcore/table/untyped/csv/writer.go index 6ad5acbe8e..96f24b67b6 100644 --- a/go/libraries/doltcore/table/untyped/csv/writer.go +++ b/go/libraries/doltcore/table/untyped/csv/writer.go @@ -45,7 +45,7 @@ type CSVWriter struct { useCRLF bool // True to use \r\n as the line terminator } -var _ table.SqlTableWriter = (*CSVWriter)(nil) +var _ table.SqlRowWriter = (*CSVWriter)(nil) // NewCSVWriter writes rows to the given WriteCloser based on the Schema and CSVFileInfo provided func NewCSVWriter(wr io.WriteCloser, outSch schema.Schema, info *CSVFileInfo) (*CSVWriter, error) { diff --git a/go/libraries/doltcore/table/untyped/tabular/fixedwidth_tablewriter.go b/go/libraries/doltcore/table/untyped/tabular/fixedwidth_tablewriter.go index 4224d39f8a..7bf141929e 100644 --- a/go/libraries/doltcore/table/untyped/tabular/fixedwidth_tablewriter.go +++ b/go/libraries/doltcore/table/untyped/tabular/fixedwidth_tablewriter.go @@ -59,7 +59,7 @@ type FixedWidthTableWriter struct { flushedSampleBuffer bool } -var _ table.SqlTableWriter = (*FixedWidthTableWriter)(nil) +var _ table.SqlRowWriter = (*FixedWidthTableWriter)(nil) type tableRow struct { columns []string diff --git a/integration-tests/bats/json-diff.bats b/integration-tests/bats/json-diff.bats new file mode 100644 index 0000000000..69ccb4924a --- /dev/null +++ b/integration-tests/bats/json-diff.bats @@ -0,0 +1,318 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash + +setup() { + setup_common + + dolt sql </dev/null +} + +function no_stdout { + "$@" 1>/dev/null +} + +function count_string { + cmd="echo '$1' | grep -o '$2' | wc -l" + eval "$cmd" +} + +@test "json-diff: works with spaces in column names" { + dolt sql -q 'CREATE table t (pk int primary key, `type of food` varchar(100));' + dolt sql -q "INSERT INTO t VALUES (1, 'ramen');" + + EXPECTED=$(cat <<'EOF' +{"tables":[{"name":"t","schema_diff":["CREATE TABLE `t` (\n `pk` int NOT NULL,\n `type of food` varchar(100),\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin;"],"data_diff":[{"from_row":{},"to_row":{"pk":1,"type of food":"ramen"}}]},{"name":"test","schema_diff":["CREATE TABLE `test` (\n `pk` bigint NOT NULL COMMENT 'tag:0',\n `c1` bigint COMMENT 'tag:1',\n `c2` bigint COMMENT 'tag:2',\n `c3` bigint COMMENT 'tag:3',\n `c4` bigint COMMENT 'tag:4',\n `c5` bigint COMMENT 'tag:5',\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin;"],"data_diff":[]}]} +EOF +) + + dolt diff -r json + run dolt diff -r json + [ $status -eq 0 ] + [[ $output =~ "$EXPECTED" ]] || false +}