diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 0aab39c0e7..2bf10b18bd 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -311,6 +311,7 @@ func CreateDiffArgParser(isTableFunction bool) *argparser.ArgParser { ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.") ap.SupportsString(WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.") ap.SupportsInt(LimitParam, "", "record_count", "limits to the first N diffs.") + ap.SupportsString(FilterParam, "", "diff_type", "filters results based on the type of change (added, modified, renamed, dropped). 'removed' is accepted as an alias for 'dropped'.") ap.SupportsFlag(StagedFlag, "", "Show only the staged data changes.") ap.SupportsFlag(CachedFlag, "c", "Synonym for --staged") ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit") diff --git a/go/cmd/dolt/cli/flags.go b/go/cmd/dolt/cli/flags.go index 691ddbd973..62a2129c5e 100644 --- a/go/cmd/dolt/cli/flags.go +++ b/go/cmd/dolt/cli/flags.go @@ -99,6 +99,7 @@ const ( SummaryFlag = "summary" WhereParam = "where" LimitParam = "limit" + FilterParam = "filter" MergeBase = "merge-base" DiffMode = "diff-mode" ReverseFlag = "reverse" diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index d35a2756b4..76b5b1b688 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -86,6 +86,8 @@ The diffs displayed can be limited to show the first N by providing the paramete To filter which data rows are displayed, use {{.EmphasisLeft}}--where {{.EmphasisRight}}. Table column names in the filter expression must be prefixed with {{.EmphasisLeft}}from_{{.EmphasisRight}} or {{.EmphasisLeft}}to_{{.EmphasisRight}}, e.g. {{.EmphasisLeft}}to_COLUMN_NAME > 100{{.EmphasisRight}} or {{.EmphasisLeft}}from_COLUMN_NAME + to_COLUMN_NAME = 0{{.EmphasisRight}}. +To filter diff output by change type, use {{.EmphasisLeft}}--filter {{.EmphasisRight}} where {{.EmphasisLeft}}{{.EmphasisRight}} is one of {{.EmphasisLeft}}added{{.EmphasisRight}}, {{.EmphasisLeft}}modified{{.EmphasisRight}}, {{.EmphasisLeft}}renamed{{.EmphasisRight}}, or {{.EmphasisLeft}}dropped{{.EmphasisRight}}. The {{.EmphasisLeft}}added{{.EmphasisRight}} filter shows only additions (new tables or rows), {{.EmphasisLeft}}modified{{.EmphasisRight}} shows only schema modifications or row updates, {{.EmphasisLeft}}renamed{{.EmphasisRight}} shows only renamed tables, and {{.EmphasisLeft}}dropped{{.EmphasisRight}} shows only deletions (dropped tables or deleted rows). You can also use {{.EmphasisLeft}}removed{{.EmphasisRight}} as an alias for {{.EmphasisLeft}}dropped{{.EmphasisRight}}. For example, {{.EmphasisLeft}}dolt diff --filter=dropped{{.EmphasisRight}} shows only deleted rows and dropped tables. + The {{.EmphasisLeft}}--diff-mode{{.EmphasisRight}} argument controls how modified rows are presented when the format output is set to {{.EmphasisLeft}}tabular{{.EmphasisRight}}. When set to {{.EmphasisLeft}}row{{.EmphasisRight}}, modified rows are presented as old and new rows. When set to {{.EmphasisLeft}}line{{.EmphasisRight}}, modified rows are presented as a single row, and changes are presented using "+" and "-" within the column. When set to {{.EmphasisLeft}}in-place{{.EmphasisRight}}, modified rows are presented as a single row, and changes are presented side-by-side with a color distinction (requires a color-enabled terminal). When set to {{.EmphasisLeft}}context{{.EmphasisRight}}, rows that contain at least one column that spans multiple lines uses {{.EmphasisLeft}}line{{.EmphasisRight}}, while all other rows use {{.EmphasisLeft}}row{{.EmphasisRight}}. The default value is {{.EmphasisLeft}}context{{.EmphasisRight}}. `, Synopsis: []string{ @@ -102,6 +104,7 @@ type diffDisplaySettings struct { where string skinny bool includeCols []string + filter *diffTypeFilter } type diffDatasets struct { @@ -130,6 +133,141 @@ type diffStatistics struct { NewCellCount uint64 } +// diffTypeFilter manages which diff types should be included in the output. +// When filters is nil or empty, all types are included. +type diffTypeFilter struct { + // Map of diff type -> should include + // If nil or empty, includes all types + filters map[string]bool +} + +// newDiffTypeFilter creates a filter for the specified diff type. +// Pass diff.DiffTypeAll or empty string to include all types. +// Accepts "removed" as an alias for "dropped" for user convenience. +func newDiffTypeFilter(filterType string) *diffTypeFilter { + if filterType == "" || filterType == diff.DiffTypeAll { + return &diffTypeFilter{filters: nil} // nil means include all + } + + // Map "removed" to "dropped" (alias for user convenience) + internalFilterType := filterType + if filterType == "removed" { + internalFilterType = diff.DiffTypeDropped + } + + return &diffTypeFilter{ + filters: map[string]bool{ + internalFilterType: true, + }, + } +} + +// shouldInclude checks if the given diff type should be included. +// Uses TableDeltaSummary.DiffType field for table-level filtering. +func (df *diffTypeFilter) shouldInclude(diffType string) bool { + // nil or empty filters means include everything + if df.filters == nil || len(df.filters) == 0 { + return true + } + + return df.filters[diffType] +} + +// isValid validates the filter configuration +func (df *diffTypeFilter) isValid() bool { + if df.filters == nil { + return true + } + + for filterType := range df.filters { + if filterType != diff.DiffTypeAdded && + filterType != diff.DiffTypeModified && + filterType != diff.DiffTypeRenamed && + filterType != diff.DiffTypeDropped { + return false + } + } + return true +} + +// shouldSkipRow checks if a row should be skipped based on the filter settings. +// Uses the DiffType infrastructure for consistency with table-level filtering. +func shouldSkipRow(filter *diffTypeFilter, rowChangeType diff.ChangeType) bool { + if filter == nil { + return false + } + + // Don't filter None - it represents "no row" on one side of the diff + if rowChangeType == diff.None { + return false + } + + // Convert row-level ChangeType to table-level DiffType string + diffType := diff.ChangeTypeToDiffType(rowChangeType) + + // Use the map-based shouldInclude method + return !filter.shouldInclude(diffType) +} + +// shouldUseLazyHeader determines if we should delay printing the table header +// until we know there are rows to display. This prevents empty headers when +// all rows are filtered out in data-only diffs. +func shouldUseLazyHeader(dArgs *diffArgs, tableSummary diff.TableDeltaSummary) bool { + return dArgs.filter != nil && dArgs.filter.filters != nil && + !tableSummary.SchemaChange && !tableSummary.IsRename() +} + +// lazyRowWriter wraps a SqlRowDiffWriter and delays calling BeginTable +// until the first row is actually written. This prevents empty table headers +// when all rows are filtered out. +type lazyRowWriter struct { + writer diff.SqlRowDiffWriter + + // Callback to invoke before first write + // Set to nil after first call + onFirstWrite func() error +} + +// newLazyRowWriter creates a lazy writer that wraps the given writer. +// The onFirstWrite callback is invoked exactly once before the first write. +func newLazyRowWriter(writer diff.SqlRowDiffWriter, onFirstWrite func() error) *lazyRowWriter { + return &lazyRowWriter{ + writer: writer, + onFirstWrite: onFirstWrite, + } +} + +// WriteRow implements diff.SqlRowDiffWriter +func (l *lazyRowWriter) WriteRow(ctx *sql.Context, row sql.Row, diffType diff.ChangeType, colDiffTypes []diff.ChangeType) error { + // Initialize on first write + if l.onFirstWrite != nil { + if err := l.onFirstWrite(); err != nil { + return err + } + l.onFirstWrite = nil // Prevent double-initialization + } + + return l.writer.WriteRow(ctx, row, diffType, colDiffTypes) +} + +// WriteCombinedRow implements diff.SqlRowDiffWriter +func (l *lazyRowWriter) WriteCombinedRow(ctx *sql.Context, oldRow, newRow sql.Row, mode diff.Mode) error { + // Initialize on first write + if l.onFirstWrite != nil { + if err := l.onFirstWrite(); err != nil { + return err + } + l.onFirstWrite = nil + } + + return l.writer.WriteCombinedRow(ctx, oldRow, newRow, mode) +} + +// Close implements diff.SqlRowDiffWriter +func (l *lazyRowWriter) Close(ctx context.Context) error { + return l.writer.Close(ctx) +} + type DiffCmd struct{} // Name is returns the name of the Dolt cli command. This is what is used on the command line to invoke the command @@ -220,6 +358,15 @@ func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseE return errhand.BuildDError("invalid output format: %s", f).Build() } + filterValue, hasFilter := apr.GetValue(cli.FilterParam) + if hasFilter { + filter := newDiffTypeFilter(filterValue) + if !filter.isValid() { + return errhand.BuildDError("invalid filter: %s. Valid values are: %s, %s, %s, %s (or %s)", + filterValue, diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeRenamed, diff.DiffTypeDropped, "removed").Build() + } + } + return nil } @@ -268,6 +415,9 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin displaySettings.limit, _ = apr.GetInt(cli.LimitParam) displaySettings.where = apr.GetValueOrDefault(cli.WhereParam, "") + filterValue := apr.GetValueOrDefault(cli.FilterParam, diff.DiffTypeAll) + displaySettings.filter = newDiffTypeFilter(filterValue) + return displaySettings } @@ -670,13 +820,13 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex tableName = fromTable } case fromTable == "": - diffType = "added" + diffType = diff.DiffTypeAdded tableName = toTable case toTable == "": - diffType = "dropped" + diffType = diff.DiffTypeDropped tableName = fromTable case fromTable != "" && toTable != "" && fromTable != toTable: - diffType = "renamed" + diffType = diff.DiffTypeRenamed tableName = toTable default: return nil, fmt.Errorf("error: unexpected schema diff case: fromTable='%s', toTable='%s'", fromTable, toTable) @@ -738,13 +888,13 @@ func getDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Context, fro } switch summary.DiffType { - case "dropped": + case diff.DiffTypeDropped: summary.TableName = summary.FromTableName - case "added": + case diff.DiffTypeAdded: summary.TableName = summary.ToTableName - case "renamed": + case diff.DiffTypeRenamed: summary.TableName = summary.ToTableName - case "modified": + case diff.DiffTypeModified: summary.TableName = summary.FromTableName default: return nil, fmt.Errorf("error: unexpected diff type '%s'", summary.DiffType) @@ -816,6 +966,16 @@ func diffUserTables(queryist cli.Queryist, sqlCtx *sql.Context, dArgs *diffArgs) continue } + // Apply table-level filtering based on diff type + if dArgs.filter != nil && dArgs.filter.filters != nil { + // For data-only changes (no schema/rename), always let them through for row-level filtering + isDataOnlyChange := !delta.SchemaChange && !delta.IsRename() && delta.DataChange + + if !isDataOnlyChange && !dArgs.filter.shouldInclude(delta.DiffType) { + continue // Skip this table + } + } + if strings.HasPrefix(delta.ToTableName.Name, diff.DBPrefix) { verr := diffDatabase(queryist, sqlCtx, delta, dArgs, dw) if verr != nil { @@ -1110,7 +1270,7 @@ func diffUserTable( fromTable := tableSummary.FromTableName toTable := tableSummary.ToTableName - if dArgs.diffParts&NameOnlyDiff == 0 { + if dArgs.diffParts&NameOnlyDiff == 0 && !shouldUseLazyHeader(dArgs, tableSummary) { // TODO: schema names err := dw.BeginTable(tableSummary.FromTableName.Name, tableSummary.ToTableName.Name, tableSummary.IsAdd(), tableSummary.IsDrop()) if err != nil { @@ -1446,11 +1606,27 @@ func diffRows( } // We always instantiate a RowWriter in case the diffWriter needs it to close off any work from schema output - rowWriter, err := dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, unionSch) + var rowWriter diff.SqlRowDiffWriter + realWriter, err := dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, unionSch) if err != nil { return errhand.VerboseErrorFromError(err) } + if shouldUseLazyHeader(dArgs, tableSummary) { + // Wrap with lazy writer to delay BeginTable until first row write + onFirstWrite := func() error { + return dw.BeginTable( + tableSummary.FromTableName.Name, + tableSummary.ToTableName.Name, + tableSummary.IsAdd(), + tableSummary.IsDrop(), + ) + } + rowWriter = newLazyRowWriter(realWriter, onFirstWrite) + } else { + rowWriter = realWriter + } + // can't diff if !diffable { // TODO: this messes up some structured output if the user didn't redirect it @@ -1708,6 +1884,13 @@ func writeDiffResults( return err } + // Apply row-level filtering based on diff type + if dArgs.filter != nil { + if shouldSkipRow(dArgs.filter, oldRow.RowDiff) || shouldSkipRow(dArgs.filter, newRow.RowDiff) { + continue + } + } + if dArgs.skinny { var filteredOldRow, filteredNewRow diff.RowDiff for i, changeType := range newRow.ColDiffs { diff --git a/go/cmd/dolt/commands/diff_filter_test.go b/go/cmd/dolt/commands/diff_filter_test.go new file mode 100644 index 0000000000..f3d705f2f0 --- /dev/null +++ b/go/cmd/dolt/commands/diff_filter_test.go @@ -0,0 +1,548 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commands + +import ( + "context" + "strings" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" +) + +func TestDiffTypeFilter_IsValid(t *testing.T) { + tests := []struct { + name string + filterBy string + want bool + }{ + {"valid: added", diff.DiffTypeAdded, true}, + {"valid: modified", diff.DiffTypeModified, true}, + {"valid: removed", diff.DiffTypeDropped, true}, + {"valid: all", diff.DiffTypeAll, true}, + {"invalid: empty string with nil filter", "", true}, // nil filter is valid + {"invalid: random string", "invalid", false}, + {"invalid: uppercase", "ADDED", false}, + {"invalid: typo addedd", "addedd", false}, + {"invalid: plural adds", "adds", false}, + {"invalid: typo modifiedd", "modifiedd", false}, + {"invalid: typo removedd", "removedd", false}, + {"invalid: insert instead of added", "insert", false}, + {"invalid: update instead of modified", "update", false}, + {"invalid: delete instead of removed", "delete", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + df := newDiffTypeFilter(tt.filterBy) + got := df.isValid() + if got != tt.want { + t.Errorf("isValid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDiffTypeFilter_ShouldInclude(t *testing.T) { + tests := []struct { + name string + filterType string + checkType string + want bool + }{ + // Testing with filter=added + {"filter=added, check added", diff.DiffTypeAdded, diff.DiffTypeAdded, true}, + {"filter=added, check modified", diff.DiffTypeAdded, diff.DiffTypeModified, false}, + {"filter=added, check removed", diff.DiffTypeAdded, diff.DiffTypeDropped, false}, + + // Testing with filter=modified + {"filter=modified, check added", diff.DiffTypeModified, diff.DiffTypeAdded, false}, + {"filter=modified, check modified", diff.DiffTypeModified, diff.DiffTypeModified, true}, + {"filter=modified, check removed", diff.DiffTypeModified, diff.DiffTypeDropped, false}, + + // Testing with filter=dropped + {"filter=dropped, check added", diff.DiffTypeDropped, diff.DiffTypeAdded, false}, + {"filter=dropped, check modified", diff.DiffTypeDropped, diff.DiffTypeModified, false}, + {"filter=dropped, check dropped", diff.DiffTypeDropped, diff.DiffTypeDropped, true}, + {"filter=dropped, check renamed", diff.DiffTypeDropped, diff.DiffTypeRenamed, false}, + + // Testing with filter=renamed + {"filter=renamed, check added", diff.DiffTypeRenamed, diff.DiffTypeAdded, false}, + {"filter=renamed, check modified", diff.DiffTypeRenamed, diff.DiffTypeModified, false}, + {"filter=renamed, check dropped", diff.DiffTypeRenamed, diff.DiffTypeDropped, false}, + {"filter=renamed, check renamed", diff.DiffTypeRenamed, diff.DiffTypeRenamed, true}, + + // Testing with "removed" alias (should map to dropped) + {"filter=removed (alias), check dropped", "removed", diff.DiffTypeDropped, true}, + {"filter=removed (alias), check added", "removed", diff.DiffTypeAdded, false}, + {"filter=removed (alias), check renamed", "removed", diff.DiffTypeRenamed, false}, + + // Testing with filter=all + {"filter=all, check added", diff.DiffTypeAll, diff.DiffTypeAdded, true}, + {"filter=all, check modified", diff.DiffTypeAll, diff.DiffTypeModified, true}, + {"filter=all, check removed", diff.DiffTypeAll, diff.DiffTypeDropped, true}, + + // Testing with empty filter (nil filters map) + {"filter=empty, check added", "", diff.DiffTypeAdded, true}, + {"filter=empty, check modified", "", diff.DiffTypeModified, true}, + {"filter=empty, check removed", "", diff.DiffTypeDropped, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + df := newDiffTypeFilter(tt.filterType) + got := df.shouldInclude(tt.checkType) + if got != tt.want { + t.Errorf("shouldInclude(%s) = %v, want %v", tt.checkType, got, tt.want) + } + }) + } +} + +func TestDiffTypeFilter_ConsistencyAcrossMethods(t *testing.T) { + // Test that filter=all returns true for all diff types + t.Run("filter=all returns true for all types", func(t *testing.T) { + df := newDiffTypeFilter(diff.DiffTypeAll) + + if !df.shouldInclude(diff.DiffTypeAdded) { + t.Error("filter=all should include added") + } + if !df.shouldInclude(diff.DiffTypeDropped) { + t.Error("filter=all should include removed") + } + if !df.shouldInclude(diff.DiffTypeModified) { + t.Error("filter=all should include modified") + } + }) + + // Test that each specific filter only returns true for its type + t.Run("filter=added only includes added", func(t *testing.T) { + df := newDiffTypeFilter(diff.DiffTypeAdded) + + if !df.shouldInclude(diff.DiffTypeAdded) { + t.Error("filter=added should include added") + } + if df.shouldInclude(diff.DiffTypeDropped) { + t.Error("filter=added should not include removed") + } + if df.shouldInclude(diff.DiffTypeModified) { + t.Error("filter=added should not include modified") + } + }) + + t.Run("filter=dropped only includes removed", func(t *testing.T) { + df := newDiffTypeFilter(diff.DiffTypeDropped) + + if df.shouldInclude(diff.DiffTypeAdded) { + t.Error("filter=dropped should not include added") + } + if !df.shouldInclude(diff.DiffTypeDropped) { + t.Error("filter=dropped should include removed") + } + if df.shouldInclude(diff.DiffTypeModified) { + t.Error("filter=dropped should not include modified") + } + }) + + t.Run("filter=modified only includes modified", func(t *testing.T) { + df := newDiffTypeFilter(diff.DiffTypeModified) + + if df.shouldInclude(diff.DiffTypeAdded) { + t.Error("filter=modified should not include added") + } + if df.shouldInclude(diff.DiffTypeDropped) { + t.Error("filter=modified should not include removed") + } + if !df.shouldInclude(diff.DiffTypeModified) { + t.Error("filter=modified should include modified") + } + }) +} + +func TestDiffTypeFilter_InvalidFilterBehavior(t *testing.T) { + // Test that invalid filters return false for isValid + invalidFilters := []string{"invalid", "ADDED", "addedd", "delete"} + + for _, filterValue := range invalidFilters { + t.Run("invalid filter: "+filterValue, func(t *testing.T) { + df := newDiffTypeFilter(filterValue) + + if df.isValid() { + t.Errorf("Filter %s should be invalid", filterValue) + } + }) + } +} + +func TestFilterConstants(t *testing.T) { + // Test that filter constants have expected values + tests := []struct { + name string + constant string + expected string + }{ + {"DiffTypeAdded value", diff.DiffTypeAdded, "added"}, + {"DiffTypeModified value", diff.DiffTypeModified, "modified"}, + {"DiffTypeDropped value", diff.DiffTypeDropped, "dropped"}, + {"DiffTypeAll value", diff.DiffTypeAll, "all"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.constant != tt.expected { + t.Errorf("Expected %s = %s, got %s", tt.name, tt.expected, tt.constant) + } + }) + } +} + +func TestFilterConstants_AreUnique(t *testing.T) { + // Test that all filter constants are unique + constants := []string{diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeDropped, diff.DiffTypeAll} + seen := make(map[string]bool) + + for _, c := range constants { + if seen[c] { + t.Errorf("Duplicate filter constant value: %s", c) + } + seen[c] = true + } + + if len(seen) != 4 { + t.Errorf("Expected 4 unique filter constants, got %d", len(seen)) + } +} + +func TestFilterConstants_AreLowercase(t *testing.T) { + // Test that filter constants are lowercase (convention) + constants := []string{diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeDropped, diff.DiffTypeAll} + + for _, c := range constants { + if c != strings.ToLower(c) { + t.Errorf("Filter constant %s should be lowercase", c) + } + } +} + +func TestShouldUseLazyHeader(t *testing.T) { + tests := []struct { + name string + filterType string + schemaChange bool + isRename bool + expectedResult bool + }{ + { + name: "use lazy: filter active, data-only change", + filterType: diff.DiffTypeAdded, + schemaChange: false, + isRename: false, + expectedResult: true, + }, + { + name: "don't use lazy: no filter", + filterType: "", + schemaChange: false, + isRename: false, + expectedResult: false, + }, + { + name: "don't use lazy: filter is all", + filterType: diff.DiffTypeAll, + schemaChange: false, + isRename: false, + expectedResult: false, + }, + { + name: "don't use lazy: schema changed", + filterType: diff.DiffTypeModified, + schemaChange: true, + isRename: false, + expectedResult: false, + }, + { + name: "don't use lazy: table renamed", + filterType: diff.DiffTypeDropped, + schemaChange: false, + isRename: true, + expectedResult: false, + }, + { + name: "don't use lazy: schema changed AND renamed", + filterType: diff.DiffTypeAdded, + schemaChange: true, + isRename: true, + expectedResult: false, + }, + { + name: "use lazy: filter=modified, data-only", + filterType: diff.DiffTypeModified, + schemaChange: false, + isRename: false, + expectedResult: true, + }, + { + name: "use lazy: filter=dropped, data-only", + filterType: diff.DiffTypeDropped, + schemaChange: false, + isRename: false, + expectedResult: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var filter *diffTypeFilter + if tt.filterType != "" { + filter = newDiffTypeFilter(tt.filterType) + } + + dArgs := &diffArgs{ + diffDisplaySettings: &diffDisplaySettings{ + filter: filter, + }, + } + tableSummary := diff.TableDeltaSummary{ + SchemaChange: tt.schemaChange, + } + // Create a mock rename by setting different from/to names + if tt.isRename { + tableSummary.FromTableName = doltdb.TableName{Name: "old_table"} + tableSummary.ToTableName = doltdb.TableName{Name: "new_table"} + } else { + tableSummary.FromTableName = doltdb.TableName{Name: "table"} + tableSummary.ToTableName = doltdb.TableName{Name: "table"} + } + + result := shouldUseLazyHeader(dArgs, tableSummary) + + if result != tt.expectedResult { + t.Errorf("%s: expected %v, got %v", tt.name, tt.expectedResult, result) + } + }) + } +} + +// mockDiffWriter is a test implementation of diffWriter +type mockDiffWriter struct { + beginTableCalled bool + beginTableError error +} + +func (m *mockDiffWriter) BeginTable(_ /* fromTableName */, _ /* toTableName */ string, _ /* isAdd */, _ /* isDrop */ bool) error { + m.beginTableCalled = true + return m.beginTableError +} + +func (m *mockDiffWriter) WriteTableSchemaDiff(_ /* fromTableInfo */, _ /* toTableInfo */ *diff.TableInfo, _ /* tds */ diff.TableDeltaSummary) error { + return nil +} + +func (m *mockDiffWriter) WriteEventDiff(_ /* ctx */ context.Context, _ /* eventName */, _ /* oldDefn */, _ /* newDefn */ string) error { + return nil +} + +func (m *mockDiffWriter) WriteTriggerDiff(_ /* ctx */ context.Context, _ /* triggerName */, _ /* oldDefn */, _ /* newDefn */ string) error { + return nil +} + +func (m *mockDiffWriter) WriteViewDiff(_ /* ctx */ context.Context, _ /* viewName */, _ /* oldDefn */, _ /* newDefn */ string) error { + return nil +} + +func (m *mockDiffWriter) WriteTableDiffStats(_ /* diffStats */ []diffStatistics, _ /* oldColLen */, _ /* newColLen */ int, _ /* areTablesKeyless */ bool) error { + return nil +} + +func (m *mockDiffWriter) RowWriter(_ /* fromTableInfo */, _ /* toTableInfo */ *diff.TableInfo, _ /* tds */ diff.TableDeltaSummary, _ /* unionSch */ sql.Schema) (diff.SqlRowDiffWriter, error) { + return &mockRowWriter{}, nil +} + +func (m *mockDiffWriter) Close(_ /* ctx */ context.Context) error { + return nil +} + +// mockRowWriter is a test implementation of SqlRowDiffWriter +type mockRowWriter struct { + writeCalled bool + closeCalled bool +} + +func (m *mockRowWriter) WriteRow(_ /* ctx */ *sql.Context, _ /* row */ sql.Row, _ /* diffType */ diff.ChangeType, _ /* colDiffTypes */ []diff.ChangeType) error { + m.writeCalled = true + return nil +} + +func (m *mockRowWriter) WriteCombinedRow(_ /* ctx */ *sql.Context, _ /* oldRow */, _ /* newRow */ sql.Row, _ /* mode */ diff.Mode) error { + m.writeCalled = true + return nil +} + +func (m *mockRowWriter) Close(_ /* ctx */ context.Context) error { + m.closeCalled = true + return nil +} + +func TestLazyRowWriter_NoRowsWritten(t *testing.T) { + mockDW := &mockDiffWriter{} + realWriter := &mockRowWriter{} + + beginTableCalled := false + onFirstWrite := func() error { + beginTableCalled = true + return mockDW.BeginTable("fromTable", "toTable", false, false) + } + + lazyWriter := newLazyRowWriter(realWriter, onFirstWrite) + + // Close without writing any rows + err := lazyWriter.Close(context.Background()) + if err != nil { + t.Fatalf("Close() returned error: %v", err) + } + + // BeginTable should NEVER have been called + if beginTableCalled { + t.Error("BeginTable() was called even though no rows were written - should have been lazy!") + } +} + +func TestLazyRowWriter_RowsWritten(t *testing.T) { + mockDW := &mockDiffWriter{} + realWriter := &mockRowWriter{} + + onFirstWrite := func() error { + return mockDW.BeginTable("fromTable", "toTable", false, false) + } + + lazyWriter := newLazyRowWriter(realWriter, onFirstWrite) + + // Write a row + ctx := sql.NewEmptyContext() + err := lazyWriter.WriteRow(ctx, sql.Row{}, diff.Added, []diff.ChangeType{}) + if err != nil { + t.Fatalf("WriteRow() returned error: %v", err) + } + + // BeginTable should have been called on first write + if !mockDW.beginTableCalled { + t.Error("BeginTable() was NOT called after writing a row - should have been initialized!") + } + + // Close + err = lazyWriter.Close(context.Background()) + if err != nil { + t.Fatalf("Close() returned error: %v", err) + } +} + +func TestLazyRowWriter_CombinedRowsWritten(t *testing.T) { + mockDW := &mockDiffWriter{} + realWriter := &mockRowWriter{} + + onFirstWrite := func() error { + return mockDW.BeginTable("fromTable", "toTable", false, false) + } + + lazyWriter := newLazyRowWriter(realWriter, onFirstWrite) + + // Write a combined row + ctx := sql.NewEmptyContext() + err := lazyWriter.WriteCombinedRow(ctx, sql.Row{}, sql.Row{}, diff.ModeRow) + if err != nil { + t.Fatalf("WriteCombinedRow() returned error: %v", err) + } + + // BeginTable should have been called on first write + if !mockDW.beginTableCalled { + t.Error("BeginTable() was NOT called after writing combined row - should have been initialized!") + } +} + +func TestLazyRowWriter_InitializedOnlyOnce(t *testing.T) { + callCount := 0 + mockDW := &mockDiffWriter{} + realWriter := &mockRowWriter{} + + onFirstWrite := func() error { + callCount++ + return mockDW.BeginTable("fromTable", "toTable", false, false) + } + + lazyWriter := newLazyRowWriter(realWriter, onFirstWrite) + + ctx := sql.NewEmptyContext() + + // Write multiple rows + for i := 0; i < 5; i++ { + err := lazyWriter.WriteRow(ctx, sql.Row{}, diff.Added, []diff.ChangeType{}) + if err != nil { + t.Fatalf("WriteRow() %d returned error: %v", i, err) + } + } + + // BeginTable should have been called exactly ONCE (on first write only) + if callCount != 1 { + t.Errorf("BeginTable() called %d times, expected exactly 1", callCount) + } +} + +func TestShouldSkipRow(t *testing.T) { + tests := []struct { + name string + filterType string + rowChangeType diff.ChangeType + expectedResult bool + }{ + {"filter=added, row=Added", diff.DiffTypeAdded, diff.Added, false}, + {"filter=added, row=Dropped", diff.DiffTypeAdded, diff.Removed, true}, + {"filter=added, row=ModifiedOld", diff.DiffTypeAdded, diff.ModifiedOld, true}, + {"filter=added, row=ModifiedNew", diff.DiffTypeAdded, diff.ModifiedNew, true}, + + {"filter=dropped, row=Added", diff.DiffTypeDropped, diff.Added, true}, + {"filter=dropped, row=Dropped", diff.DiffTypeDropped, diff.Removed, false}, + {"filter=dropped, row=ModifiedOld", diff.DiffTypeDropped, diff.ModifiedOld, true}, + + {"filter=modified, row=Added", diff.DiffTypeModified, diff.Added, true}, + {"filter=modified, row=Dropped", diff.DiffTypeModified, diff.Removed, true}, + {"filter=modified, row=ModifiedOld", diff.DiffTypeModified, diff.ModifiedOld, false}, + {"filter=modified, row=ModifiedNew", diff.DiffTypeModified, diff.ModifiedNew, false}, + + {"filter=all, row=Added", diff.DiffTypeAll, diff.Added, false}, + {"filter=all, row=Dropped", diff.DiffTypeAll, diff.Removed, false}, + {"filter=all, row=ModifiedOld", diff.DiffTypeAll, diff.ModifiedOld, false}, + + {"nil filter, row=Added", "", diff.Added, false}, + {"nil filter, row=Dropped", "", diff.Removed, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var filter *diffTypeFilter + if tt.filterType != "" { + filter = newDiffTypeFilter(tt.filterType) + } + + result := shouldSkipRow(filter, tt.rowChangeType) + + if result != tt.expectedResult { + t.Errorf("expected %v, got %v", tt.expectedResult, result) + } + }) + } +} diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index 6d39ccad69..a2b0dcc73e 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -480,17 +480,17 @@ func calculateMergeStats(queryist cli.Queryist, sqlCtx *sql.Context, mergeStats if strings.HasPrefix(summary.TableName.Name, diff.DBPrefix) { continue } - if summary.DiffType == "added" { + if summary.DiffType == diff.DiffTypeAdded { allUnmodified = false mergeStats[summary.TableName.Name] = &merge.MergeStats{ Operation: merge.TableAdded, } - } else if summary.DiffType == "dropped" { + } else if summary.DiffType == diff.DiffTypeDropped { allUnmodified = false mergeStats[summary.TableName.Name] = &merge.MergeStats{ Operation: merge.TableRemoved, } - } else if summary.DiffType == "modified" || summary.DiffType == "renamed" { + } else if summary.DiffType == diff.DiffTypeModified || summary.DiffType == diff.DiffTypeRenamed { allUnmodified = false mergeStats[summary.TableName.Name] = &merge.MergeStats{ Operation: merge.TableModified, diff --git a/go/libraries/doltcore/diff/table_deltas.go b/go/libraries/doltcore/diff/table_deltas.go index 201b555011..1809297220 100644 --- a/go/libraries/doltcore/diff/table_deltas.go +++ b/go/libraries/doltcore/diff/table_deltas.go @@ -39,6 +39,17 @@ const ( RemovedTable ) +// Filter type constants for diff filtering. +// These correspond to the string values used in the --filter flag and +// are stored in TableDeltaSummary.DiffType field. +const ( + DiffTypeAdded = "added" + DiffTypeModified = "modified" + DiffTypeRenamed = "renamed" + DiffTypeDropped = "dropped" + DiffTypeAll = "all" +) + const DBPrefix = "__DATABASE__" type TableInfo struct { @@ -97,6 +108,22 @@ func (tds TableDeltaSummary) IsRename() bool { return tds.FromTableName != tds.ToTableName } +// ChangeTypeToDiffType converts a row-level ChangeType to a table-level DiffType string. +// This allows row-level filtering to use the same DiffType infrastructure as table-level filtering. +func ChangeTypeToDiffType(ct ChangeType) string { + switch ct { + case Added: + return DiffTypeAdded + case Removed: + return DiffTypeDropped + case ModifiedOld, ModifiedNew: + // Both ModifiedOld and ModifiedNew represent the same logical change: modified + return DiffTypeModified + default: + return "" + } +} + // GetStagedUnstagedTableDeltas represents staged and unstaged changes as TableDelta slices. func GetStagedUnstagedTableDeltas(ctx context.Context, roots doltdb.Roots) (staged, unstaged []TableDelta, err error) { staged, err = GetTableDeltas(ctx, roots.Head, roots.Staged) @@ -689,7 +716,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error) FromTableName: td.FromName, DataChange: dataChange, SchemaChange: true, - DiffType: "dropped", + DiffType: DiffTypeDropped, }, nil } @@ -700,7 +727,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error) ToTableName: td.ToName, DataChange: dataChange, SchemaChange: true, - DiffType: "added", + DiffType: DiffTypeAdded, }, nil } @@ -712,7 +739,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error) ToTableName: td.ToName, DataChange: dataChange, SchemaChange: true, - DiffType: "renamed", + DiffType: DiffTypeRenamed, }, nil } @@ -727,7 +754,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error) ToTableName: td.ToName, DataChange: dataChange, SchemaChange: schemaChange, - DiffType: "modified", + DiffType: DiffTypeModified, }, nil } diff --git a/go/libraries/doltcore/doltdb/system_table.go b/go/libraries/doltcore/doltdb/system_table.go index 60ac6325c9..9701487a00 100644 --- a/go/libraries/doltcore/doltdb/system_table.go +++ b/go/libraries/doltcore/doltdb/system_table.go @@ -155,12 +155,13 @@ func GeneratedSystemTableNames() []string { GetTableOfTablesWithViolationsName(), GetCommitsTableName(), GetCommitAncestorsTableName(), - GetStatusTableName(), GetRemotesTableName(), GetHelpTableName(), GetBackupsTableName(), GetStashesTableName(), GetBranchActivityTableName(), + // [dtables.StatusTable] now uses [adapters.DoltTableAdapterRegistry] in its constructor for Doltgres. + StatusTableName, } } @@ -367,11 +368,6 @@ var GetSchemaConflictsTableName = func() string { return SchemaConflictsTableName } -// GetStatusTableName returns the status system table name. -var GetStatusTableName = func() string { - return StatusTableName -} - // GetTagsTableName returns the tags table name var GetTagsTableName = func() string { return TagsTableName diff --git a/go/libraries/doltcore/sqle/adapters/table.go b/go/libraries/doltcore/sqle/adapters/table.go new file mode 100644 index 0000000000..de5fa3c209 --- /dev/null +++ b/go/libraries/doltcore/sqle/adapters/table.go @@ -0,0 +1,85 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapters + +import ( + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" +) + +// TableAdapter provides a hook for extensions to customize or wrap table implementations. For example, this allows +// libraries like Doltgres to intercept system table creation and apply type conversions, schema modifications, or other +// customizations without modifying the core Dolt implementation for their compatibility. +type TableAdapter interface { + // NewTable creates or wraps a system table. The function receives all necessary parameters to construct the table + // and can either build it from scratch or call the default Dolt constructor and wrap it. + NewTable(ctx *sql.Context, tableName string, dDb *doltdb.DoltDB, workingSet *doltdb.WorkingSet, rootsProvider env.RootsProvider[*sql.Context]) sql.Table + + // TableName returns the preferred name for the adapter's table. This allows extensions to rename tables while + // preserving the underlying implementation. For example, Doltgres uses "status" while Dolt uses "dolt_status", + // enabling cleaner Postgres-style naming. + TableName() string +} + +var DoltTableAdapterRegistry = newDoltTableAdapterRegistry() + +// doltTableAdapterRegistry is a Dolt table name to TableAdapter map. Integrators populate this registry during package +// initialization, and it's intended to be read-only thereafter. The registry links with existing Dolt system tables to +// allow them to be resolved and evaluated to integrator's version and internal aliases (integrators' Dolt table name +// keys). +type doltTableAdapterRegistry struct { + Adapters map[string]TableAdapter + internalAliases map[string]string +} + +// newDoltTableAdapterRegistry constructs Dolt table adapter registry with empty internal alias and adapter maps. +func newDoltTableAdapterRegistry() *doltTableAdapterRegistry { + return &doltTableAdapterRegistry{ + Adapters: make(map[string]TableAdapter), + internalAliases: make(map[string]string), + } +} + +// AddAdapter maps |doltTableName| to an |adapter| in the Dolt table adapter registry, with optional |internalAliases|. +func (as *doltTableAdapterRegistry) AddAdapter(doltTableName string, adapter TableAdapter, internalAliases ...string) { + for _, alias := range internalAliases { + as.internalAliases[alias] = doltTableName + } + as.Adapters[doltTableName] = adapter +} + +// GetAdapter gets a Dolt TableAdapter mapped to |name|, which can be the dolt table name or internal alias. +func (as *doltTableAdapterRegistry) GetAdapter(name string) (TableAdapter, bool) { + adapter, ok := as.Adapters[name] + if !ok { + name = as.internalAliases[name] + adapter, ok = as.Adapters[name] + } + + return adapter, ok +} + +// NormalizeName normalizes |name| if it's an internal alias of the underlying Dolt table name. If no match is found, +// |name| is returned as-is. +func (as *doltTableAdapterRegistry) NormalizeName(name string) string { + doltTableName, ok := as.internalAliases[name] + if !ok { + return name + } + + return doltTableName +} diff --git a/go/libraries/doltcore/sqle/adapters/table_test.go b/go/libraries/doltcore/sqle/adapters/table_test.go new file mode 100644 index 0000000000..1d63792c68 --- /dev/null +++ b/go/libraries/doltcore/sqle/adapters/table_test.go @@ -0,0 +1,64 @@ +package adapters + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" +) + +type mockAdapter struct { + name string +} + +func (m mockAdapter) NewTable(_ *sql.Context, _ string, _ *doltdb.DoltDB, _ *doltdb.WorkingSet, _ env.RootsProvider[*sql.Context]) sql.Table { + return nil +} + +func (m mockAdapter) TableName() string { + return m.name +} + +func TestDoltTableAdapterRegistry(t *testing.T) { + registry := newDoltTableAdapterRegistry() + + statusAdapter := mockAdapter{name: "status"} + logAdapter := mockAdapter{name: "log"} + + registry.AddAdapter(doltdb.StatusTableName, statusAdapter, "status") + registry.AddAdapter(doltdb.LogTableName, logAdapter, "log") + + t.Run("GetAdapter", func(t *testing.T) { + adapter, ok := registry.GetAdapter("dolt_status") + require.True(t, ok) + require.Equal(t, "status", adapter.TableName()) + + adapter, ok = registry.GetAdapter("status") + require.True(t, ok) + require.Equal(t, "status", adapter.TableName()) + + _, ok = registry.GetAdapter("unknown_alias") + require.False(t, ok) + + _, ok = registry.GetAdapter("dolt_unknown") + require.False(t, ok) + }) + + t.Run("NormalizeName", func(t *testing.T) { + normalized := registry.NormalizeName("status") + require.Equal(t, "dolt_status", normalized) + + normalized = registry.NormalizeName("log") + require.Equal(t, "dolt_log", normalized) + + normalized = registry.NormalizeName("dolt_status") + require.Equal(t, "dolt_status", normalized) + + normalized = registry.NormalizeName("unknown_table") + require.Equal(t, "unknown_table", normalized) + }) +} diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 3f13f91e71..081933df61 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -44,6 +44,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/rebase" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/adapters" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dprocedures" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables" @@ -621,7 +622,7 @@ func (db Database) getTableInsensitiveWithRoot(ctx *sql.Context, head *doltdb.Co var dt sql.Table found := false tname := doltdb.TableName{Name: lwrName, Schema: db.schemaName} - switch lwrName { + switch adapters.DoltTableAdapterRegistry.NormalizeName(lwrName) { case doltdb.GetLogTableName(), doltdb.LogTableName: isDoltgresSystemTable, err := resolve.IsDoltgresSystemTable(ctx, tname, root) if err != nil { @@ -750,7 +751,7 @@ func (db Database) getTableInsensitiveWithRoot(ctx *sql.Context, head *doltdb.Co if !resolve.UseSearchPath || isDoltgresSystemTable { dt, found = dtables.NewCommitAncestorsTable(ctx, db.Name(), lwrName, db.ddb), true } - case doltdb.GetStatusTableName(), doltdb.StatusTableName: + case doltdb.StatusTableName: isDoltgresSystemTable, err := resolve.IsDoltgresSystemTable(ctx, tname, root) if err != nil { return nil, false, err diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_backup.go b/go/libraries/doltcore/sqle/dprocedures/dolt_backup.go index 00199214a1..4e0f4f8a9d 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_backup.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_backup.go @@ -79,16 +79,8 @@ func doltBackup(ctx *sql.Context, args ...string) (sql.RowIter, error) { return nil, err } - if sqlserver.RunningInServerMode() { - // TODO(elianddb): DoltgreSQL needs an auth handler for stored procedures, i.e. AuthType_CALL, but for now we use - // this. dolt_backup already requires admin privilege on GMS due to its potentially destructive nature. - privileges, counter := ctx.GetPrivilegeSet() - if counter == 0 || !privileges.Has(sql.PrivilegeType_Super) { - return nil, sql.ErrPrivilegeCheckFailed.New(ctx.Session.Client().User) - } - if apr.ContainsAny(cli.AwsParams...) { - return nil, fmt.Errorf("AWS parameters are unavailable when running in server mode") - } + if sqlserver.RunningInServerMode() && apr.ContainsAny(cli.AwsParams...) { + return nil, fmt.Errorf("AWS parameters are unavailable when running in server mode") } doltSess := dsess.DSessFromSess(ctx.Session) diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_purge_dropped_databases.go b/go/libraries/doltcore/sqle/dprocedures/dolt_purge_dropped_databases.go index 63919e0941..e899868291 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_purge_dropped_databases.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_purge_dropped_databases.go @@ -27,11 +27,6 @@ func doltPurgeDroppedDatabases(ctx *sql.Context, args ...string) (sql.RowIter, e return nil, fmt.Errorf("dolt_purge_dropped_databases does not take any arguments") } - // Only allow admins to purge dropped databases - if err := checkDoltPurgeDroppedDatabasesPrivs(ctx); err != nil { - return nil, err - } - doltSession := dsess.DSessFromSess(ctx.Session) err := doltSession.Provider().PurgeDroppedDatabases(ctx) if err != nil { @@ -40,18 +35,3 @@ func doltPurgeDroppedDatabases(ctx *sql.Context, args ...string) (sql.RowIter, e return rowToIter(int64(cmdSuccess)), nil } - -// checkDoltPurgeDroppedDatabasesPrivs returns an error if the user requesting to purge dropped databases -// does not have SUPER access. Since this is a permanent and destructive operation, we restrict it to admins, -// even though the SUPER privilege has been deprecated, since there isn't another appropriate global privilege. -func checkDoltPurgeDroppedDatabasesPrivs(ctx *sql.Context) error { - privs, counter := ctx.GetPrivilegeSet() - if counter == 0 { - return fmt.Errorf("unable to check user privileges for dolt_purge_dropped_databases procedure") - } - if privs.Has(sql.PrivilegeType_Super) == false { - return sql.ErrPrivilegeCheckFailed.New(ctx.Session.Client().User) - } - - return nil -} diff --git a/go/libraries/doltcore/sqle/dtables/status_table.go b/go/libraries/doltcore/sqle/dtables/status_table.go index 2e7673e8c8..8213544981 100644 --- a/go/libraries/doltcore/sqle/dtables/status_table.go +++ b/go/libraries/doltcore/sqle/dtables/status_table.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/adapters" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" ) @@ -61,20 +62,12 @@ func (st StatusTable) String() string { return st.tableName } -func getDoltStatusSchema(tableName string) sql.Schema { - return []*sql.Column{ - {Name: "table_name", Type: types.Text, Source: tableName, PrimaryKey: true, Nullable: false}, - {Name: "staged", Type: types.Boolean, Source: tableName, PrimaryKey: true, Nullable: false}, - {Name: "status", Type: types.Text, Source: tableName, PrimaryKey: true, Nullable: false}, - } -} - -// GetDoltStatusSchema returns the schema of the dolt_status system table. This is used -// by Doltgres to update the dolt_status schema using Doltgres types. -var GetDoltStatusSchema = getDoltStatusSchema - func (st StatusTable) Schema() sql.Schema { - return GetDoltStatusSchema(st.tableName) + return []*sql.Column{ + {Name: "table_name", Type: types.Text, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false}, + {Name: "staged", Type: types.Boolean, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false}, + {Name: "status", Type: types.Text, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false}, + } } func (st StatusTable) Collation() sql.CollationID { @@ -89,8 +82,19 @@ func (st StatusTable) PartitionRows(context *sql.Context, _ sql.Partition) (sql. return newStatusItr(context, &st) } -// NewStatusTable creates a StatusTable -func NewStatusTable(_ *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table { +// NewStatusTable creates a new StatusTable using either an integrators' [adapters.TableAdapter] or the +// NewStatusTableWithNoAdapter constructor (the default implementation provided by Dolt). +func NewStatusTable(ctx *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table { + adapter, ok := adapters.DoltTableAdapterRegistry.GetAdapter(tableName) + if ok { + return adapter.NewTable(ctx, tableName, ddb, ws, rp) + } + + return NewStatusTableWithNoAdapter(ctx, tableName, ddb, ws, rp) +} + +// NewStatusTableWithNoAdapter returns a new StatusTable. +func NewStatusTableWithNoAdapter(_ *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table { return &StatusTable{ tableName: tableName, ddb: ddb, diff --git a/go/libraries/doltcore/sqle/resolve/system_tables.go b/go/libraries/doltcore/sqle/resolve/system_tables.go index c37ecfc0af..3d8c777795 100755 --- a/go/libraries/doltcore/sqle/resolve/system_tables.go +++ b/go/libraries/doltcore/sqle/resolve/system_tables.go @@ -19,6 +19,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/adapters" ) // GetGeneratedSystemTables returns table names of all generated system tables. @@ -26,15 +27,19 @@ func GetGeneratedSystemTables(ctx context.Context, root doltdb.RootValue) ([]dol s := doltdb.NewTableNameSet(nil) // Depending on whether the search path is used, the generated system tables will either be in the dolt namespace - // or the empty (default) namespace - if !UseSearchPath { - for _, t := range doltdb.GeneratedSystemTableNames() { - s.Add(doltdb.TableName{Name: t}) + // or the empty (default) namespace. + for _, tableName := range doltdb.GeneratedSystemTableNames() { + adapter, ok := adapters.DoltTableAdapterRegistry.Adapters[tableName] + if ok { + tableName = adapter.TableName() } - } else { - for _, t := range doltdb.GeneratedSystemTableNames() { - s.Add(doltdb.TableName{Name: t, Schema: doltdb.DoltNamespace}) + + tableUnique := doltdb.TableName{Name: tableName} + if UseSearchPath { + tableUnique.Schema = doltdb.DoltNamespace } + + s.Add(tableUnique) } schemas, err := root.GetDatabaseSchemas(ctx) diff --git a/integration-tests/bats/diff.bats b/integration-tests/bats/diff.bats index a36eccf6c1..e920fd918b 100644 --- a/integration-tests/bats/diff.bats +++ b/integration-tests/bats/diff.bats @@ -2245,3 +2245,223 @@ EOF [[ "$output" =~ "dolt_tests" ]] || false [[ "$output" =~ "updated description" ]] || false } + +@test "diff: --filter option filters by diff type" { + dolt sql -q "create table t(pk int primary key, val int)" + dolt add . + dolt commit -m "create table" + + # Test filter with table addition + run dolt diff HEAD~1 --filter=modified + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 --filter=added + [ $status -eq 0 ] + [[ $output =~ 'diff --dolt a/t b/t' ]] || false + [[ $output =~ 'added table' ]] || false + + # Test filter with row inserts + dolt sql -q "INSERT INTO t VALUES (1, 10)" + dolt sql -q "INSERT INTO t VALUES (2, 10)" + dolt sql -q "INSERT INTO t VALUES (3, 10)" + dolt add . + dolt commit -m "add initial rows" + + run dolt diff HEAD~1 --filter=modified + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [ "${lines[0]}" = 'INSERT INTO `t` (`pk`,`val`) VALUES (1,10);' ] + [ "${lines[1]}" = 'INSERT INTO `t` (`pk`,`val`) VALUES (2,10);' ] + [ "${lines[2]}" = 'INSERT INTO `t` (`pk`,`val`) VALUES (3,10);' ] + + # Test filter with row updates + dolt sql -q "UPDATE t SET val=12 where pk=1" + dolt add . + dolt commit -m "update row" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=modified + [ "${lines[0]}" = 'UPDATE `t` SET `val`=12 WHERE `pk`=1;' ] + + # Test filter with row deletes + dolt sql -q "DELETE from t where pk=1" + + dolt add . && dolt commit -m "delete row" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 --filter=modified + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [ "${lines[0]}" = 'DELETE FROM `t` WHERE `pk`=1;' ] + + # Test filter with schema changes - add column + dolt sql -q "ALTER TABLE t ADD val2 int" + + dolt add . && dolt commit -m "add a col" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=modified + [ $status -eq 0 ] + [ "${lines[0]}" = 'ALTER TABLE `t` ADD `val2` int;' ] + + # Test filter with schema changes - modify column type + dolt sql -q "ALTER TABLE t MODIFY COLUMN val2 varchar(255)" + + dolt add . && dolt commit -m "change datatype of column" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=modified + [ $status -eq 0 ] + [ "${lines[0]}" = 'ALTER TABLE `t` MODIFY COLUMN `val2` varchar(255);' ] + + # Test filter with schema changes - rename column + dolt sql -q "ALTER TABLE t RENAME COLUMN val2 TO val3" + + dolt add . && dolt commit -m "rename column" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=modified + [ $status -eq 0 ] + [ "${lines[0]}" = 'ALTER TABLE `t` RENAME COLUMN `val2` TO `val3`;' ] + + # Test filter with schema changes - drop column + dolt sql -q "ALTER TABLE t DROP COLUMN val3" + + dolt add . && dolt commit -m "drop column" + + run dolt diff HEAD~1 -r sql --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=removed + [ $status -eq 0 ] + [[ $output = '' ]] || false + + run dolt diff HEAD~1 -r sql --filter=modified + [ $status -eq 0 ] + [ "${lines[0]}" = 'ALTER TABLE `t` DROP `val3`;' ] +} + +@test "diff: --filter with invalid value returns error" { + dolt sql -q "create table t(pk int primary key)" + dolt add . && dolt commit -m "create table" + + run dolt diff HEAD~1 --filter=invalid + [ $status -eq 1 ] + [[ $output =~ "invalid filter" ]] || false +} + +@test "diff: --filter=renamed filters to only renamed tables" { + dolt sql -q "create table t(pk int primary key, val int)" + dolt sql -q "INSERT INTO t VALUES (1, 10)" + dolt add . && dolt commit -m "create table with data" + + # Rename the table + dolt sql -q "RENAME TABLE t TO t_renamed" + dolt add . && dolt commit -m "rename table" + + # filter=renamed should show the renamed table (shows different from/to names) + run dolt diff HEAD~1 --filter=renamed + [ $status -eq 0 ] + [[ $output =~ 'diff --dolt a/t b/t_renamed' ]] || false + [[ $output =~ '--- a/t' ]] || false + [[ $output =~ '+++ b/t_renamed' ]] || false + + # filter=added should not show the renamed table + run dolt diff HEAD~1 --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + # filter=modified should not show the renamed table + run dolt diff HEAD~1 --filter=modified + [ $status -eq 0 ] + [[ $output = '' ]] || false + + # filter=dropped should not show the renamed table + run dolt diff HEAD~1 --filter=dropped + [ $status -eq 0 ] + [[ $output = '' ]] || false +} + +@test "diff: --filter=dropped filters to only dropped tables" { + dolt sql -q "create table t(pk int primary key, val int)" + dolt sql -q "INSERT INTO t VALUES (1, 10)" + dolt add . && dolt commit -m "create table with data" + + # Drop the table + dolt sql -q "DROP TABLE t" + dolt add . && dolt commit -m "drop table" + + # filter=dropped should show the dropped table + run dolt diff HEAD~1 --filter=dropped + [ $status -eq 0 ] + [[ $output =~ 'diff --dolt a/t b/t' ]] || false + [[ $output =~ 'deleted table' ]] || false + + # filter=removed (alias for dropped) should also show the dropped table + run dolt diff HEAD~1 --filter=removed + [ $status -eq 0 ] + [[ $output =~ 'diff --dolt a/t b/t' ]] || false + [[ $output =~ 'deleted table' ]] || false + + # filter=added should not show the dropped table + run dolt diff HEAD~1 --filter=added + [ $status -eq 0 ] + [[ $output = '' ]] || false + + # filter=modified should not show the dropped table + run dolt diff HEAD~1 --filter=modified + [ $status -eq 0 ] + [[ $output = '' ]] || false + + # filter=renamed should not show the dropped table + run dolt diff HEAD~1 --filter=renamed + [ $status -eq 0 ] + [[ $output = '' ]] || false +}