Progress on new diff --summary

This commit is contained in:
Taylor Bantle
2023-02-22 15:36:22 -08:00
parent e7f5c3a6d1
commit 30ec1e7e79
4 changed files with 187 additions and 65 deletions

View File

@@ -47,6 +47,7 @@ const (
SchemaOnlyDiff diffPart = 1 // 0b0001
DataOnlyDiff diffPart = 2 // 0b0010
Stat diffPart = 4 // 0b0100
Summary diffPart = 8 // 0b0101
SchemaAndDataDiff = SchemaOnlyDiff | DataOnlyDiff
@@ -54,16 +55,17 @@ const (
SQLDiffOutput diffOutput = 2
JsonDiffOutput diffOutput = 3
DataFlag = "data"
SchemaFlag = "schema"
StatFlag = "stat"
whereParam = "where"
limitParam = "limit"
SQLFlag = "sql"
CachedFlag = "cached"
SkinnyFlag = "skinny"
MergeBase = "merge-base"
DiffMode = "diff-mode"
DataFlag = "data"
SchemaFlag = "schema"
StatFlag = "stat"
SummaryFlag = "summary"
whereParam = "where"
limitParam = "limit"
SQLFlag = "sql"
CachedFlag = "cached"
SkinnyFlag = "skinny"
MergeBase = "merge-base"
DiffMode = "diff-mode"
)
var diffDocs = cli.CommandDocumentationContent{
@@ -139,6 +141,7 @@ func (cmd DiffCmd) ArgParser() *argparser.ArgParser {
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
@@ -173,9 +176,9 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, d
}
func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.Contains(StatFlag) {
if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) {
return errhand.BuildDError("invalid Arguments: --stat cannot be combined with --schema or --data").Build()
return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build()
}
}
@@ -199,6 +202,8 @@ func parseDiffArgs(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPar
dArgs.diffParts = SchemaOnlyDiff
} else if apr.Contains(StatFlag) {
dArgs.diffParts = Stat
} else if apr.Contains(SummaryFlag) {
dArgs.diffParts = Summary
}
dArgs.skinny = apr.Contains(SkinnyFlag)
@@ -522,9 +527,13 @@ func diffUserTable(
return errhand.BuildDError("error: both tables in tableDelta are nil").Build()
}
err := dw.BeginTable(ctx, td)
if err != nil {
return errhand.VerboseErrorFromError(err)
shouldSummary := dArgs.diffParts&Summary != 0
if !shouldSummary {
err := dw.BeginTable(ctx, td)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
}
fromSch, toSch, err := td.GetSchemas(ctx)
@@ -536,6 +545,18 @@ func diffUserTable(
return printDiffStat(ctx, td, fromSch.GetAllCols().Size(), toSch.GetAllCols().Size())
}
if shouldSummary {
dataChanged, verr := getDataHasChanged(ctx, engine, td, dArgs)
if verr != nil {
return verr
}
summ, err := td.GetSummary(ctx, dataChanged)
if err != nil {
return errhand.BuildDError("could not get table delta summary").AddCause(err).Build()
}
return printDiffSummary(ctx, summ)
}
if dArgs.diffParts&SchemaOnlyDiff != 0 {
err := dw.WriteSchemaDiff(ctx, dArgs.toRoot, td)
if err != nil {
@@ -663,6 +684,79 @@ func sqlSchemaDiff(ctx context.Context, td diff.TableDelta, toSchemas map[string
return ddlStatements, nil
}
func getRowDiffIter(
ctx context.Context,
se *engine.SqlEngine,
td diff.TableDelta,
dArgs *diffArgs,
where string,
limit int,
) (*sql.Context, sql.Schema, sql.RowIter, string, errhand.VerboseError) {
diffable := schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch)
canSqlDiff := !(td.ToSch == nil || (td.FromSch != nil && !schema.SchemasAreEqual(td.FromSch, td.ToSch)))
// can't diff
if !diffable {
// TODO: this messes up some structured output if the user didn't redirect it
cli.PrintErrf("Primary key sets differ between revisions for table %s, skipping data diff\n", td.ToName)
return nil, nil, nil, "", nil
} else if dArgs.diffOutput == SQLDiffOutput && !canSqlDiff {
// TODO: this is overly broad, we can absolutely do better
_, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff\n")
return nil, nil, nil, "", nil
}
// do the data diff
tableName := td.CurName()
columns := getColumnNamesString(td.FromSch, td.ToSch)
query := fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columns, "diff_type", dArgs.fromRef, dArgs.toRef, tableName)
if len(where) > 0 {
query += " where " + where
}
if limit >= 0 {
query += " limit " + strconv.Itoa(limit)
}
sqlCtx, err := engine.NewLocalSqlContext(ctx, se)
if err != nil {
return nil, nil, nil, "", errhand.VerboseErrorFromError(err)
}
sch, rowIter, err := se.Query(sqlCtx, query)
if sql.ErrSyntaxError.Is(err) {
return nil, nil, nil, "", errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build()
} else if err != nil {
return nil, nil, nil, "", errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
return sqlCtx, sch, rowIter, query, nil
}
func getDataHasChanged(ctx context.Context,
se *engine.SqlEngine,
td diff.TableDelta,
dArgs *diffArgs,
) (bool, errhand.VerboseError) {
sqlCtx, _, rowIter, _, verr := getRowDiffIter(ctx, se, td, dArgs, "", 1)
if verr != nil {
return false, verr
}
defer rowIter.Close(sqlCtx)
_, err := rowIter.Next(sqlCtx)
if err == io.EOF {
return false, nil
} else if err != nil {
return false, errhand.VerboseErrorFromError(err)
}
return true, nil
}
func diffRows(
ctx context.Context,
se *engine.SqlEngine,
@@ -670,9 +764,6 @@ func diffRows(
dArgs *diffArgs,
dw diffWriter,
) errhand.VerboseError {
diffable := schema.ArePrimaryKeySetsDiffable(td.Format(), td.FromSch, td.ToSch)
canSqlDiff := !(td.ToSch == nil || (td.FromSch != nil && !schema.SchemasAreEqual(td.FromSch, td.ToSch)))
var toSch, fromSch sql.Schema
if td.FromSch != nil {
pkSch, err := sqlutil.FromDoltSchema(td.FromName, td.FromSch)
@@ -698,25 +789,6 @@ func diffRows(
return errhand.VerboseErrorFromError(err)
}
// can't diff
if !diffable {
// TODO: this messes up some structured output if the user didn't redirect it
cli.PrintErrf("Primary key sets differ between revisions for table %s, skipping data diff\n", td.ToName)
err := rowWriter.Close(ctx)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
return nil
} else if dArgs.diffOutput == SQLDiffOutput && !canSqlDiff {
// TODO: this is overly broad, we can absolutely do better
_, _ = fmt.Fprintf(cli.CliErr, "Incompatible schema change, skipping data diff\n")
err := rowWriter.Close(ctx)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
return nil
}
// no data diff requested
if dArgs.diffParts&DataOnlyDiff == 0 {
err := rowWriter.Close(ctx)
@@ -726,33 +798,13 @@ func diffRows(
return nil
}
// do the data diff
tableName := td.ToName
if len(tableName) == 0 {
tableName = td.FromName
}
columns := getColumnNamesString(td.FromSch, td.ToSch)
query := fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columns, "diff_type", dArgs.fromRef, dArgs.toRef, tableName)
if len(dArgs.where) > 0 {
query += " where " + dArgs.where
}
if dArgs.limit >= 0 {
query += " limit " + strconv.Itoa(dArgs.limit)
}
sqlCtx, err := engine.NewLocalSqlContext(ctx, se)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
sch, rowIter, err := se.Query(sqlCtx, query)
if sql.ErrSyntaxError.Is(err) {
return errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build()
} else if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
sqlCtx, sch, rowIter, query, verr := getRowDiffIter(ctx, se, td, dArgs, dArgs.where, dArgs.limit)
if verr != nil {
err := rowWriter.Close(ctx)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
return verr
}
defer rowIter.Close(sqlCtx)

View File

@@ -65,6 +65,27 @@ func newDiffWriter(diffOutput diffOutput) (diffWriter, error) {
}
}
func printDiffSummary(ctx context.Context, summ *diff.TableDeltaSummary) errhand.VerboseError {
bold := color.New(color.Bold)
fmtStr := fmt.Sprintf("%%s%%%ds\t| %%s", 35-len(summ.TableName))
line := fmt.Sprintf(fmtStr, bold.Sprintf(summ.TableName), "", getDiffTypeString(summ.DiffType))
line = fmt.Sprintf("%s | Data changed: %t | Schema changed: %t", line, summ.HasDataChanges, summ.HasSchemaChanges)
cli.Println(line)
return nil
}
func getDiffTypeString(t string) string {
if t == "added" {
return color.HiGreenString("added ")
}
if t == "dropped" {
return color.HiRedString("dropped ")
}
return color.HiYellowString("modified")
}
func printDiffStat(ctx context.Context, td diff.TableDelta, oldColLen, newColLen int) errhand.VerboseError {
// todo: use errgroup.Group
ae := atomicerr.New()

View File

@@ -315,7 +315,7 @@ func logCommits(ctx context.Context, dEnv *env.DoltEnv, opts *logOpts) int {
// Get all remote branches
remotes, err := dEnv.DoltDB.GetRemotesWithHashes(ctx)
if err != nil {
cli.PrintErrln(color.HiRedString("Fatal error: cannot get Remotes information."))
cli.PrintErrln(color.RedString("Fatal error: cannot get Remotes information."))
return 1
}
for _, r := range remotes {

View File

@@ -18,6 +18,7 @@ import (
"context"
"fmt"
"sort"
"strings"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/utils/set"
@@ -57,6 +58,13 @@ type TableDelta struct {
FromFksParentSch map[string]schema.Schema
}
type TableDeltaSummary struct {
DiffType string
HasDataChanges bool
HasSchemaChanges bool
TableName string
}
// GetStagedUnstagedTableDeltas represents staged and unstaged changes as TableDelta slices.
func GetStagedUnstagedTableDeltas(ctx context.Context, roots doltdb.Roots) (staged, unstaged []TableDelta, err error) {
staged, err = GetTableDeltas(ctx, roots.Head, roots.Staged)
@@ -280,6 +288,16 @@ func (td TableDelta) IsRename() bool {
return td.FromName != td.ToName
}
func (td TableDelta) Type() string {
if td.IsAdd() {
return "added"
}
if td.IsDrop() {
return "dropped"
}
return "modified"
}
// HasHashChanged returns true if the hash of the table content has changed between
// the fromRoot and toRoot.
func (td TableDelta) HasHashChanged() (bool, error) {
@@ -387,6 +405,20 @@ func (td TableDelta) IsKeyless(ctx context.Context) (bool, error) {
}
}
// GetSummary returns a summary of the table delta.
func (td TableDelta) GetSummary(ctx context.Context, dataChanged bool) (*TableDeltaSummary, error) {
schemaChanged, err := td.HasSchemaChanged(ctx)
if err != nil {
return nil, err
}
return &TableDeltaSummary{
HasSchemaChanges: schemaChanged,
HasDataChanges: dataChanged,
DiffType: td.Type(),
TableName: td.CurName(),
}, nil
}
// GetRowData returns the table's row data at the fromRoot and toRoot, or an empty map if the table did not exist.
func (td TableDelta) GetRowData(ctx context.Context) (from, to durable.Index, err error) {
if td.FromTable == nil && td.ToTable == nil {
@@ -435,3 +467,20 @@ func fkSlicesAreEqual(from, to []doltdb.ForeignKey) bool {
}
return true
}
func getColumnNamesString(fromSch, toSch schema.Schema) string {
var cols []string
if fromSch != nil {
fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
cols = append(cols, fmt.Sprintf("`from_%s`", col.Name))
return false, nil
})
}
if toSch != nil {
toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
cols = append(cols, fmt.Sprintf("`to_%s`", col.Name))
return false, nil
})
}
return strings.Join(cols, ",")
}