Address reviewer comments

This commit is contained in:
Tan Yong Zhi
2022-09-09 21:21:44 +08:00
parent f8ff89b471
commit 74edb799fa

View File

@@ -131,7 +131,7 @@ func (cmd DiffCmd) ArgParser() *argparser.ArgParser {
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsFlag(CachedFlag, "c", "Show only the unstaged data changes.")
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only the primary key and the rows that changed")
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
return ap
}
@@ -558,6 +558,7 @@ func diffRows(
if err != nil {
return errhand.VerboseErrorFromError(err)
}
defer rowWriter.Close(ctx)
// can't diff
if !diffable {
@@ -619,7 +620,7 @@ func diffRows(
defer rowIter.Close(sqlCtx)
if dArgs.skinny {
oldRows, newRows, modifiedColNames, err := getDiffRows(sqlCtx, sch, rowIter, unionSch, dArgs.skinny)
modifiedColNames, err := getModifiedCols(sqlCtx, rowIter, unionSch, sch)
if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
@@ -639,21 +640,22 @@ func diffRows(
if err != nil {
return errhand.VerboseErrorFromError(err)
}
defer rowWriter.Close(ctx)
writeFilteredResults(sqlCtx, oldRows, newRows, rowWriter)
if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
} else {
oldRows, newRows, _, err := getDiffRows(sqlCtx, sch, rowIter, unionSch, dArgs.skinny)
if err != nil {
// reset the row iterator
rowIter.Close(sqlCtx)
_, rowIter, err = se.Query(sqlCtx, query)
if sql.ErrSyntaxError.Is(err) {
return errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build()
} else if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
defer rowIter.Close(sqlCtx)
}
writeFilteredResults(sqlCtx, oldRows, newRows, rowWriter)
if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
err = writeDiffResults(sqlCtx, sch, unionSch, rowIter, rowWriter, dArgs.skinny)
if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build()
}
return nil
@@ -689,38 +691,36 @@ func getColumnNamesString(fromSch, toSch schema.Schema) string {
return strings.Join(cols, ",")
}
func getDiffRows(
sqlCtx *sql.Context,
func writeDiffResults(
ctx *sql.Context,
diffQuerySch sql.Schema,
rowIter sql.RowIter,
unionSch sql.Schema,
targetSch sql.Schema,
iter sql.RowIter,
writer diff.SqlRowDiffWriter,
filterChangedCols bool,
) ([]rowDiff, []rowDiff, map[string]bool, error) {
oldRows, newRows := []rowDiff{}, []rowDiff{}
modifiedColNames := make(map[string]bool)
ds, err := newDiffSplitter(diffQuerySch, unionSch)
) error {
ds, err := newDiffSplitter(diffQuerySch, targetSch)
if err != nil {
return oldRows, newRows, modifiedColNames, err
return err
}
for {
r, err := rowIter.Next(sqlCtx)
r, err := iter.Next(ctx)
if err == io.EOF {
break
return nil
} else if err != nil {
return err
}
oldRow, newRow, err := ds.splitDiffResultRow(r)
if err != nil {
return oldRows, newRows, modifiedColNames, err
return err
}
if filterChangedCols {
var filteredOldRow, filteredNewRow rowDiff
for i, changeType := range oldRow.colDiffs {
if changeType != diff.None || unionSch[i].PrimaryKey {
modifiedColNames[unionSch[i].Name] = true
if changeType != diff.None || targetSch[i].PrimaryKey {
filteredOldRow.row = append(filteredOldRow.row, oldRow.row[i])
filteredOldRow.colDiffs = append(filteredOldRow.colDiffs, oldRow.colDiffs[i])
filteredOldRow.rowDiff = oldRow.rowDiff
@@ -731,26 +731,9 @@ func getDiffRows(
}
}
oldRows = append(oldRows, filteredOldRow)
newRows = append(newRows, filteredNewRow)
} else {
oldRows = append(oldRows, oldRow)
newRows = append(newRows, newRow)
oldRow = filteredOldRow
newRow = filteredNewRow
}
}
return oldRows, newRows, modifiedColNames, nil
}
func writeFilteredResults(
ctx *sql.Context,
oldRows []rowDiff,
newRows []rowDiff,
writer diff.SqlRowDiffWriter,
) error {
for i := range oldRows {
oldRow := oldRows[i]
newRow := newRows[i]
if oldRow.row != nil {
err := writer.WriteRow(ctx, oldRow.row, oldRow.rowDiff, oldRow.colDiffs)
@@ -766,6 +749,37 @@ func writeFilteredResults(
}
}
}
return writer.Close(ctx)
}
func getModifiedCols(
ctx *sql.Context,
iter sql.RowIter,
unionSch sql.Schema,
diffQuerySch sql.Schema,
) (map[string]bool, error) {
modifiedColNames := make(map[string]bool)
for {
r, err := iter.Next(ctx)
if err == io.EOF {
break
}
ds, err := newDiffSplitter(diffQuerySch, unionSch)
if err != nil {
return modifiedColNames, err
}
oldRow, _, err := ds.splitDiffResultRow(r)
if err != nil {
return modifiedColNames, err
}
for i, changeType := range oldRow.colDiffs {
if changeType != diff.None || unionSch[i].PrimaryKey {
modifiedColNames[unionSch[i].Name] = true
}
}
}
return modifiedColNames, nil
}