Merge pull request #8343 from dolthub/zachmu/schema-commit

support for schemas in various version control operations
This commit is contained in:
Zach Musgrave
2024-09-11 15:47:57 -07:00
committed by GitHub
39 changed files with 761 additions and 230 deletions
@@ -30,7 +30,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/types"
)
@@ -85,7 +84,12 @@ func (cmd VerifyConstraintsCmd) Exec(ctx context.Context, commandStr string, arg
return commands.HandleVErrAndExitCode(errhand.BuildDError("Unable to read table names.").AddCause(err).Build(), nil)
}
}
tableSet := set.NewStrSet(tableNames)
tableSet := doltdb.NewTableNameSet(nil)
// TODO: schema names
for _, tableName := range tableNames {
tableSet.Add(doltdb.TableName{Name: tableName})
}
comparingRoot, err := dEnv.HeadRoot(ctx)
if err != nil {
+106
View File
@@ -0,0 +1,106 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package diff
import (
"context"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/utils/set"
)
// DatabaseSchemaDelta represents a change in the set of database schemas between two roots
type DatabaseSchemaDelta struct {
FromName string
ToName string
}
func (d DatabaseSchemaDelta) IsAdd() bool {
return d.FromName == "" && d.ToName != ""
}
func (d DatabaseSchemaDelta) IsDrop() bool {
return d.FromName != "" && d.ToName == ""
}
func (d DatabaseSchemaDelta) CurName() string {
if d.ToName != "" {
return d.ToName
}
return d.FromName
}
// GetDatabaseSchemaDeltas returns a list of DatabaseSchemaDelta objects representing the changes in database schemas
func GetDatabaseSchemaDeltas(ctx context.Context, fromRoot, toRoot doltdb.RootValue) ([]DatabaseSchemaDelta, error) {
fromNames, err := getDatabaseSchemaNames(ctx, fromRoot)
if err != nil {
return nil, err
}
toNames, err := getDatabaseSchemaNames(ctx, toRoot)
if err != nil {
return nil, err
}
// short circuit for common case where there are no schemas (dolt)
if fromNames.Size() == 0 && toNames.Size() == 0 {
return nil, nil
}
// generate a diff for each schema name that's present in one root but not the other
var deltas []DatabaseSchemaDelta
fromNames.Iterate(func(name string) (cont bool) {
if !toNames.Contains(name) {
deltas = append(deltas, DatabaseSchemaDelta{FromName: name})
}
return true
})
toNames.Iterate(func(name string) (cont bool) {
if !fromNames.Contains(name) {
deltas = append(deltas, DatabaseSchemaDelta{ToName: name})
}
return true
})
return deltas, nil
}
// GetStagedUnstagedDatabaseSchemaDeltas represents staged and unstaged changes as DatabaseSchemaDelta slices.
func GetStagedUnstagedDatabaseSchemaDeltas(ctx context.Context, roots doltdb.Roots) (staged, unstaged []DatabaseSchemaDelta, err error) {
staged, err = GetDatabaseSchemaDeltas(ctx, roots.Head, roots.Staged)
if err != nil {
return nil, nil, err
}
unstaged, err = GetDatabaseSchemaDeltas(ctx, roots.Staged, roots.Working)
if err != nil {
return nil, nil, err
}
return staged, unstaged, nil
}
func getDatabaseSchemaNames(ctx context.Context, root doltdb.RootValue) (*set.StrSet, error) {
dbSchemaNames := set.NewEmptyStrSet()
dbSchemas, err := root.GetDatabaseSchemas(ctx)
if err != nil {
return nil, err
}
for _, dbSchema := range dbSchemas {
dbSchemaNames.Add(dbSchema.Name)
}
return dbSchemaNames, nil
}
@@ -490,7 +490,7 @@ OuterLoop:
// and any keys in the collection are unresolved. A "dirty resolution" is performed, which matches the column names to
// tags, and then a standard tag comparison is performed. If a table or column is not in the map, then the foreign key
// is ignored.
func (fkc *ForeignKeyCollection) GetMatchingKey(fk ForeignKey, allSchemas map[string]schema.Schema, matchUnresolvedKeyToResolvedKey bool) (ForeignKey, bool) {
func (fkc *ForeignKeyCollection) GetMatchingKey(fk ForeignKey, allSchemas map[TableName]schema.Schema, matchUnresolvedKeyToResolvedKey bool) (ForeignKey, bool) {
if !fk.IsResolved() {
// The given foreign key is unresolved, so we only look for matches on unresolved keys
OuterLoopUnresolved:
@@ -543,11 +543,13 @@ OuterLoopResolved:
len(fk.ReferencedTableColumns) != len(existingFk.UnresolvedFKDetails.ReferencedTableColumns) {
continue
}
tblSch, ok := allSchemas[existingFk.TableName]
// TODO: schema name
tblSch, ok := allSchemas[TableName{Name: existingFk.TableName}]
if !ok {
continue
}
refTblSch, ok := allSchemas[existingFk.ReferencedTableName]
// TODO: schema name
refTblSch, ok := allSchemas[TableName{Name: existingFk.ReferencedTableName}]
if !ok {
continue
}
+3 -4
View File
@@ -455,11 +455,10 @@ func GetExistingColumns(
return existingCols, nil
}
func GetAllSchemas(ctx context.Context, root RootValue) (map[string]schema.Schema, error) {
m := make(map[string]schema.Schema)
func GetAllSchemas(ctx context.Context, root RootValue) (map[TableName]schema.Schema, error) {
m := make(map[TableName]schema.Schema)
err := root.IterTables(ctx, func(name TableName, table *Table, sch schema.Schema) (stop bool, err error) {
// TODO: schema name
m[name.Name] = sch
m[name] = sch
return false, nil
})
+4 -4
View File
@@ -256,7 +256,7 @@ func (t *Table) clearConflicts(ctx context.Context) (*Table, error) {
}
// GetConflictSchemas returns the merge conflict schemas for this table.
func (t *Table) GetConflictSchemas(ctx context.Context, tblName string) (base, sch, mergeSch schema.Schema, err error) {
func (t *Table) GetConflictSchemas(ctx context.Context, tblName TableName) (base, sch, mergeSch schema.Schema, err error) {
if t.Format() == types.Format_DOLT {
return t.getProllyConflictSchemas(ctx, tblName)
}
@@ -267,7 +267,7 @@ func (t *Table) GetConflictSchemas(ctx context.Context, tblName string) (base, s
// The conflict schema is implicitly determined based on the first conflict in the artifacts table.
// For now, we will enforce that all conflicts in the artifacts table must have the same schema set (base, ours, theirs).
// In the future, we may be able to display conflicts in a way that allows different conflict schemas to coexist.
func (t *Table) getProllyConflictSchemas(ctx context.Context, tblName string) (base, sch, mergeSch schema.Schema, err error) {
func (t *Table) getProllyConflictSchemas(ctx context.Context, tblName TableName) (base, sch, mergeSch schema.Schema, err error) {
arts, err := t.GetArtifacts(ctx)
if err != nil {
return nil, nil, nil, err
@@ -331,12 +331,12 @@ func (t *Table) getProllyConflictSchemas(ctx context.Context, tblName string) (b
return baseSch, ourSch, theirSch, nil
}
func tableFromRootIsh(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, h hash.Hash, tblName string) (*Table, bool, error) {
func tableFromRootIsh(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, h hash.Hash, tblName TableName) (*Table, bool, error) {
rv, err := LoadRootValueFromRootIshAddr(ctx, vrw, ns, h)
if err != nil {
return nil, false, err
}
tbl, ok, err := rv.GetTable(ctx, TableName{Name: tblName})
tbl, ok, err := rv.GetTable(ctx, tblName)
if err != nil {
return nil, false, err
}
+2 -2
View File
@@ -137,8 +137,8 @@ type MergeState struct {
type SchemaConflict struct {
ToSch, FromSch schema.Schema
ToFks, FromFks []ForeignKey
ToParentSchemas map[string]schema.Schema
FromParentSchemas map[string]schema.Schema
ToParentSchemas map[TableName]schema.Schema
FromParentSchemas map[TableName]schema.Schema
toTbl, fromTbl *Table
}
+8 -3
View File
@@ -47,13 +47,13 @@ func GetCommitStaged(
return nil, datas.ErrEmptyCommitMessage
}
staged, notStaged, err := diff.GetStagedUnstagedTableDeltas(ctx, roots)
stagedTables, notStaged, err := diff.GetStagedUnstagedTableDeltas(ctx, roots)
if err != nil {
return nil, err
}
var stagedTblNames []doltdb.TableName
for _, td := range staged {
for _, td := range stagedTables {
n := td.ToName
if td.IsDrop() {
n = td.FromName
@@ -61,7 +61,12 @@ func GetCommitStaged(
stagedTblNames = append(stagedTblNames, n)
}
isEmpty := len(staged) == 0
stagedSchemas, _, err := diff.GetStagedUnstagedDatabaseSchemaDeltas(ctx, roots)
if err != nil {
return nil, err
}
isEmpty := len(stagedTables) == 0 && len(stagedSchemas) == 0
allowEmpty := ws.MergeActive() || props.AllowEmpty || props.Amend
if isEmpty && props.SkipEmpty {
+5 -6
View File
@@ -25,7 +25,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/store/datas"
)
@@ -81,7 +80,7 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
return nil, doltdb.Roots{}, err
}
for _, name := range staged {
delete(untracked, name)
delete(untracked, doltdb.TableName{Name: name})
}
newWkRoot := roots.Head
@@ -102,7 +101,7 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
}
for name := range untracked {
tname, tbl, exists, err := resolve.Table(ctx, roots.Working, name)
tbl, exists, err := roots.Working.GetTable(ctx, name)
if err != nil {
return nil, doltdb.Roots{}, err
}
@@ -110,7 +109,7 @@ func resetHardTables(ctx *sql.Context, dbData env.DbData, cSpecStr string, roots
return nil, doltdb.Roots{}, fmt.Errorf("untracked table %s does not exist in working set", name)
}
newWkRoot, err = newWkRoot.PutTable(ctx, tname, tbl)
newWkRoot, err = newWkRoot.PutTable(ctx, name, tbl)
if err != nil {
return nil, doltdb.Roots{}, fmt.Errorf("failed to write table back to database: %s", err)
}
@@ -334,11 +333,11 @@ func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dr
// mapColumnTags takes a map from table name to schema.Schema and generates
// a map from column tags to table names (see RootValue.GetAllSchemas).
func mapColumnTags(tables map[string]schema.Schema) (m map[uint64]string) {
func mapColumnTags(tables map[doltdb.TableName]schema.Schema) (m map[uint64]string) {
m = make(map[uint64]string, len(tables))
for tbl, sch := range tables {
for _, tag := range sch.GetAllCols().Tags {
m[tag] = tbl
m[tag] = tbl.Name
}
}
return
+53 -1
View File
@@ -21,6 +21,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
)
func StageTables(ctx context.Context, roots doltdb.Roots, tbls []doltdb.TableName, filterIgnoredTables bool) (doltdb.Roots, error) {
@@ -45,7 +46,58 @@ func StageAllTables(ctx context.Context, roots doltdb.Roots, filterIgnoredTables
return doltdb.Roots{}, err
}
return StageTables(ctx, roots, tbls, filterIgnoredTables)
roots, err = StageTables(ctx, roots, tbls, filterIgnoredTables)
if err != nil {
return doltdb.Roots{}, err
}
roots, err = StageAllSchemas(ctx, roots)
if err != nil {
return doltdb.Roots{}, err
}
return roots, nil
}
func StageAllSchemas(ctx context.Context, roots doltdb.Roots) (doltdb.Roots, error) {
newStaged, err := MoveAllSchemasBetweenRoots(ctx, roots.Working, roots.Staged)
if err != nil {
return doltdb.Roots{}, err
}
roots.Staged = newStaged
return roots, nil
}
// MoveAllSchemasBetweenRoots copies all schemas from the src RootValue to the dest RootValue.
func MoveAllSchemasBetweenRoots(ctx context.Context, src, dest doltdb.RootValue) (doltdb.RootValue, error) {
srcSchemaNames, err := getDatabaseSchemaNames(ctx, src)
if err != nil {
return nil, err
}
if srcSchemaNames.Size() == 0 {
return dest, nil
}
destSchemaNames, err := getDatabaseSchemaNames(ctx, dest)
if err != nil {
return nil, err
}
srcSchemaNames.Iterate(func(schemaName string) (cont bool) {
if !destSchemaNames.Contains(schemaName) {
dest, err = dest.CreateDatabaseSchema(ctx, schema.DatabaseSchema{
Name: schemaName,
})
if err != nil {
return false
}
}
return true
})
return dest, nil
}
func StageDatabase(ctx context.Context, roots doltdb.Roots) (doltdb.Roots, error) {
+69 -35
View File
@@ -17,16 +17,16 @@ package actions
import (
"context"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/utils/set"
)
// MoveTablesBetweenRoots copies tables with names in tbls from the src RootValue to the dest RootValue.
// It matches tables between roots by column tags.
func MoveTablesBetweenRoots(ctx context.Context, tbls []doltdb.TableName, src, dest doltdb.RootValue) (doltdb.RootValue, error) {
tblSet := doltdb.NewTableNameSet(tbls)
tablesToMove := doltdb.NewTableNameSet(tbls)
stagedFKs, err := dest.GetForeignKeyCollection(ctx)
if err != nil {
@@ -38,51 +38,28 @@ func MoveTablesBetweenRoots(ctx context.Context, tbls []doltdb.TableName, src, d
return nil, err
}
// We want to include all Full-Text tables for every move
for _, td := range tblDeltas {
var ftIndexes []schema.Index
if tblSet.Contains(td.ToName) && td.ToSch.Indexes().ContainsFullTextIndex() {
for _, idx := range td.ToSch.Indexes().AllIndexes() {
if !idx.IsFullText() {
continue
}
ftIndexes = append(ftIndexes, idx)
}
} else if tblSet.Contains(td.FromName) && td.FromSch.Indexes().ContainsFullTextIndex() {
for _, idx := range td.FromSch.Indexes().AllIndexes() {
if !idx.IsFullText() {
continue
}
ftIndexes = append(ftIndexes, idx)
}
}
for _, ftIndex := range ftIndexes {
props := ftIndex.FullTextProperties()
tblSet.Add(
doltdb.TableName{Name: props.ConfigTable},
doltdb.TableName{Name: props.PositionTable},
doltdb.TableName{Name: props.DocCountTable},
doltdb.TableName{Name: props.GlobalCountTable},
doltdb.TableName{Name: props.RowCountTable},
)
}
addFullTextTablesToDelta(tblDeltas, tablesToMove)
destSchemaNames, err := getDatabaseSchemaNames(ctx, dest)
if err != nil {
return nil, err
}
tblsToDrop := doltdb.NewTableNameSet(nil)
for _, td := range tblDeltas {
if td.IsDrop() {
if !tblSet.Contains(td.FromName) {
if !tablesToMove.Contains(td.FromName) {
continue
}
tblsToDrop.Add(td.FromName)
stagedFKs.RemoveKeys(td.FromFks...)
tblsToDrop.Add(td.FromName)
}
}
for _, td := range tblDeltas {
if !td.IsDrop() {
if !tblSet.Contains(td.ToName) {
if !tablesToMove.Contains(td.ToName) {
continue
}
@@ -95,6 +72,18 @@ func MoveTablesBetweenRoots(ctx context.Context, tbls []doltdb.TableName, src, d
}
}
// edge case: if we're moving a table with a schema name to a root that doesn't have that schema,
// we implicitly create that schema on the destination root in addition to updating the list of schemas
if td.ToName.Schema != "" && !destSchemaNames.Contains(td.ToName.Schema) {
dest, err = dest.CreateDatabaseSchema(ctx, schema.DatabaseSchema{
Name: td.ToName.Schema,
})
destSchemaNames.Add(td.ToName.Schema)
if err != nil {
return nil, err
}
}
dest, err = dest.PutTable(ctx, td.ToName, td.ToTable)
if err != nil {
return nil, err
@@ -122,6 +111,51 @@ func MoveTablesBetweenRoots(ctx context.Context, tbls []doltdb.TableName, src, d
return dest, nil
}
func getDatabaseSchemaNames(ctx context.Context, dest doltdb.RootValue) (*set.StrSet, error) {
dbSchemaNames := set.NewEmptyStrSet()
dbSchemas, err := dest.GetDatabaseSchemas(ctx)
if err != nil {
return nil, err
}
for _, dbSchema := range dbSchemas {
dbSchemaNames.Add(dbSchema.Name)
}
return dbSchemaNames, nil
}
// addFullTextTablesToDelta adds the full text tables associated any full text indexes in the table deltas to the tableset provided
func addFullTextTablesToDelta(tblDeltas []diff.TableDelta, tblSet *doltdb.TableNameSet) {
for _, td := range tblDeltas {
var ftIndexes []schema.Index
if tblSet.Contains(td.ToName) && td.ToSch.Indexes().ContainsFullTextIndex() {
for _, idx := range td.ToSch.Indexes().AllIndexes() {
if !idx.IsFullText() {
continue
}
ftIndexes = append(ftIndexes, idx)
}
} else if tblSet.Contains(td.FromName) && td.FromSch.Indexes().ContainsFullTextIndex() {
for _, idx := range td.FromSch.Indexes().AllIndexes() {
if !idx.IsFullText() {
continue
}
ftIndexes = append(ftIndexes, idx)
}
}
for _, ftIndex := range ftIndexes {
props := ftIndex.FullTextProperties()
// TODO: schema names
tblSet.Add(
doltdb.TableName{Name: props.ConfigTable},
doltdb.TableName{Name: props.PositionTable},
doltdb.TableName{Name: props.DocCountTable},
doltdb.TableName{Name: props.GlobalCountTable},
doltdb.TableName{Name: props.RowCountTable},
)
}
}
}
func validateTablesExist(ctx context.Context, currRoot doltdb.RootValue, unknown []doltdb.TableName) error {
var notExist []doltdb.TableName
for _, tbl := range unknown {
@@ -52,7 +52,7 @@ type ConflictReader struct {
}
// NewConflictReader returns a new conflict reader for a given table
func NewConflictReader(ctx context.Context, tbl *doltdb.Table, tblName string) (*ConflictReader, error) {
func NewConflictReader(ctx context.Context, tbl *doltdb.Table, tblName doltdb.TableName) (*ConflictReader, error) {
base, sch, mergeSch, err := tbl.GetConflictSchemas(ctx, tblName) // tblName unused by old storage format
if err != nil {
return nil, err
+28 -10
View File
@@ -23,8 +23,8 @@ import (
goerrors "gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
)
@@ -139,8 +139,8 @@ func (r Result) CountOfTablesWithConstraintViolations() int {
return count
}
func SchemaConflictTableNames(sc []SchemaConflict) (tables []string) {
tables = make([]string, len(sc))
func SchemaConflictTableNames(sc []SchemaConflict) (tables []doltdb.TableName) {
tables = make([]doltdb.TableName, len(sc))
for i := range sc {
tables[i] = sc[i].TableName
}
@@ -241,12 +241,17 @@ func MergeRoots(
return nil, err
}
destSchemaNames, err := getDatabaseSchemaNames(ctx, ourRoot)
if err != nil {
return nil, err
}
// visitedTables holds all tables that were added, removed, or modified (basically not "unmodified")
visitedTables := make(map[string]struct{})
var schConflicts []SchemaConflict
for _, tblName := range tblNames {
// TODO: schema name
mergedTable, stats, err := merger.MergeTable(ctx, tblName.Name, opts, mergeOpts)
mergedTable, stats, err := merger.MergeTable(ctx, tblName, opts, mergeOpts)
if errors.Is(ErrTableDeletedAndModified, err) && doltdb.IsFullTextTable(tblName.Name) {
// If a Full-Text table was both modified and deleted, then we want to ignore the deletion.
// If there's a true conflict, then the parent table will catch the conflict.
@@ -257,7 +262,7 @@ func MergeRoots(
SchemaConflicts: 1,
}
conflict := SchemaConflict{
TableName: tblName.Name,
TableName: tblName,
ModifyDeleteConflict: true,
}
if !mergeOpts.KeepSchemaConflicts {
@@ -289,6 +294,18 @@ func MergeRoots(
if mergedTable.table != nil {
tblToStats[tblName.Name] = stats
// edge case: if we're merging a table with a schema name to a root that doesn't have that schema,
// we implicitly create that schema on the destination root in addition to updating the list of schemas
if tblName.Schema != "" && !destSchemaNames.Contains(tblName.Schema) {
mergedRoot, err = mergedRoot.CreateDatabaseSchema(ctx, schema.DatabaseSchema{
Name: tblName.Schema,
})
if err != nil {
return nil, err
}
destSchemaNames.Add(tblName.Schema)
}
mergedRoot, err = mergedRoot.PutTable(ctx, tblName, mergedTable.table)
if err != nil {
return nil, err
@@ -296,15 +313,16 @@ func MergeRoots(
continue
}
newRootHasTable, err := mergedRoot.HasTable(ctx, tblName)
mergedRootHasTable, err := mergedRoot.HasTable(ctx, tblName)
if err != nil {
return nil, err
}
if newRootHasTable {
if mergedRootHasTable {
// Merge root deleted this table
tblToStats[tblName.Name] = &MergeStats{Operation: TableRemoved}
// TODO: drop schemas as necessary
mergedRoot, err = mergedRoot.RemoveTables(ctx, false, false, tblName)
if err != nil {
return nil, err
@@ -346,9 +364,9 @@ func MergeRoots(
return nil, err
}
var tableSet *set.StrSet = nil
var tableSet *doltdb.TableNameSet = nil
if mergeOpts.RecordViolationsForTables != nil {
tableSet = set.NewCaseInsensitiveStrSet(nil)
tableSet = doltdb.NewCaseInsensitiveTableNameSet(nil)
for tableName, _ := range mergeOpts.RecordViolationsForTables {
tableSet.Add(tableName)
}
@@ -72,7 +72,7 @@ func mergeNomsTable(ctx *sql.Context, tm *TableMerger, mergedSch schema.Schema,
}
}
updatedTblEditor, err := editor.NewTableEditor(ctx, mergeTbl, mergedSch, tm.name, opts)
updatedTblEditor, err := editor.NewTableEditor(ctx, mergeTbl, mergedSch, tm.name.Name, opts)
if err != nil {
return nil, nil, err
}
@@ -92,7 +92,7 @@ func mergeNomsTable(ctx *sql.Context, tm *TableMerger, mergedSch schema.Schema,
return nil, nil, err
}
resultTbl, cons, stats, err := mergeNomsTableData(ctx, vrw, tm.name, mergedSch, rows, mergeRows, durable.NomsMapFromIndex(ancRows), updatedTblEditor)
resultTbl, cons, stats, err := mergeNomsTableData(ctx, vrw, tm.name.Name, mergedSch, rows, mergeRows, durable.NomsMapFromIndex(ancRows), updatedTblEditor)
if err != nil {
return nil, nil, err
}
@@ -88,7 +88,7 @@ func mergeProllySecondaryIndexes(
mergedIndex, err := func() (durable.Index, error) {
if forceIndexRebuild || rebuildRequired {
return buildIndex(ctx, tm.vrw, tm.ns, finalSch, index, mergedM, artifacts, tm.rightSrc, tm.name)
return buildIndex(ctx, tm.vrw, tm.ns, finalSch, index, mergedM, artifacts, tm.rightSrc, tm.name.Name)
}
return durable.IndexFromProllyMap(left), nil
}()
@@ -134,11 +134,11 @@ func buildIndex(
mergedMap, err := creation.BuildUniqueProllyIndex(ctx, vrw, ns, postMergeSchema, tblName, index, m, func(ctx context.Context, existingKey, newKey val.Tuple) (err error) {
eK := getPKFromSecondaryKey(kb, p, pkMapping, existingKey)
nK := getPKFromSecondaryKey(kb, p, pkMapping, newKey)
err = replaceUniqueKeyViolation(ctx, artEditor, m, eK, kd, theirRootIsh, vInfo, tblName)
err = replaceUniqueKeyViolation(ctx, artEditor, m, eK, theirRootIsh, vInfo)
if err != nil {
return err
}
err = replaceUniqueKeyViolation(ctx, artEditor, m, nK, kd, theirRootIsh, vInfo, tblName)
err = replaceUniqueKeyViolation(ctx, artEditor, m, nK, theirRootIsh, vInfo)
if err != nil {
return err
}
@@ -52,7 +52,13 @@ var ErrUnableToMergeColumnDefaultValue = errorkinds.NewKind("unable to automatic
// table's primary index will also be rewritten. This function merges the table's artifacts (e.g. recorded
// conflicts), migrates any existing table data to the specified |mergedSch|, and merges table data from both
// sides of the merge together.
func mergeProllyTable(ctx context.Context, tm *TableMerger, mergedSch schema.Schema, mergeInfo MergeInfo, diffInfo tree.ThreeWayDiffInfo) (*doltdb.Table, *MergeStats, error) {
func mergeProllyTable(
ctx context.Context,
tm *TableMerger,
mergedSch schema.Schema,
mergeInfo MergeInfo,
diffInfo tree.ThreeWayDiffInfo,
) (*doltdb.Table, *MergeStats, error) {
mergeTbl, err := mergeTableArtifacts(ctx, tm, tm.leftTbl)
if err != nil {
return nil, nil, err
@@ -132,7 +138,7 @@ func mergeProllyTableData(ctx *sql.Context, tm *TableMerger, finalSch schema.Sch
keyless := schema.IsKeyless(tm.leftSch)
defaults, err := resolveDefaults(ctx, tm.name, finalSch, tm.leftSch)
defaults, err := resolveDefaults(ctx, tm.name.Name, finalSch, tm.leftSch)
if err != nil {
return nil, nil, err
}
@@ -374,7 +380,7 @@ func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch s
continue
}
expr, err := expranalysis.ResolveCheckExpression(ctx, tm.name, sch, check.Expression())
expr, err := expranalysis.ResolveCheckExpression(ctx, tm.name.Name, sch, check.Expression())
if err != nil {
return checkValidator{}, err
}
@@ -546,7 +552,7 @@ func newUniqValidator(ctx *sql.Context, sch schema.Schema, tm *TableMerger, vm *
}
secondary := durable.ProllyMapFromIndex(idx)
u, err := newUniqIndex(ctx, sch, tm.name, def, clustered, secondary)
u, err := newUniqIndex(ctx, sch, tm.name.Name, def, clustered, secondary)
if err != nil {
return uniqValidator{}, err
}
@@ -853,7 +859,7 @@ func newNullValidator(
return nullValidator{}, err
}
return nullValidator{
table: tm.name,
table: tm.name.Name,
final: final,
leftMap: vm.leftMapping,
rightMap: vm.rightMapping,
@@ -1089,7 +1095,7 @@ func (m *primaryMerger) merge(ctx *sql.Context, diff tree.ThreeWayDiff, sourceSc
} else {
// Remapping when there's no schema change is harmless, but slow.
if m.mergeInfo.RightNeedsRewrite {
defaults, err := resolveDefaults(ctx, m.tableMerger.name, m.finalSch, m.tableMerger.rightSch)
defaults, err := resolveDefaults(ctx, m.tableMerger.name.Name, m.finalSch, m.tableMerger.rightSch)
if err != nil {
return err
}
@@ -1127,7 +1133,7 @@ func (m *primaryMerger) merge(ctx *sql.Context, diff tree.ThreeWayDiff, sourceSc
// the merge
merged := diff.Merged
if hasStoredGeneratedColumns(m.finalSch) {
defaults, err := resolveDefaults(ctx, m.tableMerger.name, m.finalSch, m.tableMerger.rightSch)
defaults, err := resolveDefaults(ctx, m.tableMerger.name.Name, m.finalSch, m.tableMerger.rightSch)
if err != nil {
return err
}
@@ -1281,7 +1287,7 @@ func newSecondaryMerger(ctx *sql.Context, tm *TableMerger, valueMerger *valueMer
}
// Use the mergedSchema to work with the secondary indexes, to pull out row data using the right
// pri_index -> sec_index mapping.
lm, err := GetMutableSecondaryIdxsWithPending(ctx, leftSchema, mergedSchema, tm.name, ls, secondaryMergerPendingSize)
lm, err := GetMutableSecondaryIdxsWithPending(ctx, leftSchema, mergedSchema, tm.name.Name, ls, secondaryMergerPendingSize)
if err != nil {
return nil, err
}
@@ -1326,7 +1332,7 @@ func (m *secondaryMerger) merge(ctx *sql.Context, diff tree.ThreeWayDiff, leftSc
return fmt.Errorf("cannot merge keyless tables with reordered columns")
}
} else {
defaults, err := resolveDefaults(ctx, m.tableMerger.name, m.mergedSchema, m.tableMerger.rightSch)
defaults, err := resolveDefaults(ctx, m.tableMerger.name.Name, m.mergedSchema, m.tableMerger.rightSch)
if err != nil {
return err
}
@@ -1350,7 +1356,7 @@ func (m *secondaryMerger) merge(ctx *sql.Context, diff tree.ThreeWayDiff, leftSc
}
newTupleValue = tempTupleValue
if diff.Base != nil {
defaults, err := resolveDefaults(ctx, m.tableMerger.name, m.mergedSchema, m.tableMerger.ancSch)
defaults, err := resolveDefaults(ctx, m.tableMerger.name.Name, m.mergedSchema, m.tableMerger.ancSch)
if err != nil {
return err
}
@@ -2032,7 +2038,7 @@ func mergeJSON(ctx context.Context, ns tree.NodeStore, base, left, right sql.JSO
return types.JSONDocument{}, true, err
}
if cmp == 0 {
//convergent operation.
// convergent operation.
return left, false, nil
} else {
return types.JSONDocument{}, true, nil
+26 -9
View File
@@ -16,7 +16,6 @@ package merge
import (
"context"
"strings"
"github.com/dolthub/go-mysql-server/sql"
@@ -26,6 +25,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/atomicerr"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly/tree"
@@ -51,11 +51,11 @@ type MergeOpts struct {
// will have constraint violations recorded. This functionality is primarily used by the
// dolt_verify_constraints() stored procedure to allow callers to verify constraints for a
// subset of tables.
RecordViolationsForTables map[string]struct{}
RecordViolationsForTables map[doltdb.TableName]struct{}
}
type TableMerger struct {
name string
name doltdb.TableName
leftTbl *doltdb.Table
rightTbl *doltdb.Table
@@ -132,9 +132,26 @@ type MergedTable struct {
conflict SchemaConflict
}
func getDatabaseSchemaNames(ctx context.Context, dest doltdb.RootValue) (*set.StrSet, error) {
dbSchemaNames := set.NewEmptyStrSet()
dbSchemas, err := dest.GetDatabaseSchemas(ctx)
if err != nil {
return nil, err
}
for _, dbSchema := range dbSchemas {
dbSchemaNames.Add(dbSchema.Name)
}
return dbSchemaNames, nil
}
// MergeTable merges schema and table data for the table tblName.
// TODO: this code will loop infinitely when merging certain schema changes
func (rm *RootMerger) MergeTable(ctx *sql.Context, tblName string, opts editor.Options, mergeOpts MergeOpts) (*MergedTable, *MergeStats, error) {
func (rm *RootMerger) MergeTable(
ctx *sql.Context,
tblName doltdb.TableName,
opts editor.Options,
mergeOpts MergeOpts,
) (*MergedTable, *MergeStats, error) {
tm, err := rm.makeTableMerger(ctx, tblName, mergeOpts)
if err != nil {
return nil, nil, err
@@ -179,10 +196,10 @@ func (rm *RootMerger) MergeTable(ctx *sql.Context, tblName string, opts editor.O
return &MergedTable{table: tbl}, stats, nil
}
func (rm *RootMerger) makeTableMerger(ctx context.Context, tblName string, mergeOpts MergeOpts) (*TableMerger, error) {
func (rm *RootMerger) makeTableMerger(ctx context.Context, tblName doltdb.TableName, mergeOpts MergeOpts) (*TableMerger, error) {
recordViolations := true
if mergeOpts.RecordViolationsForTables != nil {
if _, ok := mergeOpts.RecordViolationsForTables[strings.ToLower(tblName)]; !ok {
if _, ok := mergeOpts.RecordViolationsForTables[tblName.ToLower()]; !ok {
recordViolations = false
}
}
@@ -199,7 +216,7 @@ func (rm *RootMerger) makeTableMerger(ctx context.Context, tblName string, merge
var err error
var leftSideTableExists, rightSideTableExists, ancTableExists bool
tm.leftTbl, leftSideTableExists, err = rm.left.GetTable(ctx, doltdb.TableName{Name: tblName})
tm.leftTbl, leftSideTableExists, err = rm.left.GetTable(ctx, tblName)
if err != nil {
return nil, err
}
@@ -209,7 +226,7 @@ func (rm *RootMerger) makeTableMerger(ctx context.Context, tblName string, merge
}
}
tm.rightTbl, rightSideTableExists, err = rm.right.GetTable(ctx, doltdb.TableName{Name: tblName})
tm.rightTbl, rightSideTableExists, err = rm.right.GetTable(ctx, tblName)
if err != nil {
return nil, err
}
@@ -241,7 +258,7 @@ func (rm *RootMerger) makeTableMerger(ctx context.Context, tblName string, merge
}
}
tm.ancTbl, ancTableExists, err = rm.anc.GetTable(ctx, doltdb.TableName{Name: tblName})
tm.ancTbl, ancTableExists, err = rm.anc.GetTable(ctx, tblName)
if err != nil {
return nil, err
}
+18 -7
View File
@@ -48,7 +48,7 @@ var ErrUnmergeableNewColumn = errorkinds.NewKind("Unable to merge new column `%s
var ErrDefaultCollationConflict = errorkinds.NewKind("Unable to merge table '%s', because its default collation setting has changed on both sides of the merge. Manually change the table's default collation setting on one of the sides of the merge and retry this merge.")
type SchemaConflict struct {
TableName string
TableName doltdb.TableName
ColConflicts []ColConflict
IdxConflicts []IdxConflict
ChkConflicts []ChkConflict
@@ -164,7 +164,12 @@ var ErrMergeWithDifferentPksFromAncestor = errorkinds.NewKind("error: cannot mer
// SchemaMerge performs a three-way merge of |ourSch|, |theirSch|, and |ancSch|, and returns: the merged schema,
// any schema conflicts identified, whether moving to the new schema requires a full table rewrite, and any
// unexpected error encountered while merging the schemas.
func SchemaMerge(ctx context.Context, format *storetypes.NomsBinFormat, ourSch, theirSch, ancSch schema.Schema, tblName string) (sch schema.Schema, sc SchemaConflict, mergeInfo MergeInfo, diffInfo tree.ThreeWayDiffInfo, err error) {
func SchemaMerge(
ctx context.Context,
format *storetypes.NomsBinFormat,
ourSch, theirSch, ancSch schema.Schema,
tblName doltdb.TableName,
) (sch schema.Schema, sc SchemaConflict, mergeInfo MergeInfo, diffInfo tree.ThreeWayDiffInfo, err error) {
// (sch - ancSch) (mergeSch - ancSch) (sch ∩ mergeSch)
sc = SchemaConflict{
TableName: tblName,
@@ -180,7 +185,7 @@ func SchemaMerge(ctx context.Context, format *storetypes.NomsBinFormat, ourSch,
}
var mergedCC *schema.ColCollection
mergedCC, sc.ColConflicts, mergeInfo, diffInfo, err = mergeColumns(tblName, format, ourSch.GetAllCols(), theirSch.GetAllCols(), ancSch.GetAllCols())
mergedCC, sc.ColConflicts, mergeInfo, diffInfo, err = mergeColumns(tblName.Name, format, ourSch.GetAllCols(), theirSch.GetAllCols(), ancSch.GetAllCols())
if err != nil {
return nil, SchemaConflict{}, mergeInfo, diffInfo, err
}
@@ -199,7 +204,7 @@ func SchemaMerge(ctx context.Context, format *storetypes.NomsBinFormat, ourSch,
return nil, sc, mergeInfo, diffInfo, err
}
sch, err = mergeTableCollation(ctx, tblName, ancSch, ourSch, theirSch, sch)
sch, err = mergeTableCollation(ctx, tblName.Name, ancSch, ourSch, theirSch, sch)
if err != nil {
return nil, sc, mergeInfo, diffInfo, err
}
@@ -285,7 +290,7 @@ func ForeignKeysMerge(ctx context.Context, mergedRoot, ourRoot, theirRoot, ancRo
}
// check for conflicts between foreign keys added on each branch since the ancestor
//TODO: figure out the best way to handle unresolved foreign keys here if one branch added an unresolved one and
// TODO: figure out the best way to handle unresolved foreign keys here if one branch added an unresolved one and
// another branch added the same one but resolved
_ = ourNewFKs.Iter(func(ourFK doltdb.ForeignKey) (stop bool, err error) {
theirFK, ok := theirNewFKs.GetByTags(ourFK.TableColumns, ourFK.ReferencedTableColumns)
@@ -905,7 +910,10 @@ func indexCollSetDifference(left, right schema.IndexCollection, cc *schema.ColCo
return d
}
func foreignKeysInCommon(ourFKs, theirFKs, ancFKs *doltdb.ForeignKeyCollection, ancSchs map[string]schema.Schema) (common *doltdb.ForeignKeyCollection, conflicts []FKConflict, err error) {
func foreignKeysInCommon(
ourFKs, theirFKs, ancFKs *doltdb.ForeignKeyCollection,
ancSchs map[doltdb.TableName]schema.Schema,
) (common *doltdb.ForeignKeyCollection, conflicts []FKConflict, err error) {
common, _ = doltdb.NewForeignKeyCollection()
err = ourFKs.Iter(func(ours doltdb.ForeignKey) (stop bool, err error) {
@@ -982,7 +990,10 @@ func foreignKeysInCommon(ourFKs, theirFKs, ancFKs *doltdb.ForeignKeyCollection,
// fkCollSetDifference returns a collection of all foreign keys that are in the given collection but not the ancestor
// collection. This is specifically for finding differences between a descendant and an ancestor, and therefore should
// not be used in the general case.
func fkCollSetDifference(fkColl, ancestorFkColl *doltdb.ForeignKeyCollection, ancSchs map[string]schema.Schema) (d *doltdb.ForeignKeyCollection, err error) {
func fkCollSetDifference(
fkColl, ancestorFkColl *doltdb.ForeignKeyCollection,
ancSchs map[doltdb.TableName]schema.Schema,
) (d *doltdb.ForeignKeyCollection, err error) {
d, _ = doltdb.NewForeignKeyCollection()
err = fkColl.Iter(func(fk doltdb.ForeignKey) (stop bool, err error) {
_, ok := ancestorFkColl.GetMatchingKey(fk, ancSchs, false)
+2 -2
View File
@@ -306,7 +306,7 @@ func TestMergeCommits(t *testing.T) {
}
opts := editor.TestEditorOptions(vrw)
// TODO: stats
merged, _, err := merger.MergeTable(sql.NewContext(context.Background()), tableName, opts, MergeOpts{IsCherryPick: false})
merged, _, err := merger.MergeTable(sql.NewContext(context.Background()), doltdb.TableName{Name: tableName}, opts, MergeOpts{IsCherryPick: false})
if err != nil {
t.Fatal(err)
}
@@ -361,7 +361,7 @@ func TestNomsMergeCommits(t *testing.T) {
t.Fatal(err)
}
opts := editor.TestEditorOptions(vrw)
merged, stats, err := merger.MergeTable(sql.NewContext(context.Background()), tableName, opts, MergeOpts{IsCherryPick: false})
merged, stats, err := merger.MergeTable(sql.NewContext(context.Background()), doltdb.TableName{Name: tableName}, opts, MergeOpts{IsCherryPick: false})
if err != nil {
t.Fatal(err)
}
@@ -266,7 +266,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{
name: "no conflicts",
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
},
},
{
@@ -284,7 +284,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CheckoutCmd{}, []string{env.DefaultInitBranch}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
ColConflicts: []merge.ColConflict{
{
Kind: merge.NameCollision,
@@ -312,7 +312,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CheckoutCmd{}, []string{env.DefaultInitBranch}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
IdxConflicts: []merge.IdxConflict{
{
Kind: merge.NameCollision,
@@ -338,7 +338,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CheckoutCmd{}, []string{env.DefaultInitBranch}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
ColConflicts: []merge.ColConflict{
{
Kind: merge.TagCollision,
@@ -366,7 +366,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CheckoutCmd{}, []string{env.DefaultInitBranch}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
IdxConflicts: []merge.IdxConflict{
{
Kind: merge.TagCollision,
@@ -389,7 +389,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CheckoutCmd{}, []string{env.DefaultInitBranch}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
ChkConflicts: []merge.ChkConflict{
{
Kind: merge.TagCollision,
@@ -432,7 +432,7 @@ var mergeSchemaConflictTests = []mergeSchemaConflictTest{
{commands.CommitCmd{}, []string{"-m", "modified branch other"}},
},
expConflict: merge.SchemaConflict{
TableName: "test",
TableName: doltdb.TableName{Name: "test"},
ChkConflicts: []merge.ChkConflict{
{
Kind: merge.TagCollision,
@@ -497,7 +497,7 @@ var mergeForeignKeyTests = []mergeForeignKeyTest{
}),
expFKConflict: []merge.FKConflict{},
},
//{
// {
// name: "add foreign key, drop foreign key, merge",
// setup: []testCommand{
// {commands.SqlCmd{}, []string{"-q", "alter table quiz add constraint q2_fk foreign key (q2) references test(t2);"}},
@@ -519,7 +519,7 @@ var mergeForeignKeyTests = []mergeForeignKeyTest{
// ReferencedTableIndex: "dolt_fk_2",
// ReferencedTableColumns: []uint64{2}}),
// expFKConflict: []merge.FKConflict{},
//},
// },
}
func colCollection(cols ...schema.Column) *schema.ColCollection {
@@ -631,7 +631,7 @@ func testMergeSchemasWithConflicts(t *testing.T, test mergeSchemaConflictTest) {
otherSch := getSchema(t, dEnv)
_, actConflicts, mergeInfo, _, err := merge.SchemaMerge(context.Background(), types.Format_Default, mainSch, otherSch, ancSch, "test")
_, actConflicts, mergeInfo, _, err := merge.SchemaMerge(context.Background(), types.Format_Default, mainSch, otherSch, ancSch, doltdb.TableName{Name: "test"})
assert.False(t, mergeInfo.InvalidateSecondaryIndexes)
if test.expectedErr != nil {
// We don't use errors.Is here because errors generated by `Kind.New` compare stack traces in their `Is` implementation.
@@ -640,7 +640,7 @@ func testMergeSchemasWithConflicts(t *testing.T, test mergeSchemaConflictTest) {
}
require.NoError(t, err)
assert.Equal(t, actConflicts.TableName, "test")
assert.Equal(t, actConflicts.TableName.Name, "test")
assert.Equal(t, test.expConflict.Count(), actConflicts.Count())
+5 -4
View File
@@ -69,13 +69,14 @@ type FKViolationReceiver interface {
// GetForeignKeyViolations returns the violations that have been created as a
// result of the diff between |baseRoot| and |newRoot|. It sends the violations to |receiver|.
func GetForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *set.StrSet, receiver FKViolationReceiver) error {
func GetForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *doltdb.TableNameSet, receiver FKViolationReceiver) error {
fkColl, err := newRoot.GetForeignKeyCollection(ctx)
if err != nil {
return err
}
for _, foreignKey := range fkColl.AllKeys() {
if !foreignKey.IsResolved() || (tables.Size() != 0 && !tables.Contains(foreignKey.TableName)) {
// TODO: schema names
if !foreignKey.IsResolved() || (tables.Size() != 0 && !tables.Contains(doltdb.TableName{Name: foreignKey.TableName})) {
continue
}
@@ -156,7 +157,7 @@ func GetForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootV
// AddForeignKeyViolations adds foreign key constraint violations to each table.
// todo(andy): pass doltdb.Rootish
func AddForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *set.StrSet, theirRootIsh hash.Hash) (doltdb.RootValue, *set.StrSet, error) {
func AddForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *doltdb.TableNameSet, theirRootIsh hash.Hash) (doltdb.RootValue, *set.StrSet, error) {
violationWriter := &foreignKeyViolationWriter{rootValue: newRoot, theirRootIsh: theirRootIsh, violatedTables: set.NewStrSet(nil)}
err := GetForeignKeyViolations(ctx, newRoot, baseRoot, tables, violationWriter)
if err != nil {
@@ -167,7 +168,7 @@ func AddForeignKeyViolations(ctx context.Context, newRoot, baseRoot doltdb.RootV
// GetForeignKeyViolatedTables returns a list of tables that have foreign key
// violations based on the diff between |newRoot| and |baseRoot|.
func GetForeignKeyViolatedTables(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *set.StrSet) (*set.StrSet, error) {
func GetForeignKeyViolatedTables(ctx context.Context, newRoot, baseRoot doltdb.RootValue, tables *doltdb.TableNameSet) (*set.StrSet, error) {
handler := &foreignKeyViolationTracker{tableSet: set.NewStrSet(nil)}
err := GetForeignKeyViolations(ctx, newRoot, baseRoot, tables, handler)
if err != nil {
@@ -80,7 +80,7 @@ func (m UniqCVMeta) PrettyPrint() string {
return jsonStr
}
func replaceUniqueKeyViolation(ctx context.Context, edt *prolly.ArtifactsEditor, m prolly.Map, k val.Tuple, kd val.TupleDesc, theirRootIsh doltdb.Rootish, vInfo []byte, tblName string) error {
func replaceUniqueKeyViolation(ctx context.Context, edt *prolly.ArtifactsEditor, m prolly.Map, k val.Tuple, theirRootIsh doltdb.Rootish, vInfo []byte) error {
var value val.Tuple
err := m.Get(ctx, k, func(_, v val.Tuple) error {
value = v
@@ -94,7 +94,7 @@ func TestRenameTable(t *testing.T) {
schemas, err := doltdb.GetAllSchemas(ctx, root)
require.NoError(t, err)
beforeSch := schemas[tt.oldName]
beforeSch := schemas[doltdb.TableName{Name: tt.oldName}]
updatedRoot, err := renameTable(ctx, root, tt.oldName, tt.newName)
if len(tt.expectedErr) > 0 {
@@ -115,7 +115,7 @@ func TestRenameTable(t *testing.T) {
schemas, err = doltdb.GetAllSchemas(ctx, updatedRoot)
require.NoError(t, err)
require.Equal(t, beforeSch, schemas[tt.newName])
require.Equal(t, beforeSch, schemas[doltdb.TableName{Name: tt.newName}])
})
}
}
@@ -411,10 +411,27 @@ func (p *DoltDatabaseProvider) CreateDatabase(ctx *sql.Context, name string) err
return p.CreateCollatedDatabase(ctx, name, sql.Collation_Default)
}
func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name string, collation sql.CollationID) (err error) {
p.mu.Lock()
defer p.mu.Unlock()
func commitTransaction(ctx *sql.Context, dSess *dsess.DoltSession, rsc *doltdb.ReplicationStatusController) error {
currentTx := ctx.GetTransaction()
err := dSess.CommitTransaction(ctx, currentTx)
if err != nil {
return err
}
newTx, err := dSess.StartTransaction(ctx, sql.ReadWrite)
if err != nil {
return err
}
ctx.SetTransaction(newTx)
if rsc != nil {
dsess.WaitForReplicationController(ctx, *rsc)
}
return nil
}
func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name string, collation sql.CollationID) (err error) {
exists, isDir := p.fs.Exists(name)
if exists && isDir {
return sql.ErrDatabaseExists.New(name)
@@ -422,6 +439,24 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
return fmt.Errorf("Cannot create DB, file exists at %s", name)
}
sess := dsess.DSessFromSess(ctx.Session)
var rsc doltdb.ReplicationStatusController
// before we create a new database, we need to implicitly commit any current transaction, because we'll begin a new
// one after we create the new DB
err = commitTransaction(ctx, sess, &rsc)
if err != nil {
return err
}
p.mu.Lock()
needUnlock := true
defer func() {
if needUnlock {
p.mu.Unlock()
}
}()
err = p.fs.MkDirs(name)
if err != nil {
return err
@@ -440,7 +475,6 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
}
// TODO: fill in version appropriately
sess := dsess.DSessFromSess(ctx.Session)
newEnv := env.Load(ctx, env.GetCurrentUserHomeDir, newFs, p.dbFactoryUrl, "TODO")
newDbStorageFormat := types.Format_Default
@@ -449,6 +483,8 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
return err
}
updatedCollation, updatedSchemas := false, false
// Set the collation
if collation != sql.Collation_Default {
workingRoot, err := newEnv.WorkingRoot(ctx)
@@ -466,6 +502,8 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
if err = newEnv.UpdateStagedRoot(ctx, newRoot); err != nil {
return err
}
updatedCollation = true
}
// If the search path is enabled, we need to create our initial schema object (public and pg_catalog are available
@@ -495,9 +533,60 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
if err = newEnv.UpdateStagedRoot(ctx, workingRoot); err != nil {
return err
}
updatedSchemas = true
}
return p.registerNewDatabase(ctx, name, newEnv)
err = p.registerNewDatabase(ctx, name, newEnv)
if err != nil {
return err
}
// Since we just created this database, we need to commit the current transaction so that the new database is
// usable in this session.
// We need to unlock the provider early to avoid a deadlock with the commit
needUnlock = false
p.mu.Unlock()
err = commitTransaction(ctx, sess, &rsc)
if err != nil {
return err
}
needsDoltCommit := updatedSchemas || updatedCollation
if needsDoltCommit {
// After making changes to the working set for the DB, create a new dolt commit so that any newly created
// branches have those changes
// TODO: it would be better if there weren't a commit for this database where these changes didn't exist, but
// we always create an empty commit as part of initializing a repo right now, and you cannot amend the initial
// commit
roots, ok := sess.GetRoots(ctx, name)
if !ok {
return fmt.Errorf("unable to get roots for database %s", name)
}
t := ctx.QueryTime()
userName := ctx.Client().User
userEmail := fmt.Sprintf("%s@%s", ctx.Client().User, ctx.Client().Address)
pendingCommit, err := sess.NewPendingCommit(ctx, name, roots, actions.CommitStagedProps{
Message: "CREATE DATABASE",
Date: t,
Name: userName,
Email: userEmail,
})
if err != nil {
return err
}
_, err = sess.DoltCommit(ctx, name, sess.GetTransaction(), pendingCommit)
if err != nil {
return err
}
}
return nil
}
type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error
@@ -310,13 +310,13 @@ func resolveNomsConflicts(ctx *sql.Context, opts editor.Options, tbl *doltdb.Tab
return resolvePkConflicts(ctx, opts, tbl, tblName, sch, conflicts)
}
func validateConstraintViolations(ctx *sql.Context, before, after doltdb.RootValue, table string) error {
tables, err := after.GetTableNames(ctx, doltdb.DefaultSchemaName)
func validateConstraintViolations(ctx *sql.Context, before, after doltdb.RootValue, table doltdb.TableName) error {
tables, err := after.GetTableNames(ctx, table.Schema)
if err != nil {
return err
}
violators, err := merge.GetForeignKeyViolatedTables(ctx, after, before, set.NewStrSet(tables))
violators, err := merge.GetForeignKeyViolatedTables(ctx, after, before, doltdb.NewTableNameSet(doltdb.ToTableNames(tables, table.Schema)))
if err != nil {
return err
}
@@ -327,12 +327,12 @@ func validateConstraintViolations(ctx *sql.Context, before, after doltdb.RootVal
return nil
}
func clearTableAndUpdateRoot(ctx *sql.Context, root doltdb.RootValue, tbl *doltdb.Table, tblName string) (doltdb.RootValue, error) {
func clearTableAndUpdateRoot(ctx *sql.Context, root doltdb.RootValue, tbl *doltdb.Table, tblName doltdb.TableName) (doltdb.RootValue, error) {
newTbl, err := tbl.ClearConflicts(ctx)
if err != nil {
return nil, err
}
newRoot, err := root.PutTable(ctx, doltdb.TableName{Name: tblName}, newTbl)
newRoot, err := root.PutTable(ctx, tblName, newTbl)
if err != nil {
return nil, err
}
@@ -360,6 +360,7 @@ func ResolveSchemaConflicts(ctx *sql.Context, ddb *doltdb.DoltDB, ws *doltdb.Wor
"To track resolution of this limitation, follow https://github.com/dolthub/dolt/issues/6616")
}
// TODO: schema names
tblSet := set.NewStrSet(tables)
updates := make(map[string]*doltdb.Table)
err := ws.MergeState().IterSchemaConflicts(ctx, ddb, func(table string, conflict doltdb.SchemaConflict) error {
@@ -399,9 +400,9 @@ func ResolveSchemaConflicts(ctx *sql.Context, ddb *doltdb.DoltDB, ws *doltdb.Wor
return ws.WithWorkingRoot(root).WithUnmergableTables(unmerged).WithMergedTables(merged), nil
}
func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltdb.RootValue, dbName string, ours bool, tblNames []string) error {
func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltdb.RootValue, dbName string, ours bool, tblNames []doltdb.TableName) error {
for _, tblName := range tblNames {
tbl, ok, err := root.GetTable(ctx, doltdb.TableName{Name: tblName})
tbl, ok, err := root.GetTable(ctx, tblName)
if err != nil {
return err
}
@@ -432,7 +433,7 @@ func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltd
if !ours {
if tbl.Format() == types.Format_DOLT {
tbl, err = resolveProllyConflicts(ctx, tbl, tblName, ourSch, sch)
tbl, err = resolveProllyConflicts(ctx, tbl, tblName.Name, ourSch, sch)
} else {
state, _, err := dSess.LookupDbState(ctx, dbName)
if err != nil {
@@ -442,7 +443,7 @@ func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltd
if ws := state.WriteSession(); ws != nil {
opts = ws.GetOptions()
}
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName, sch)
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName.Name, sch)
}
if err != nil {
return err
@@ -499,21 +500,22 @@ func DoDoltConflictsResolve(ctx *sql.Context, args []string) (int, error) {
}
// get all tables in conflict
tbls := apr.Args
if len(tbls) == 1 && tbls[0] == "." {
strTableNames := apr.Args
if len(strTableNames) == 1 && strTableNames[0] == "." {
// TODO: schema names
all, err := ws.WorkingRoot().GetTableNames(ctx, doltdb.DefaultSchemaName)
if err != nil {
return 1, nil
}
tbls = all
strTableNames = all
}
ws, err = ResolveSchemaConflicts(ctx, ddb, ws, ours, tbls)
ws, err = ResolveSchemaConflicts(ctx, ddb, ws, ours, strTableNames)
if err != nil {
return 1, err
}
err = ResolveDataConflicts(ctx, dSess, ws.WorkingRoot(), dbName, ours, tbls)
err = ResolveDataConflicts(ctx, dSess, ws.WorkingRoot(), dbName, ours, doltdb.ToTableNames(strTableNames, doltdb.DefaultSchemaName))
if err != nil {
return 1, err
}
@@ -533,7 +533,8 @@ func mergeRootToWorking(
if !squash || merged.HasSchemaConflicts() {
ws = ws.StartMerge(cm2, cm2Spec)
tt := merge.SchemaConflictTableNames(merged.SchemaConflicts)
ws = ws.WithUnmergableTables(tt)
// TODO: schema names
ws = ws.WithUnmergableTables(doltdb.FlattenTableNames(tt))
}
ws = ws.WithWorkingRoot(working)
@@ -16,7 +16,6 @@ package dprocedures
import (
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
@@ -25,9 +24,9 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/set"
)
// doltVerifyConstraints is the stored procedure version for the CLI command `dolt constraints verify`.
@@ -111,13 +110,12 @@ func doDoltConstraintsVerify(ctx *sql.Context, args []string) (int, error) {
// tables in |tableSet|. Returns the new root with the violations, and a set of table names that have violations.
// Note that constraint violations detected for ALL existing tables will be stored in the dolt_constraint_violations
// tables, but the returned set of table names will be a subset of |tableSet|.
func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.RootValue, tableSet *set.StrSet) (doltdb.RootValue, *set.StrSet, error) {
var recordViolationsForTables map[string]struct{} = nil
func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.RootValue, tableSet *doltdb.TableNameSet) (doltdb.RootValue, *doltdb.TableNameSet, error) {
var recordViolationsForTables map[doltdb.TableName]struct{} = nil
if tableSet.Size() > 0 {
recordViolationsForTables = make(map[string]struct{})
recordViolationsForTables = make(map[doltdb.TableName]struct{})
for _, table := range tableSet.AsSlice() {
table = strings.ToLower(table)
recordViolationsForTables[table] = struct{}{}
recordViolationsForTables[table.ToLower()] = struct{}{}
}
}
@@ -133,9 +131,9 @@ func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.Roo
return nil, nil, fmt.Errorf("error calculating constraint violations: %w", err)
}
tablesWithViolations := set.NewStrSet(nil)
tablesWithViolations := doltdb.NewTableNameSet(nil)
for _, tableName := range tableSet.AsSlice() {
table, ok, err := mergeResults.Root.GetTable(ctx, doltdb.TableName{Name: tableName})
table, ok, err := mergeResults.Root.GetTable(ctx, tableName)
if err != nil {
return nil, nil, err
}
@@ -160,26 +158,29 @@ func calculateViolations(ctx *sql.Context, workingRoot, comparingRoot doltdb.Roo
// parseTablesToCheck returns a set of table names to check for constraint violations. If no tables are specified, then
// all tables in the root are returned.
func parseTablesToCheck(ctx *sql.Context, workingRoot doltdb.RootValue, apr *argparser.ArgParseResults) (*set.StrSet, error) {
tableSet := set.NewStrSet(nil)
func parseTablesToCheck(ctx *sql.Context, workingRoot doltdb.RootValue, apr *argparser.ArgParseResults) (*doltdb.TableNameSet, error) {
tableSet := doltdb.NewTableNameSet(nil)
for _, val := range apr.Args {
_, tableName, ok, err := doltdb.GetTableInsensitive(ctx, workingRoot, doltdb.TableName{Name: val})
tableName, _, ok, err := resolve.Table(ctx, workingRoot, val)
if err != nil {
return nil, err
}
if !ok {
return nil, sql.ErrTableNotFound.New(tableName)
return nil, sql.ErrTableNotFound.New(val)
}
tableSet.Add(tableName)
}
// If no tables were explicitly specified, then check all tables
if tableSet.Size() == 0 {
// TODO: schema search path
names, err := workingRoot.GetTableNames(ctx, doltdb.DefaultSchemaName)
if err != nil {
return nil, err
}
tableSet.Add(names...)
tableSet.Add(doltdb.ToTableNames(names, doltdb.DefaultSchemaName)...)
}
return tableSet, nil
+6 -1
View File
@@ -682,7 +682,12 @@ func (d *DoltSession) PendingCommitAllStaged(ctx *sql.Context, branchState *bran
// NewPendingCommit returns a new |doltdb.PendingCommit| for the database named, using the roots given, adding any
// merge parent from an in progress merge as appropriate. The session working set is not updated with these new roots,
// but they are set in the returned |doltdb.PendingCommit|. If there are no changes staged, this method returns nil.
func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
func (d *DoltSession) NewPendingCommit(
ctx *sql.Context,
dbName string,
roots doltdb.Roots,
props actions.CommitStagedProps,
) (*doltdb.PendingCommit, error) {
branchState, ok, err := d.lookupDbState(ctx, dbName)
if err != nil {
return nil, err
@@ -23,13 +23,14 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/store/types"
)
// NewConflictsTable returns a new ConflictsTable instance
func NewConflictsTable(ctx *sql.Context, tblName string, srcTbl sql.Table, root doltdb.RootValue, rs RootSetter) (sql.Table, error) {
tbl, tblName, ok, err := doltdb.GetTableInsensitive(ctx, root, doltdb.TableName{Name: tblName})
resolvedTableName, tbl, ok, err := resolve.Table(ctx, root, tblName)
if err != nil {
return nil, err
} else if !ok {
@@ -41,14 +42,14 @@ func NewConflictsTable(ctx *sql.Context, tblName string, srcTbl sql.Table, root
if !ok {
return nil, fmt.Errorf("%s can not have conflicts because it is not updateable", tblName)
}
return newProllyConflictsTable(ctx, tbl, upd, tblName, root, rs)
return newProllyConflictsTable(ctx, tbl, upd, resolvedTableName, root, rs)
}
return newNomsConflictsTable(ctx, tbl, tblName, root, rs)
return newNomsConflictsTable(ctx, tbl, resolvedTableName.Name, root, rs)
}
func newNomsConflictsTable(ctx *sql.Context, tbl *doltdb.Table, tblName string, root doltdb.RootValue, rs RootSetter) (sql.Table, error) {
rd, err := merge.NewConflictReader(ctx, tbl, tblName)
rd, err := merge.NewConflictReader(ctx, tbl, doltdb.TableName{Name: tblName})
if err != nil {
return nil, err
}
@@ -114,7 +115,8 @@ func (ct ConflictsTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error)
// PartitionRows returns a RowIter for the given partition
func (ct ConflictsTable) PartitionRows(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) {
// conflict reader must be reset each time partitionRows is called.
rd, err := merge.NewConflictReader(ctx, ct.tbl, ct.tblName)
// TODO: schema name
rd, err := merge.NewConflictReader(ctx, ct.tbl, doltdb.TableName{Name: ct.tblName})
if err != nil {
return nil, err
}
@@ -36,7 +36,14 @@ import (
"github.com/dolthub/dolt/go/store/val"
)
func newProllyConflictsTable(ctx *sql.Context, tbl *doltdb.Table, sourceUpdatableTbl sql.UpdatableTable, tblName string, root doltdb.RootValue, rs RootSetter) (sql.Table, error) {
func newProllyConflictsTable(
ctx *sql.Context,
tbl *doltdb.Table,
sourceUpdatableTbl sql.UpdatableTable,
tblName doltdb.TableName,
root doltdb.RootValue,
rs RootSetter,
) (sql.Table, error) {
arts, err := tbl.GetArtifacts(ctx)
if err != nil {
return nil, err
@@ -51,13 +58,13 @@ func newProllyConflictsTable(ctx *sql.Context, tbl *doltdb.Table, sourceUpdatabl
if err != nil {
return nil, err
}
sqlSch, err := sqlutil.FromDoltSchema("", doltdb.DoltConfTablePrefix+tblName, confSch)
sqlSch, err := sqlutil.FromDoltSchema("", doltdb.DoltConfTablePrefix+tblName.Name, confSch)
if err != nil {
return nil, err
}
return ProllyConflictsTable{
tblName: tblName,
tblName: tblName.Name,
sqlSch: sqlSch,
baseSch: baseSch,
ourSch: ourSch,
@@ -119,8 +119,8 @@ func (dt *SchemaConflictsTable) PartitionRows(ctx *sql.Context, part sql.Partiti
var conflicts []schemaConflict
err = p.state.IterSchemaConflicts(ctx, p.ddb, func(table string, cnf doltdb.SchemaConflict) error {
c, err := newSchemaConflict(ctx, table, baseRoot, cnf)
// TODO: schema name
c, err := newSchemaConflict(ctx, doltdb.TableName{Name: table}, baseRoot, cnf)
if err != nil {
return err
}
@@ -147,14 +147,14 @@ func (p schemaConflictsPartition) Key() []byte {
}
type schemaConflict struct {
table string
table doltdb.TableName
baseSch string
ourSch string
theirSch string
description string
}
func newSchemaConflict(ctx context.Context, table string, baseRoot doltdb.RootValue, c doltdb.SchemaConflict) (schemaConflict, error) {
func newSchemaConflict(ctx context.Context, table doltdb.TableName, baseRoot doltdb.RootValue, c doltdb.SchemaConflict) (schemaConflict, error) {
bs, err := doltdb.GetAllSchemas(ctx, baseRoot)
if err != nil {
return schemaConflict{}, err
@@ -166,13 +166,12 @@ func newSchemaConflict(ctx context.Context, table string, baseRoot doltdb.RootVa
return schemaConflict{}, err
}
// TODO: schema name
baseFKs, _ := fkc.KeysForTable(doltdb.TableName{Name: table})
baseFKs, _ := fkc.KeysForTable(table)
var base string
if baseSch != nil {
var err error
base, err = getCreateTableStatement(table, baseSch, baseFKs, bs)
base, err = getCreateTableStatement(table.Name, baseSch, baseFKs, bs)
if err != nil {
return schemaConflict{}, err
}
@@ -183,7 +182,7 @@ func newSchemaConflict(ctx context.Context, table string, baseRoot doltdb.RootVa
var ours string
if c.ToSch != nil {
var err error
ours, err = getCreateTableStatement(table, c.ToSch, c.ToFks, c.ToParentSchemas)
ours, err = getCreateTableStatement(table.Name, c.ToSch, c.ToFks, c.ToParentSchemas)
if err != nil {
return schemaConflict{}, err
}
@@ -194,7 +193,7 @@ func newSchemaConflict(ctx context.Context, table string, baseRoot doltdb.RootVa
var theirs string
if c.FromSch != nil {
var err error
theirs, err = getCreateTableStatement(table, c.FromSch, c.FromFks, c.FromParentSchemas)
theirs, err = getCreateTableStatement(table.Name, c.FromSch, c.FromFks, c.FromParentSchemas)
if err != nil {
return schemaConflict{}, err
}
@@ -226,13 +225,12 @@ func newSchemaConflict(ctx context.Context, table string, baseRoot doltdb.RootVa
}, nil
}
func getCreateTableStatement(table string, sch schema.Schema, fks []doltdb.ForeignKey, parents map[string]schema.Schema) (string, error) {
func getCreateTableStatement(table string, sch schema.Schema, fks []doltdb.ForeignKey, parents map[doltdb.TableName]schema.Schema) (string, error) {
return sqlfmt.GenerateCreateTableStatement(table, sch, fks, parents)
}
func getSchemaConflictDescription(ctx context.Context, table string, base, ours, theirs schema.Schema) (string, error) {
nbf := noms.Format_Default
_, conflict, _, _, err := merge.SchemaMerge(ctx, nbf, ours, theirs, base, table)
func getSchemaConflictDescription(ctx context.Context, table doltdb.TableName, base, ours, theirs schema.Schema) (string, error) {
_, conflict, _, _, err := merge.SchemaMerge(ctx, noms.Format_Default, ours, theirs, base, table)
if err != nil {
return "", err
}
@@ -251,7 +249,7 @@ func (it *schemaConflictsIter) Next(ctx *sql.Context) (sql.Row, error) {
}
c := it.conflicts[0] // pop next conflict
it.conflicts = it.conflicts[1:]
return sql.NewRow(c.table, c.baseSch, c.ourSch, c.theirSch, c.description), nil
return sql.NewRow(c.table.Name, c.baseSch, c.ourSch, c.theirSch, c.description), nil
}
func (it *schemaConflictsIter) Close(ctx *sql.Context) error {
@@ -122,7 +122,12 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
return nil, err
}
rows := make([]statusTableRow, 0, len(stagedTables)+len(unstagedTables))
stagedSchemas, unstagedSchemas, err := diff.GetStagedUnstagedDatabaseSchemaDeltas(ctx, roots)
if err != nil {
return nil, err
}
rows := make([]statusTableRow, 0, len(stagedTables)+len(unstagedTables)+len(stagedSchemas)+len(unstagedSchemas))
cvTables, err := doltdb.TablesWithConstraintViolations(ctx, roots.Working)
if err != nil {
@@ -195,9 +200,35 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
})
}
for _, sd := range stagedSchemas {
rows = append(rows, statusTableRow{
tableName: sd.CurName(),
isStaged: true,
status: schemaStatusString(sd),
})
}
for _, sd := range unstagedSchemas {
rows = append(rows, statusTableRow{
tableName: sd.CurName(),
isStaged: false,
status: schemaStatusString(sd),
})
}
return &StatusItr{rows: rows}, nil
}
func schemaStatusString(sd diff.DatabaseSchemaDelta) string {
if sd.IsAdd() {
return "new schema"
} else if sd.IsDrop() {
return "deleted schema"
} else {
panic("unexpected schema delta")
}
}
func tableName(td diff.TableDelta) string {
if td.IsRename() {
return fmt.Sprintf("%s -> %s", td.FromName, td.ToName)
@@ -108,9 +108,35 @@ func TestSingleScript(t *testing.T) {
var scripts = []queries.ScriptTest{
{
Name: "",
SetUpScript: []string{},
Assertions: []queries.ScriptTestAssertion{},
Name: "create database in a transaction",
SetUpScript: []string{
"START TRANSACTION",
"CREATE DATABASE test",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "USE test",
SkipResultsCheck: true,
},
{
Query: "CREATE TABLE foo (bar INT)",
SkipResultsCheck: true,
},
{
Query: "USE mydb",
SkipResultsCheck: true,
},
{
Query: "INSERT INTO test.foo VALUES (1)",
SkipResultsCheck: true,
},
{
Query: "SELECT * FROM test.foo",
Expected: []sql.Row{
{1},
},
},
},
},
}
@@ -284,9 +310,9 @@ func TestSingleQueryPrepared(t *testing.T) {
t.Skip()
harness := newDoltHarness(t)
//engine := enginetest.NewEngine(t, harness)
//enginetest.CreateIndexes(t, harness, engine)
//engine := enginetest.NewSpatialEngine(t, harness)
// engine := enginetest.NewEngine(t, harness)
// enginetest.CreateIndexes(t, harness, engine)
// engine := enginetest.NewSpatialEngine(t, harness)
engine, err := harness.NewEngine(t)
if err != nil {
panic(err)
@@ -305,8 +331,8 @@ func TestSingleQueryPrepared(t *testing.T) {
enginetest.RunQueryWithContext(t, engine, harness, nil, q)
}
//engine.Analyzer.Debug = true
//engine.Analyzer.Verbose = true
// engine.Analyzer.Debug = true
// engine.Analyzer.Verbose = true
var test queries.QueryTest
test = queries.QueryTest{
@@ -842,8 +868,7 @@ func TestDropColumn(t *testing.T) {
func TestCreateDatabase(t *testing.T) {
h := newDoltHarness(t)
defer h.Close()
enginetest.TestCreateDatabase(t, h)
RunCreateDatabaseTest(t, h)
}
func TestBlobs(t *testing.T) {
@@ -1986,7 +2011,7 @@ func TestStatsAutoRefreshConcurrency(t *testing.T) {
_, iter, _, err := engine.Query(ctx, q)
require.NoError(t, err)
_, err = sql.RowIterToRows(ctx, iter)
//fmt.Printf("%s %d\n", tag, id)
// fmt.Printf("%s %d\n", tag, id)
require.NoError(t, err)
}
@@ -639,22 +639,22 @@ func RunMultiDbTransactionsTest(t *testing.T, h DoltEnginetestHarness) {
func RunMultiDbTransactionsPreparedTest(t *testing.T, h DoltEnginetestHarness) {
for _, script := range MultiDbTransactionTests {
//func() {
// func() {
h := h.NewHarness(t)
defer h.Close()
enginetest.TestScriptPrepared(t, h, script)
//}()
// }()
}
}
func RunDoltScriptsTest(t *testing.T, harness DoltEnginetestHarness) {
for _, script := range DoltScripts {
//go func() {
// go func() {
harness := harness.NewHarness(t)
enginetest.TestScript(t, harness, script)
harness.Close()
//}()
// }()
}
}
@@ -824,6 +824,17 @@ func RunShowCreateTableTests(t *testing.T, h DoltEnginetestHarness) {
}
}
func RunCreateDatabaseTest(t *testing.T, h *DoltHarness) {
enginetest.TestCreateDatabase(t, h)
h.Close()
for _, script := range DoltCreateDatabaseScripts {
h := h.NewHarness(t)
enginetest.TestScript(t, h, script)
h.Close()
}
}
func RunShowCreateTablePreparedTests(t *testing.T, h DoltEnginetestHarness) {
for _, script := range ShowCreateTableScriptTests {
func() {
@@ -0,0 +1,116 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package enginetest
import (
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
)
var DoltCreateDatabaseScripts = []queries.ScriptTest{
{
Name: "create database simple",
SetUpScript: []string{
"CREATE DATABASE if not exists mydb", // TODO: this is an artifact of how we run the tests
"CREATE DATABASE test",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SHOW DATABASES",
Expected: []sql.Row{
{"information_schema"},
{"mydb"},
{"mysql"},
{"test"},
},
},
{
Query: "USE test",
SkipResultsCheck: true,
},
{
Query: "CREATE TABLE foo (bar INT)",
SkipResultsCheck: true,
},
{
Query: "USE mydb",
SkipResultsCheck: true,
},
{
Query: "INSERT INTO test.foo VALUES (1)",
SkipResultsCheck: true,
},
{
Query: "SELECT * FROM test.foo",
Expected: []sql.Row{
{1},
},
},
},
},
{
Name: "create database with non standard collation, create branch",
SetUpScript: []string{
"CREATE DATABASE test CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "use test",
SkipResultsCheck: true,
},
{
Query: "call dolt_branch('b1')",
SkipResultsCheck: true,
},
{
Query: "show create database test",
Expected: []sql.Row{
{"test", "CREATE DATABASE `test` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci */"},
},
},
},
},
{
Name: "create database in a transaction",
SetUpScript: []string{
"START TRANSACTION",
"CREATE DATABASE test",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "USE test",
SkipResultsCheck: true,
},
{
Query: "CREATE TABLE foo (bar INT)",
SkipResultsCheck: true,
},
{
Query: "USE mydb",
SkipResultsCheck: true,
},
{
Query: "INSERT INTO test.foo VALUES (1)",
SkipResultsCheck: true,
},
{
Query: "SELECT * FROM test.foo",
Expected: []sql.Row{
{1},
},
},
},
},
}
@@ -76,7 +76,7 @@ func GenerateSqlPatchSchemaStatements(ctx *sql.Context, toRoot doltdb.RootValue,
if td.IsDrop() {
ddlStatements = append(ddlStatements, DropTableStmt(td.FromName.Name))
} else if td.IsAdd() {
stmt, err := GenerateCreateTableStatement(td.ToName.Name, td.ToSch, td.ToFks, nameMapFromTableNameMap(td.ToFksParentSch))
stmt, err := GenerateCreateTableStatement(td.ToName.Name, td.ToSch, td.ToFks, td.ToFksParentSch)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
}
@@ -92,17 +92,9 @@ func GenerateSqlPatchSchemaStatements(ctx *sql.Context, toRoot doltdb.RootValue,
return ddlStatements, nil
}
func nameMapFromTableNameMap(tableNameMap map[doltdb.TableName]schema.Schema) map[string]schema.Schema {
nameMap := make(map[string]schema.Schema)
for name := range tableNameMap {
nameMap[name.Name] = tableNameMap[name]
}
return nameMap
}
// generateNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements.
// TODO: schema names
func generateNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) {
func generateNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[doltdb.TableName]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) {
if td.IsAdd() || td.IsDrop() {
// use add and drop specific methods
return nil, nil
@@ -167,7 +159,8 @@ func generateNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas ma
switch fkDiff.DiffType {
case diff.SchDiffNone:
case diff.SchDiffAdded:
parentSch := toSchemas[fkDiff.To.ReferencedTableName]
// TODO: schema name
parentSch := toSchemas[doltdb.TableName{Name: fkDiff.To.ReferencedTableName}]
ddlStatements = append(ddlStatements, AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
case diff.SchDiffRemoved:
from := fkDiff.From
@@ -175,8 +168,8 @@ func generateNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas ma
case diff.SchDiffModified:
from := fkDiff.From
ddlStatements = append(ddlStatements, AlterTableDropForeignKeyStmt(from.TableName, from.Name))
parentSch := toSchemas[fkDiff.To.ReferencedTableName]
// TODO: schema name
parentSch := toSchemas[doltdb.TableName{Name: fkDiff.To.ReferencedTableName}]
ddlStatements = append(ddlStatements, AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
}
}
@@ -449,12 +442,7 @@ func AlterTableDropForeignKeyStmt(tableName, fkName string) string {
// `SHOW CREATE TABLE` in the engine, but may have some differences. Callers are advised to use the engine when
// possible.
// TODO: schema names
func GenerateCreateTableStatement(
tblName string,
sch schema.Schema,
fks []doltdb.ForeignKey,
fksParentSch map[string]schema.Schema,
) (string, error) {
func GenerateCreateTableStatement(tblName string, sch schema.Schema, fks []doltdb.ForeignKey, fksParentSch map[doltdb.TableName]schema.Schema) (string, error) {
colStmts := make([]string, sch.GetAllCols().Size())
// Statement creation parts for each column
@@ -478,7 +466,8 @@ func GenerateCreateTableStatement(
}
for _, fk := range fks {
colStmts = append(colStmts, GenerateCreateTableForeignKeyDefinition(fk, sch, fksParentSch[fk.ReferencedTableName]))
// TODO: schema name
colStmts = append(colStmts, GenerateCreateTableForeignKeyDefinition(fk, sch, fksParentSch[doltdb.TableName{Name: fk.ReferencedTableName}]))
}
for _, check := range sch.Checks().AllChecks() {
@@ -40,16 +40,7 @@ const (
// BuildProllyIndexExternal builds unique and non-unique indexes with a
// single prolly tree materialization by presorting the index keys in an
// intermediate file format.
func BuildProllyIndexExternal(
ctx *sql.Context,
vrw types.ValueReadWriter,
ns tree.NodeStore,
sch schema.Schema,
tableName string,
idx schema.Index,
primary prolly.Map,
uniqCb DupEntryCb,
) (durable.Index, error) {
func BuildProllyIndexExternal(ctx *sql.Context, vrw types.ValueReadWriter, ns tree.NodeStore, sch schema.Schema, tableName string, idx schema.Index, primary prolly.Map, uniqCb DupEntryCb) (durable.Index, error) {
empty, err := durable.NewEmptyIndex(ctx, vrw, ns, idx.Schema())
if err != nil {
return nil, err
@@ -45,7 +45,8 @@ type CreateIndexReturn struct {
func CreateIndex(
ctx *sql.Context,
table *doltdb.Table,
tableName, indexName string,
tableName string,
indexName string,
columns []string,
prefixLengths []uint16,
props schema.IndexProperties,
@@ -37,7 +37,7 @@ const batchSize = 10000
type BatchSqlExportWriter struct {
tableName string
sch schema.Schema
parentSchs map[string]schema.Schema
parentSchs map[doltdb.TableName]schema.Schema
foreignKeys []doltdb.ForeignKey
wr io.WriteCloser
root doltdb.RootValue
@@ -49,7 +49,15 @@ type BatchSqlExportWriter struct {
}
// OpenBatchedSQLExportWriter returns a new SqlWriter for the table with the writer given.
func OpenBatchedSQLExportWriter(ctx context.Context, wr io.WriteCloser, root doltdb.RootValue, tableName string, autocommitOff bool, sch schema.Schema, editOpts editor.Options) (*BatchSqlExportWriter, error) {
func OpenBatchedSQLExportWriter(
ctx context.Context,
wr io.WriteCloser,
root doltdb.RootValue,
tableName string,
autocommitOff bool,
sch schema.Schema,
editOpts editor.Options,
) (*BatchSqlExportWriter, error) {
allSchemas, err := doltdb.GetAllSchemas(ctx, root)
if err != nil {
@@ -35,7 +35,7 @@ import (
type SqlExportWriter struct {
tableName string
sch schema.Schema
parentSchs map[string]schema.Schema
parentSchs map[doltdb.TableName]schema.Schema
foreignKeys []doltdb.ForeignKey
wr io.WriteCloser
root doltdb.RootValue