diff --git a/go/libraries/doltcore/doltdb/root_val.go b/go/libraries/doltcore/doltdb/root_val.go index ed3b595294..502f89d14e 100644 --- a/go/libraries/doltcore/doltdb/root_val.go +++ b/go/libraries/doltcore/doltdb/root_val.go @@ -654,16 +654,15 @@ func TablesWithDataConflicts(ctx context.Context, root RootValue) ([]string, err } // TablesWithConstraintViolations returns all tables that have constraint violations. -func TablesWithConstraintViolations(ctx context.Context, root RootValue) ([]string, error) { - // TODO: schema name - names, err := root.GetTableNames(ctx, DefaultSchemaName) +func TablesWithConstraintViolations(ctx context.Context, root RootValue) ([]TableName, error) { + names, err := UnionTableNames(ctx, root) if err != nil { return nil, err } - violating := make([]string, 0, len(names)) + violating := make([]TableName, 0, len(names)) for _, name := range names { - tbl, _, err := root.GetTable(ctx, TableName{Name: name}) + tbl, _, err := root.GetTable(ctx, name) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/env/actions/commit.go b/go/libraries/doltcore/env/actions/commit.go index 0bebcf21f5..6063be3e46 100644 --- a/go/libraries/doltcore/env/actions/commit.go +++ b/go/libraries/doltcore/env/actions/commit.go @@ -89,7 +89,7 @@ func GetCommitStaged( return nil, err } if len(violatesConstraints) > 0 { - return nil, NewTblHasConstraintViolations(violatesConstraints) + return nil, NewTblHasConstraintViolations(doltdb.FlattenTableNames(violatesConstraints)) } if ws.MergeActive() { diff --git a/go/libraries/doltcore/merge/merge.go b/go/libraries/doltcore/merge/merge.go index 9452acbd5c..41f14c3c05 100644 --- a/go/libraries/doltcore/merge/merge.go +++ b/go/libraries/doltcore/merge/merge.go @@ -538,10 +538,12 @@ func GetMergeArtifactStatus(ctx context.Context, working *doltdb.WorkingSet) (as return as, err } - as.ConstraintViolationsTables, err = doltdb.TablesWithConstraintViolations(ctx, working.WorkingRoot()) + violations, err := doltdb.TablesWithConstraintViolations(ctx, working.WorkingRoot()) if err != nil { - return as, err + return ArtifactStatus{}, err } + + as.ConstraintViolationsTables = doltdb.FlattenTableNames(violations) return } diff --git a/go/libraries/doltcore/sqle/dsess/transactions.go b/go/libraries/doltcore/sqle/dsess/transactions.go index 90eab235e9..f9711a7668 100644 --- a/go/libraries/doltcore/sqle/dsess/transactions.go +++ b/go/libraries/doltcore/sqle/dsess/transactions.go @@ -648,7 +648,7 @@ func (tx *DoltTransaction) validateWorkingSetForCommit(ctx *sql.Context, working violations := make([]string, len(badTbls)) for i, name := range badTbls { - tbl, _, err := workingRoot.GetTable(ctx, doltdb.TableName{Name: name}) + tbl, _, err := workingRoot.GetTable(ctx, name) if err != nil { return err } diff --git a/go/libraries/doltcore/sqle/dtables/merge_status_table.go b/go/libraries/doltcore/sqle/dtables/merge_status_table.go index 58833538ad..fc6f7e2ced 100644 --- a/go/libraries/doltcore/sqle/dtables/merge_status_table.go +++ b/go/libraries/doltcore/sqle/dtables/merge_status_table.go @@ -104,7 +104,7 @@ func newMergeStatusItr(ctx context.Context, ws *doltdb.WorkingSet) (*MergeStatus } unmergedTblNames := set.NewStrSet(inConflict) - unmergedTblNames.Add(tblsWithViolations...) + unmergedTblNames.Add(doltdb.FlattenTableNames(tblsWithViolations)...) unmergedTblNames.Add(schConflicts...) var sourceCommitSpecStr *string diff --git a/go/libraries/doltcore/sqle/dtables/status_table.go b/go/libraries/doltcore/sqle/dtables/status_table.go index 9b006dc83d..fc885c535a 100644 --- a/go/libraries/doltcore/sqle/dtables/status_table.go +++ b/go/libraries/doltcore/sqle/dtables/status_table.go @@ -100,9 +100,9 @@ type statusTableRow struct { status string } -func contains(str string, strs []string) bool { - for _, s := range strs { - if s == str { +func containsTableName(name string, names []doltdb.TableName) bool { + for _, s := range names { + if s.Name == name { return true } } @@ -136,7 +136,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) { for _, tbl := range cvTables { rows = append(rows, statusTableRow{ - tableName: tbl, + tableName: tbl.Name, status: "constraint violation", }) } @@ -176,7 +176,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) { if doltdb.IsFullTextTable(tblName) { continue } - if contains(tblName, cvTables) { + if containsTableName(tblName, cvTables) { continue } rows = append(rows, statusTableRow{ @@ -190,7 +190,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) { if doltdb.IsFullTextTable(tblName) { continue } - if contains(tblName, cvTables) { + if containsTableName(tblName, cvTables) { continue } rows = append(rows, statusTableRow{ diff --git a/go/libraries/doltcore/sqle/dtables/table_of_tables_with_violations.go b/go/libraries/doltcore/sqle/dtables/table_of_tables_with_violations.go index 21ee8e2bea..5bb1dae246 100644 --- a/go/libraries/doltcore/sqle/dtables/table_of_tables_with_violations.go +++ b/go/libraries/doltcore/sqle/dtables/table_of_tables_with_violations.go @@ -15,6 +15,7 @@ package dtables import ( + "bytes" "fmt" "io" @@ -78,9 +79,9 @@ func (totwv *TableOfTablesWithViolations) Partitions(ctx *sql.Context) (sql.Part // PartitionRows implements the interface sql.Table. func (totwv *TableOfTablesWithViolations) PartitionRows(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { - tblName := string(part.Key()) + tblName := decodeTableName(part.Key()) var rows []sql.Row - tbl, _, ok, err := doltdb.GetTableInsensitive(ctx, totwv.root, doltdb.TableName{Name: tblName}) + tbl, ok, err := totwv.root.GetTable(ctx, tblName) if err != nil { return nil, err } @@ -91,7 +92,7 @@ func (totwv *TableOfTablesWithViolations) PartitionRows(ctx *sql.Context, part s if err != nil { return nil, err } - rows = append(rows, sql.Row{tblName, n}) + rows = append(rows, sql.Row{tblName.Name, n}) return sql.RowsToRowIter(rows...), nil } @@ -119,11 +120,24 @@ func (t *tableOfTablesPartitionIter) Close(context *sql.Context) error { } // tableOfTablesPartition is a partition returned from tableOfTablesPartitionIter, which is just a table name. -type tableOfTablesPartition string +type tableOfTablesPartition doltdb.TableName -var _ sql.Partition = tableOfTablesPartition("") +var _ sql.Partition = tableOfTablesPartition(doltdb.TableName{}) // Key implements the interface sql.Partition. func (t tableOfTablesPartition) Key() []byte { - return []byte(t) + return encodeTableName(doltdb.TableName(t)) +} + +func encodeTableName(name doltdb.TableName) []byte { + b := bytes.Buffer{} + b.WriteString(name.Schema) + b.WriteByte(0) + b.WriteString(name.Name) + return b.Bytes() +} + +func decodeTableName(b []byte) doltdb.TableName { + parts := bytes.SplitN(b, []byte{0}, 2) + return doltdb.TableName{Schema: string(parts[0]), Name: string(parts[1])} }