diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index e07e93238e..c051e5c07d 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -131,7 +131,7 @@ func (cmd DiffCmd) ArgParser() *argparser.ArgParser { ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.") ap.SupportsFlag(CachedFlag, "c", "Show only the unstaged data changes.") - ap.SupportsFlag(SkinnyFlag, "sk", "Shows only the primary key and the rows that changed") + ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.") return ap } @@ -558,6 +558,7 @@ func diffRows( if err != nil { return errhand.VerboseErrorFromError(err) } + defer rowWriter.Close(ctx) // can't diff if !diffable { @@ -619,7 +620,7 @@ func diffRows( defer rowIter.Close(sqlCtx) if dArgs.skinny { - oldRows, newRows, modifiedColNames, err := getDiffRows(sqlCtx, sch, rowIter, unionSch, dArgs.skinny) + modifiedColNames, err := getModifiedCols(sqlCtx, rowIter, unionSch, sch) if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } @@ -639,21 +640,22 @@ func diffRows( if err != nil { return errhand.VerboseErrorFromError(err) } + defer rowWriter.Close(ctx) - writeFilteredResults(sqlCtx, oldRows, newRows, rowWriter) - if err != nil { - return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() - } - } else { - oldRows, newRows, _, err := getDiffRows(sqlCtx, sch, rowIter, unionSch, dArgs.skinny) - if err != nil { + // reset the row iterator + rowIter.Close(sqlCtx) + _, rowIter, err = se.Query(sqlCtx, query) + if sql.ErrSyntaxError.Is(err) { + return errhand.BuildDError("Failed to parse diff query. Invalid where clause?\nDiff query: %s", query).AddCause(err).Build() + } else if err != nil { return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } + defer rowIter.Close(sqlCtx) + } - writeFilteredResults(sqlCtx, oldRows, newRows, rowWriter) - if err != nil { - return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() - } + err = writeDiffResults(sqlCtx, sch, unionSch, rowIter, rowWriter, dArgs.skinny) + if err != nil { + return errhand.BuildDError("Error running diff query:\n%s", query).AddCause(err).Build() } return nil @@ -689,38 +691,36 @@ func getColumnNamesString(fromSch, toSch schema.Schema) string { return strings.Join(cols, ",") } -func getDiffRows( - sqlCtx *sql.Context, +func writeDiffResults( + ctx *sql.Context, diffQuerySch sql.Schema, - rowIter sql.RowIter, - unionSch sql.Schema, + targetSch sql.Schema, + iter sql.RowIter, + writer diff.SqlRowDiffWriter, filterChangedCols bool, -) ([]rowDiff, []rowDiff, map[string]bool, error) { - oldRows, newRows := []rowDiff{}, []rowDiff{} - modifiedColNames := make(map[string]bool) - - ds, err := newDiffSplitter(diffQuerySch, unionSch) +) error { + ds, err := newDiffSplitter(diffQuerySch, targetSch) if err != nil { - return oldRows, newRows, modifiedColNames, err + return err } for { - r, err := rowIter.Next(sqlCtx) + r, err := iter.Next(ctx) if err == io.EOF { - break + return nil + } else if err != nil { + return err } oldRow, newRow, err := ds.splitDiffResultRow(r) if err != nil { - return oldRows, newRows, modifiedColNames, err + return err } if filterChangedCols { var filteredOldRow, filteredNewRow rowDiff for i, changeType := range oldRow.colDiffs { - if changeType != diff.None || unionSch[i].PrimaryKey { - modifiedColNames[unionSch[i].Name] = true - + if changeType != diff.None || targetSch[i].PrimaryKey { filteredOldRow.row = append(filteredOldRow.row, oldRow.row[i]) filteredOldRow.colDiffs = append(filteredOldRow.colDiffs, oldRow.colDiffs[i]) filteredOldRow.rowDiff = oldRow.rowDiff @@ -731,26 +731,9 @@ func getDiffRows( } } - oldRows = append(oldRows, filteredOldRow) - newRows = append(newRows, filteredNewRow) - } else { - oldRows = append(oldRows, oldRow) - newRows = append(newRows, newRow) + oldRow = filteredOldRow + newRow = filteredNewRow } - } - - return oldRows, newRows, modifiedColNames, nil -} - -func writeFilteredResults( - ctx *sql.Context, - oldRows []rowDiff, - newRows []rowDiff, - writer diff.SqlRowDiffWriter, -) error { - for i := range oldRows { - oldRow := oldRows[i] - newRow := newRows[i] if oldRow.row != nil { err := writer.WriteRow(ctx, oldRow.row, oldRow.rowDiff, oldRow.colDiffs) @@ -766,6 +749,37 @@ func writeFilteredResults( } } } - - return writer.Close(ctx) +} + +func getModifiedCols( + ctx *sql.Context, + iter sql.RowIter, + unionSch sql.Schema, + diffQuerySch sql.Schema, +) (map[string]bool, error) { + modifiedColNames := make(map[string]bool) + for { + r, err := iter.Next(ctx) + if err == io.EOF { + break + } + + ds, err := newDiffSplitter(diffQuerySch, unionSch) + if err != nil { + return modifiedColNames, err + } + + oldRow, _, err := ds.splitDiffResultRow(r) + if err != nil { + return modifiedColNames, err + } + + for i, changeType := range oldRow.colDiffs { + if changeType != diff.None || unionSch[i].PrimaryKey { + modifiedColNames[unionSch[i].Name] = true + } + } + } + + return modifiedColNames, nil }