diff --git a/go/cmd/dolt/commands/engine/sqlengine.go b/go/cmd/dolt/commands/engine/sqlengine.go index ea843ce39d..3a1606d9f7 100644 --- a/go/cmd/dolt/commands/engine/sqlengine.go +++ b/go/cmd/dolt/commands/engine/sqlengine.go @@ -287,7 +287,7 @@ func getDbStates(ctx context.Context, dbs []dsqle.SqlDatabase) ([]dsess.InitialD var init dsess.InitialDbState var err error - _, val, ok := sql.SystemVariables.GetGlobal(dsqle.DefaultBranchKey) + _, val, ok := sql.SystemVariables.GetGlobal(dsess.DefaultBranchKey(db.Name())) if ok && val != "" { init, err = getInitialDBStateWithDefaultBranch(ctx, db, val.(string)) } else { @@ -314,7 +314,7 @@ func getInitialDBStateWithDefaultBranch(ctx context.Context, db dsqle.SqlDatabas head, err := ddb.ResolveCommitRef(ctx, r) if err != nil { - init.Err = fmt.Errorf("@@GLOBAL.dolt_default_branch (%s) is not a valid branch", branch) + init.Err = fmt.Errorf("failed to connect to database default branch: '%s/%s'; %w", db.Name(), branch, err) } else { init.Err = nil } diff --git a/go/cmd/dolt/commands/sqlserver/logformat.go b/go/cmd/dolt/commands/sqlserver/logformat.go index 50dbe03688..2be31d461f 100755 --- a/go/cmd/dolt/commands/sqlserver/logformat.go +++ b/go/cmd/dolt/commands/sqlserver/logformat.go @@ -20,7 +20,7 @@ import ( "strings" "time" - sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/sql" "github.com/sirupsen/logrus" ) @@ -45,8 +45,8 @@ func (l LogFormat) Format(entry *logrus.Entry) ([]byte, error) { lvl = "TRACE" } - connectionId := entry.Data[sqle.ConnectionIdLogField] - delete(entry.Data, sqle.ConnectionIdLogField) + connectionId := entry.Data[sql.ConnectionIdLogField] + delete(entry.Data, sql.ConnectionIdLogField) var dataFormat strings.Builder var i int diff --git a/go/cmd/dolt/commands/tblcmds/import.go b/go/cmd/dolt/commands/tblcmds/import.go index 312224b432..68175e423d 100644 --- a/go/cmd/dolt/commands/tblcmds/import.go +++ b/go/cmd/dolt/commands/tblcmds/import.go @@ -62,6 +62,7 @@ const ( fileTypeParam = "file-type" delimParam = "delim" ignoreSkippedRows = "ignore-skipped-rows" + disableFkChecks = "disable-fk-checks" ) var importDocs = cli.CommandDocumentationContent{ @@ -85,7 +86,7 @@ A mapping file can be used to map fields between the file being imported and the In create, update, and replace scenarios the file's extension is used to infer the type of the file. If a file does not have the expected extension then the {{.EmphasisLeft}}--file-type{{.EmphasisRight}} parameter should be used to explicitly define the format of the file in one of the supported formats (csv, psv, json, xlsx). For files separated by a delimiter other than a ',' (type csv) or a '|' (type psv), the --delim parameter can be used to specify a delimiter`, Synopsis: []string{ - "-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}", + "-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--disable-fk-checks] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}", "-u [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}", "-r [--map {{.LessThan}}file{{.GreaterThan}}] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}", }, @@ -102,6 +103,7 @@ type importOptions struct { src mvdata.DataLocation srcOptions interface{} ignoreSkippedRows bool + disableFkChecks bool } func (m importOptions) IsBatched() bool { @@ -164,6 +166,7 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d force := apr.Contains(forceParam) contOnErr := apr.Contains(contOnErrParam) ignore := apr.Contains(ignoreSkippedRows) + disableFks := apr.Contains(disableFkChecks) val, _ := apr.GetValue(primaryKeyParam) pks := funcitr.MapStrings(strings.Split(val, ","), strings.TrimSpace) @@ -229,15 +232,6 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d if !exists { return nil, errhand.BuildDError("The following table could not be found: %s", tableName).Build() } - fkc, err := root.GetForeignKeyCollection(ctx) - if err != nil { - return nil, errhand.VerboseErrorFromError(err) - } - decFks, refFks := fkc.KeysForTable(tableName) - if len(decFks) > 0 || len(refFks) > 0 { - return nil, errhand.BuildDError("The following table is used in a foreign key and does not work "+ - "with import: %s\nThe recommended alternative is LOAD DATA", tableName).Build() - } } return &importOptions{ @@ -251,6 +245,7 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d src: srcLoc, srcOptions: srcOpts, ignoreSkippedRows: ignore, + disableFkChecks: disableFks, }, nil } @@ -341,6 +336,7 @@ func (cmd ImportCmd) ArgParser() *argparser.ArgParser { ap.SupportsFlag(replaceParam, "r", "Replace existing table with imported data while preserving the original schema.") ap.SupportsFlag(contOnErrParam, "", "Continue importing when row import errors are encountered.") ap.SupportsFlag(ignoreSkippedRows, "", "Ignore the skipped rows printed by the --continue flag.") + ap.SupportsFlag(disableFkChecks, "", "Disables foreign key checks.") ap.SupportsString(schemaParam, "s", "schema_file", "The schema for the output data.") ap.SupportsString(mappingFileParam, "m", "mapping_file", "A file that lays out how fields should be mapped from input data to output data.") ap.SupportsString(primaryKeyParam, "pk", "primary_key", "Explicitly define the name of the field in the schema which should be used as the primary key.") @@ -462,7 +458,7 @@ func newImportDataReader(ctx context.Context, root *doltdb.RootValue, dEnv *env. } func newImportSqlEngineMover(ctx context.Context, dEnv *env.DoltEnv, rdSchema schema.Schema, imOpts *importOptions) (*mvdata.SqlEngineTableWriter, *mvdata.DataMoverCreationError) { - moveOps := &mvdata.MoverOptions{Force: imOpts.force, TableToWriteTo: imOpts.destTableName, ContinueOnErr: imOpts.contOnErr, Operation: imOpts.operation} + moveOps := &mvdata.MoverOptions{Force: imOpts.force, TableToWriteTo: imOpts.destTableName, ContinueOnErr: imOpts.contOnErr, Operation: imOpts.operation, DisableFks: imOpts.disableFkChecks} // Returns the schema of the table to be created or the existing schema tableSchema, dmce := getImportSchema(ctx, dEnv, imOpts) diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index 3f4713f4f4..35685711b6 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -54,7 +54,7 @@ import ( ) const ( - Version = "0.40.0" + Version = "0.40.1" ) var dumpDocsCommand = &commands.DumpDocsCmd{} diff --git a/go/go.mod b/go/go.mod index 3b2adf7494..b04d636e54 100644 --- a/go/go.mod +++ b/go/go.mod @@ -19,7 +19,7 @@ require ( github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20220506214606-1a0fb4aab742 + github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58 github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.9.0 github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 @@ -68,7 +68,7 @@ require ( ) require ( - github.com/dolthub/go-mysql-server v0.11.1-0.20220512212424-2c1ee84d49ec + github.com/dolthub/go-mysql-server v0.11.1-0.20220517180350-eb55834c15cb github.com/google/flatbuffers v2.0.5+incompatible github.com/gosuri/uilive v0.0.4 github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 diff --git a/go/go.sum b/go/go.sum index 5b6ded0aaa..a17210716f 100755 --- a/go/go.sum +++ b/go/go.sum @@ -178,8 +178,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.11.1-0.20220512212424-2c1ee84d49ec h1:sBJRQQSPDr+d2mbuNYGSsh8vbfYkXrm1e4lgUSOy7r4= -github.com/dolthub/go-mysql-server v0.11.1-0.20220512212424-2c1ee84d49ec/go.mod h1:jfc/rO3guNfSQbyElNepEHBZ4rO3AaxKk9LMhDZa09I= +github.com/dolthub/go-mysql-server v0.11.1-0.20220517180350-eb55834c15cb h1:rynUl+BTPJ+lonOOVAZjqsI8S/8xrRJSFZAYTehcoPw= +github.com/dolthub/go-mysql-server v0.11.1-0.20220517180350-eb55834c15cb/go.mod h1:h0gpkn07YqshhXbeNkOfII0uV+I37SJYyvccH77+FOk= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= @@ -188,8 +188,8 @@ github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66 h1:WRPDbpJWEnPxP github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66/go.mod h1:N5ZIbMGuDUpTpOFQ7HcsN6WSIpTGQjHP+Mz27AfmAgk= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20220506214606-1a0fb4aab742 h1:hlRT6htzhXA2CBfsQrXb24aUkT4JTJVMcD+RPCzGrmY= -github.com/dolthub/vitess v0.0.0-20220506214606-1a0fb4aab742/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI= +github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58 h1:v7uMbJKhb9zi2Nz3pxDOUVfWO30E5wbSckVq7AjgXRw= +github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= diff --git a/go/libraries/doltcore/mvdata/data_mover.go b/go/libraries/doltcore/mvdata/data_mover.go index ee1cd18c54..03e06ccd78 100644 --- a/go/libraries/doltcore/mvdata/data_mover.go +++ b/go/libraries/doltcore/mvdata/data_mover.go @@ -53,6 +53,7 @@ type MoverOptions struct { Force bool TableToWriteTo string Operation TableImportOp + DisableFks bool } type DataMoverOptions interface { diff --git a/go/libraries/doltcore/mvdata/engine_table_writer.go b/go/libraries/doltcore/mvdata/engine_table_writer.go index b118b689f0..2e51c00b79 100644 --- a/go/libraries/doltcore/mvdata/engine_table_writer.go +++ b/go/libraries/doltcore/mvdata/engine_table_writer.go @@ -44,15 +44,16 @@ const ( tableWriterStatUpdateRate = 64 * 1024 ) -// type SqlEngineTableWriter is a utility for importing a set of rows through the sql engine. +// SqlEngineTableWriter is a utility for importing a set of rows through the sql engine. type SqlEngineTableWriter struct { se *engine.SqlEngine sqlCtx *sql.Context - tableName string - database string - contOnErr bool - force bool + tableName string + database string + contOnErr bool + force bool + disableFks bool statsCB noms.StatsCB stats types.AppliedEditStats @@ -76,6 +77,7 @@ func NewSqlEngineTableWriter(ctx context.Context, dEnv *env.DoltEnv, createTable return true, nil }) + // Simplest path would have our import path be a layer over load data se, err := engine.NewSqlEngine(ctx, mrEnv, engine.FormatCsv, dbName, false, nil, false) if err != nil { return nil, err @@ -105,10 +107,11 @@ func NewSqlEngineTableWriter(ctx context.Context, dEnv *env.DoltEnv, createTable } return &SqlEngineTableWriter{ - se: se, - sqlCtx: sqlCtx, - contOnErr: options.ContinueOnErr, - force: options.Force, + se: se, + sqlCtx: sqlCtx, + contOnErr: options.ContinueOnErr, + force: options.Force, + disableFks: options.DisableFks, database: dbName, tableName: options.TableToWriteTo, @@ -144,10 +147,11 @@ func NewSqlEngineTableWriterWithEngine(ctx *sql.Context, eng *sqle.Engine, db ds } return &SqlEngineTableWriter{ - se: engine.NewRebasedSqlEngine(eng, map[string]dsqle.SqlDatabase{db.Name(): db}), - sqlCtx: ctx, - contOnErr: options.ContinueOnErr, - force: options.Force, + se: engine.NewRebasedSqlEngine(eng, map[string]dsqle.SqlDatabase{db.Name(): db}), + sqlCtx: ctx, + contOnErr: options.ContinueOnErr, + force: options.Force, + disableFks: options.DisableFks, database: db.Name(), tableName: options.TableToWriteTo, @@ -170,6 +174,13 @@ func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan return err } + if s.disableFks { + _, _, err = s.se.Query(s.sqlCtx, fmt.Sprintf("SET FOREIGN_KEY_CHECKS = 0")) + if err != nil { + return err + } + } + err = s.createOrEmptyTableIfNeeded() if err != nil { return err diff --git a/go/libraries/doltcore/sqle/dolt_diff_table_function.go b/go/libraries/doltcore/sqle/dolt_diff_table_function.go index 250e596919..9da0f61078 100644 --- a/go/libraries/doltcore/sqle/dolt_diff_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_diff_table_function.go @@ -18,7 +18,6 @@ import ( "fmt" "io" - "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/rowconv" "github.com/dolthub/dolt/go/libraries/doltcore/schema" @@ -42,8 +41,9 @@ type DiffTableFunction struct { database sql.Database sqlSch sql.Schema joiner *rowconv.Joiner - toSch schema.Schema fromSch schema.Schema + toSch schema.Schema + diffTableSch schema.Schema } // NewInstance implements the TableFunction interface @@ -104,7 +104,7 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql return nil, err } - dtf.sqlSch, err = dtf.generateSchema(tableName, fromCommitVal, toCommitVal) + err = dtf.generateSchema(tableName, fromCommitVal, toCommitVal) if err != nil { return nil, err } @@ -127,10 +127,6 @@ func (dtf *DiffTableFunction) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, return nil, err } - if dtf.joiner == nil { - panic("schema and joiner haven't been initialized") - } - sqledb, ok := dtf.database.(Database) if !ok { panic("unable to get dolt database") @@ -258,9 +254,9 @@ func (dtf *DiffTableFunction) evaluateArguments() (string, interface{}, interfac return tableName, fromCommitVal, toCommitVal, nil } -func (dtf *DiffTableFunction) generateSchema(tableName string, fromCommitVal, toCommitVal interface{}) (sql.Schema, error) { +func (dtf *DiffTableFunction) generateSchema(tableName string, fromCommitVal, toCommitVal interface{}) error { if !dtf.Resolved() { - return nil, nil + return nil } sqledb, ok := dtf.database.(Database) @@ -270,81 +266,62 @@ func (dtf *DiffTableFunction) generateSchema(tableName string, fromCommitVal, to fromRoot, err := sqledb.rootAsOf(dtf.ctx, fromCommitVal) if err != nil { - return nil, err + return err } fromTable, _, ok, err := fromRoot.GetTableInsensitive(dtf.ctx, tableName) if err != nil { - return nil, err + return err } if !ok { - return nil, sql.ErrTableNotFound.New(tableName) + return sql.ErrTableNotFound.New(tableName) } toRoot, err := sqledb.rootAsOf(dtf.ctx, toCommitVal) if err != nil { - return nil, err + return err } toTable, _, ok, err := toRoot.GetTableInsensitive(dtf.ctx, tableName) if err != nil { - return nil, err + return err } if !ok { - return nil, sql.ErrTableNotFound.New(tableName) + return sql.ErrTableNotFound.New(tableName) } fromSchema, err := fromTable.GetSchema(dtf.ctx) if err != nil { - return nil, err + return err } toSchema, err := toTable.GetSchema(dtf.ctx) if err != nil { - return nil, err + return err } - fromSchema = schema.MustSchemaFromCols( - fromSchema.GetAllCols().Append( - schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), - schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false))) dtf.fromSch = fromSchema - - toSchema = schema.MustSchemaFromCols( - toSchema.GetAllCols().Append( - schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), - schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false))) dtf.toSch = toSchema - joiner, err := rowconv.NewJoiner( - []rowconv.NamedSchema{{Name: diff.To, Sch: toSchema}, {Name: diff.From, Sch: fromSchema}}, - map[string]rowconv.ColNamingFunc{ - diff.To: diff.ToColNamer, - diff.From: diff.FromColNamer, - }) + diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(toTable.Format(), fromSchema, toSchema) if err != nil { - return nil, err + return err } - - sch := joiner.GetSchema() - - sch = schema.MustSchemaFromCols( - sch.GetAllCols().Append( - schema.NewColumn("diff_type", schema.DiffTypeTag, types.StringKind, false))) + dtf.joiner = j // TODO: sql.Columns include a Source that indicates the table it came from, but we don't have a real table // when the column comes from a table function, so we omit the table name when we create these columns. // This allows column projections to work correctly with table functions, but we will need to add a // unique id (e.g. hash generated from method arguments) when we add support for aliasing and joining // table functions in order for the analyzer to determine which table function result a column comes from. - sqlSchema, err := sqlutil.FromDoltSchema("", sch) + sqlSchema, err := sqlutil.FromDoltSchema("", diffTableSch) if err != nil { - return nil, err + return err } - dtf.joiner = joiner + dtf.sqlSch = sqlSchema.Schema - return sqlSchema.Schema, nil + return nil } // Schema implements the sql.Node interface diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index c50198aacb..186d66469f 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -171,7 +171,7 @@ func (sess *Session) Flush(ctx *sql.Context, dbName string) error { return sess.SetRoot(ctx, dbName, ws.WorkingRoot()) } -// CommitTransaction commits the in-progress transaction for the database named +// StartTransaction refreshes the state of this session and starts a new transaction. func (sess *Session) StartTransaction(ctx *sql.Context, dbName string, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) { if TransactionsDisabled(ctx) { return DisabledTransaction{}, nil diff --git a/go/libraries/doltcore/sqle/dsess/variables.go b/go/libraries/doltcore/sqle/dsess/variables.go index c5ca981113..466f9e5e35 100644 --- a/go/libraries/doltcore/sqle/dsess/variables.go +++ b/go/libraries/doltcore/sqle/dsess/variables.go @@ -21,10 +21,11 @@ import ( ) const ( - HeadKeySuffix = "_head" - HeadRefKeySuffix = "_head_ref" - WorkingKeySuffix = "_working" - StagedKeySuffix = "_staged" + HeadKeySuffix = "_head" + HeadRefKeySuffix = "_head_ref" + WorkingKeySuffix = "_working" + StagedKeySuffix = "_staged" + DefaultBranchKeySuffix = "_default_branch" ) const ( @@ -118,6 +119,14 @@ func defineSystemVariables(name string) { Type: sql.NewSystemStringType(StagedKey(name)), Default: "", }, + { + Name: DefaultBranchKey(name), + Scope: sql.SystemVariableScope_Global, + Dynamic: true, + SetVarHintApplies: false, + Type: sql.NewSystemStringType(DefaultBranchKey(name)), + Default: "", + }, }) } } @@ -138,6 +147,10 @@ func StagedKey(dbName string) string { return dbName + StagedKeySuffix } +func DefaultBranchKey(dbName string) string { + return dbName + DefaultBranchKeySuffix +} + func IsHeadKey(key string) (bool, string) { if strings.HasSuffix(key, HeadKeySuffix) { return true, key[:len(key)-len(HeadKeySuffix)] @@ -162,6 +175,14 @@ func IsWorkingKey(key string) (bool, string) { return false, "" } +func IsDefaultBranchKey(key string) (bool, string) { + if strings.HasSuffix(key, DefaultBranchKeySuffix) { + return true, key[:len(key)-len(DefaultBranchKeySuffix)] + } + + return false, "" +} + func IsReadOnlyVersionKey(key string) bool { return strings.HasSuffix(key, HeadKeySuffix) || strings.HasSuffix(key, StagedKeySuffix) || diff --git a/go/libraries/doltcore/sqle/dtables/commit_diff_table.go b/go/libraries/doltcore/sqle/dtables/commit_diff_table.go index ded41a4932..e896c1af94 100644 --- a/go/libraries/doltcore/sqle/dtables/commit_diff_table.go +++ b/go/libraries/doltcore/sqle/dtables/commit_diff_table.go @@ -24,7 +24,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/rowconv" "github.com/dolthub/dolt/go/libraries/doltcore/schema" @@ -66,36 +65,16 @@ func NewCommitDiffTable(ctx *sql.Context, tblName string, ddb *doltdb.DoltDB, ro return nil, err } - sch = schema.MustSchemaFromCols(sch.GetAllCols().Append( - schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), - schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false))) - - if sch.GetAllCols().Size() <= 1 { - return nil, sql.ErrTableNotFound.New(diffTblName) - } - - j, err := rowconv.NewJoiner( - []rowconv.NamedSchema{{Name: diff.To, Sch: sch}, {Name: diff.From, Sch: sch}}, - map[string]rowconv.ColNamingFunc{ - diff.To: diff.ToColNamer, - diff.From: diff.FromColNamer, - }) + diffTableSchema, j, err := GetDiffTableSchemaAndJoiner(ddb.Format(), sch, sch) if err != nil { return nil, err } - sqlSch, err := sqlutil.FromDoltSchema(diffTblName, j.GetSchema()) + sqlSch, err := sqlutil.FromDoltSchema(diffTblName, diffTableSchema) if err != nil { return nil, err } - sqlSch.Schema = append(sqlSch.Schema, &sql.Column{ - Name: diffTypeColName, - Type: sql.Text, - Nullable: false, - Source: diffTblName, - }) - return &CommitDiffTable{ name: tblName, ddb: ddb, diff --git a/go/libraries/doltcore/sqle/dtables/diff_iter.go b/go/libraries/doltcore/sqle/dtables/diff_iter.go new file mode 100644 index 0000000000..8312c9cb07 --- /dev/null +++ b/go/libraries/doltcore/sqle/dtables/diff_iter.go @@ -0,0 +1,372 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dtables + +import ( + "context" + "errors" + "io" + "time" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" + "github.com/dolthub/dolt/go/libraries/doltcore/rowconv" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/types" + "github.com/dolthub/dolt/go/store/val" +) + +type diffRowItr struct { + ad diff.RowDiffer + diffSrc *diff.RowDiffSource + joiner *rowconv.Joiner + sch schema.Schema + fromCommitInfo commitInfo + toCommitInfo commitInfo +} + +var _ sql.RowIter = &diffRowItr{} + +type commitInfo struct { + name types.String + date *types.Timestamp + nameTag uint64 + dateTag uint64 +} + +func newNomsDiffIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joiner, dp DiffPartition) (*diffRowItr, error) { + fromData, fromSch, err := tableData(ctx, dp.from, ddb) + + if err != nil { + return nil, err + } + + toData, toSch, err := tableData(ctx, dp.to, ddb) + + if err != nil { + return nil, err + } + + fromConv, err := dp.rowConvForSchema(ctx, ddb.ValueReadWriter(), *dp.fromSch, fromSch) + + if err != nil { + return nil, err + } + + toConv, err := dp.rowConvForSchema(ctx, ddb.ValueReadWriter(), *dp.toSch, toSch) + + if err != nil { + return nil, err + } + + sch := joiner.GetSchema() + toCol, _ := sch.GetAllCols().GetByName(toCommit) + fromCol, _ := sch.GetAllCols().GetByName(fromCommit) + toDateCol, _ := sch.GetAllCols().GetByName(toCommitDate) + fromDateCol, _ := sch.GetAllCols().GetByName(fromCommitDate) + + fromCmInfo := commitInfo{types.String(dp.fromName), dp.fromDate, fromCol.Tag, fromDateCol.Tag} + toCmInfo := commitInfo{types.String(dp.toName), dp.toDate, toCol.Tag, toDateCol.Tag} + + rd := diff.NewRowDiffer(ctx, fromSch, toSch, 1024) + // TODO (dhruv) don't cast to noms map + rd.Start(ctx, durable.NomsMapFromIndex(fromData), durable.NomsMapFromIndex(toData)) + + warnFn := func(code int, message string, args ...string) { + ctx.Warn(code, message, args) + } + + src := diff.NewRowDiffSource(rd, joiner, warnFn) + src.AddInputRowConversion(fromConv, toConv) + + return &diffRowItr{ + ad: rd, + diffSrc: src, + joiner: joiner, + sch: joiner.GetSchema(), + fromCommitInfo: fromCmInfo, + toCommitInfo: toCmInfo, + }, nil +} + +// Next returns the next row +func (itr *diffRowItr) Next(*sql.Context) (sql.Row, error) { + r, _, err := itr.diffSrc.NextDiff() + + if err != nil { + return nil, err + } + + toAndFromRows, err := itr.joiner.Split(r) + if err != nil { + return nil, err + } + _, hasTo := toAndFromRows[diff.To] + _, hasFrom := toAndFromRows[diff.From] + + r, err = r.SetColVal(itr.toCommitInfo.nameTag, types.String(itr.toCommitInfo.name), itr.sch) + if err != nil { + return nil, err + } + + r, err = r.SetColVal(itr.fromCommitInfo.nameTag, types.String(itr.fromCommitInfo.name), itr.sch) + + if err != nil { + return nil, err + } + + if itr.toCommitInfo.date != nil { + r, err = r.SetColVal(itr.toCommitInfo.dateTag, *itr.toCommitInfo.date, itr.sch) + + if err != nil { + return nil, err + } + } + + if itr.fromCommitInfo.date != nil { + r, err = r.SetColVal(itr.fromCommitInfo.dateTag, *itr.fromCommitInfo.date, itr.sch) + + if err != nil { + return nil, err + } + } + + sqlRow, err := sqlutil.DoltRowToSqlRow(r, itr.sch) + + if err != nil { + return nil, err + } + + if hasTo && hasFrom { + sqlRow = append(sqlRow, diffTypeModified) + } else if hasTo && !hasFrom { + sqlRow = append(sqlRow, diffTypeAdded) + } else { + sqlRow = append(sqlRow, diffTypeRemoved) + } + + return sqlRow, nil +} + +// Close closes the iterator +func (itr *diffRowItr) Close(*sql.Context) (err error) { + defer itr.ad.Close() + defer func() { + closeErr := itr.diffSrc.Close() + + if err == nil { + err = closeErr + } + }() + + return nil +} + +type commitInfo2 struct { + name string + ts *time.Time +} + +type prollyDiffIter struct { + from, to prolly.Map + fromSch, toSch schema.Schema + targetFromSch, targetToSch schema.Schema + fromConverter, toConverter ProllyRowConverter + + fromCm commitInfo2 + toCm commitInfo2 + + rows chan sql.Row + errChan chan error + cancel context.CancelFunc +} + +var _ sql.RowIter = prollyDiffIter{} + +// newProllyDiffIter produces dolt_diff system table and dolt_diff table +// function rows. The rows first have the "to" columns on the left and the +// "from" columns on the right. After the "to" and "from" columns, a commit +// name, and commit date is also present. The final column is the diff_type +// column. +// +// An example: to_pk, to_col1, to_commit, to_commit_date, from_pk, from_col1, from_commit, from_commit_date, diff_type +// +// |targetFromSchema| and |targetToSchema| defines what the schema should be for +// the row data on the "from" or "to" side. In the above example, both schemas are +// identical with two columns "pk" and "col1". The dolt diff table function for +// example can provide two different schemas. +// +// The |from| and |to| tables in the DiffPartition may have different schemas +// than |targetFromSchema| or |targetToSchema|. We convert the rows from the +// schema of |from| to |targetFromSchema| and the schema of |to| to +// |targetToSchema|. See the tablediff_prolly package. +func newProllyDiffIter(ctx *sql.Context, dp DiffPartition, ddb *doltdb.DoltDB, targetFromSchema, targetToSchema schema.Schema) (prollyDiffIter, error) { + if schema.IsKeyless(targetToSchema) { + return prollyDiffIter{}, errors.New("diffs with keyless schema have not been implemented yet") + } + + fromCm := commitInfo2{ + name: dp.fromName, + ts: (*time.Time)(dp.fromDate), + } + toCm := commitInfo2{ + name: dp.toName, + ts: (*time.Time)(dp.toDate), + } + + // dp.from may be nil + f, fSch, err := tableData(ctx, dp.from, ddb) + if err != nil { + return prollyDiffIter{}, nil + } + from := durable.ProllyMapFromIndex(f) + + t, tSch, err := tableData(ctx, dp.to, ddb) + if err != nil { + return prollyDiffIter{}, nil + } + to := durable.ProllyMapFromIndex(t) + + fromConverter, err := NewProllyRowConverter(fSch, targetFromSchema) + if err != nil { + return prollyDiffIter{}, err + } + + toConverter, err := NewProllyRowConverter(tSch, targetToSchema) + if err != nil { + return prollyDiffIter{}, err + } + + child, cancel := context.WithCancel(ctx) + + iter := prollyDiffIter{ + from: from, + to: to, + fromSch: fSch, + toSch: tSch, + targetFromSch: targetFromSchema, + targetToSch: targetToSchema, + fromConverter: fromConverter, + toConverter: toConverter, + fromCm: fromCm, + toCm: toCm, + rows: make(chan sql.Row, 64), + errChan: make(chan error), + cancel: cancel, + } + + go func() { + iter.queueRows(child) + }() + + return iter, nil +} + +func (itr prollyDiffIter) Next(ctx *sql.Context) (sql.Row, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-itr.errChan: + return nil, err + case r, ok := <-itr.rows: + if !ok { + return nil, io.EOF + } + return r, nil + } +} + +func (itr prollyDiffIter) Close(ctx *sql.Context) error { + itr.cancel() + return nil +} + +func (itr prollyDiffIter) queueRows(ctx context.Context) { + err := prolly.DiffMaps(ctx, itr.from, itr.to, func(ctx context.Context, d tree.Diff) error { + r, err := itr.makeDiffRow(d) + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case itr.rows <- r: + return nil + } + }) + if err != nil && err != io.EOF { + select { + case <-ctx.Done(): + case itr.errChan <- err: + } + return + } + // we need to drain itr.rows before returning io.EOF + close(itr.rows) +} + +// todo(andy): copy string fields +func (itr prollyDiffIter) makeDiffRow(d tree.Diff) (r sql.Row, err error) { + + n := itr.targetFromSch.GetAllCols().Size() + m := itr.targetToSch.GetAllCols().Size() + // 2 commit names, 2 commit dates, 1 diff_type + r = make(sql.Row, n+m+5) + + // todo (dhruv): implement warnings for row column value coercions. + + if d.Type != tree.RemovedDiff { + err = itr.toConverter.PutConverted(val.Tuple(d.Key), val.Tuple(d.To), r[0:n]) + if err != nil { + return nil, err + } + } + + o := n + r[o] = itr.toCm.name + r[o+1] = itr.toCm.ts + + if d.Type != tree.AddedDiff { + err = itr.fromConverter.PutConverted(val.Tuple(d.Key), val.Tuple(d.From), r[n+2:n+2+m]) + if err != nil { + return nil, err + } + } + + o = n + 2 + m + r[o] = itr.fromCm.name + r[o+1] = itr.fromCm.ts + r[o+2] = diffTypeString(d) + + return r, nil +} + +func diffTypeString(d tree.Diff) (s string) { + switch d.Type { + case tree.AddedDiff: + s = diffTypeAdded + case tree.ModifiedDiff: + s = diffTypeModified + case tree.RemovedDiff: + s = diffTypeRemoved + } + return +} diff --git a/go/libraries/doltcore/sqle/dtables/diff_table.go b/go/libraries/doltcore/sqle/dtables/diff_table.go index d545dbf16e..393241e6fc 100644 --- a/go/libraries/doltcore/sqle/dtables/diff_table.go +++ b/go/libraries/doltcore/sqle/dtables/diff_table.go @@ -56,11 +56,19 @@ type DiffTable struct { workingRoot *doltdb.RootValue head *doltdb.Commit - targetSch schema.Schema - joiner *rowconv.Joiner + // from and to need to be mapped to this schema + targetSch schema.Schema + + // the schema for the diff table itself. Once from and to are converted to + // targetSch, the commit names and dates are inserted. + diffTableSch schema.Schema + sqlSch sql.PrimaryKeySchema partitionFilters []sql.Expression rowFilters []sql.Expression + + // noms only + joiner *rowconv.Joiner } var PrimaryKeyChangeWarning = "cannot render full diff between commits %s and %s due to primary key set change" @@ -82,48 +90,27 @@ func NewDiffTable(ctx *sql.Context, tblName string, ddb *doltdb.DoltDB, root *do return nil, err } - colCollection := sch.GetAllCols() - colCollection = colCollection.Append( - schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), - schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false)) - sch = schema.MustSchemaFromCols(colCollection) - - if sch.GetAllCols().Size() <= 1 { - return nil, sql.ErrTableNotFound.New(diffTblName) - } - - j, err := rowconv.NewJoiner( - []rowconv.NamedSchema{{Name: diff.To, Sch: sch}, {Name: diff.From, Sch: sch}}, - map[string]rowconv.ColNamingFunc{ - diff.To: diff.ToColNamer, - diff.From: diff.FromColNamer, - }) + diffTableSchema, j, err := GetDiffTableSchemaAndJoiner(ddb.Format(), sch, sch) if err != nil { return nil, err } - sqlSch, err := sqlutil.FromDoltSchema(diffTblName, j.GetSchema()) + sqlSch, err := sqlutil.FromDoltSchema(diffTblName, diffTableSchema) if err != nil { return nil, err } - sqlSch.Schema = append(sqlSch.Schema, &sql.Column{ - Name: diffTypeColName, - Type: sql.Text, - Nullable: false, - Source: diffTblName, - }) - return &DiffTable{ name: tblName, ddb: ddb, workingRoot: root, head: head, targetSch: sch, - joiner: j, + diffTableSch: diffTableSchema, sqlSch: sqlSch, partitionFilters: nil, rowFilters: nil, + joiner: j, }, nil } @@ -247,97 +234,6 @@ func tableData(ctx *sql.Context, tbl *doltdb.Table, ddb *doltdb.DoltDB) (durable return data, sch, nil } -var _ sql.RowIter = (*diffRowItr)(nil) - -type diffRowItr struct { - ad diff.RowDiffer - diffSrc *diff.RowDiffSource - joiner *rowconv.Joiner - sch schema.Schema - fromCommitInfo commitInfo - toCommitInfo commitInfo -} - -type commitInfo struct { - name types.String - date *types.Timestamp - nameTag uint64 - dateTag uint64 -} - -// Next returns the next row -func (itr *diffRowItr) Next(*sql.Context) (sql.Row, error) { - r, _, err := itr.diffSrc.NextDiff() - - if err != nil { - return nil, err - } - - toAndFromRows, err := itr.joiner.Split(r) - if err != nil { - return nil, err - } - _, hasTo := toAndFromRows[diff.To] - _, hasFrom := toAndFromRows[diff.From] - - r, err = r.SetColVal(itr.toCommitInfo.nameTag, types.String(itr.toCommitInfo.name), itr.sch) - if err != nil { - return nil, err - } - - r, err = r.SetColVal(itr.fromCommitInfo.nameTag, types.String(itr.fromCommitInfo.name), itr.sch) - - if err != nil { - return nil, err - } - - if itr.toCommitInfo.date != nil { - r, err = r.SetColVal(itr.toCommitInfo.dateTag, *itr.toCommitInfo.date, itr.sch) - - if err != nil { - return nil, err - } - } - - if itr.fromCommitInfo.date != nil { - r, err = r.SetColVal(itr.fromCommitInfo.dateTag, *itr.fromCommitInfo.date, itr.sch) - - if err != nil { - return nil, err - } - } - - sqlRow, err := sqlutil.DoltRowToSqlRow(r, itr.sch) - - if err != nil { - return nil, err - } - - if hasTo && hasFrom { - sqlRow = append(sqlRow, diffTypeModified) - } else if hasTo && !hasFrom { - sqlRow = append(sqlRow, diffTypeAdded) - } else { - sqlRow = append(sqlRow, diffTypeRemoved) - } - - return sqlRow, nil -} - -// Close closes the iterator -func (itr *diffRowItr) Close(*sql.Context) (err error) { - defer itr.ad.Close() - defer func() { - closeErr := itr.diffSrc.Close() - - if err == nil { - err = closeErr - } - }() - - return nil -} - type TblInfoAtCommit struct { name string date *types.Timestamp @@ -361,8 +257,9 @@ type DiffPartition struct { fromName string toDate *types.Timestamp fromDate *types.Timestamp - toSch *schema.Schema - fromSch *schema.Schema + // fromSch and toSch are usually identical. It is the schema of the table at head. + toSch *schema.Schema + fromSch *schema.Schema } func NewDiffPartition(to, from *doltdb.Table, toName, fromName string, toDate, fromDate *types.Timestamp, toSch, fromSch *schema.Schema) *DiffPartition { @@ -383,58 +280,11 @@ func (dp DiffPartition) Key() []byte { } func (dp DiffPartition) GetRowIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joiner) (sql.RowIter, error) { - fromData, fromSch, err := tableData(ctx, dp.from, ddb) - - if err != nil { - return nil, err + if types.IsFormat_DOLT_1(ddb.Format()) { + return newProllyDiffIter(ctx, dp, ddb, *dp.fromSch, *dp.toSch) + } else { + return newNomsDiffIter(ctx, ddb, joiner, dp) } - - toData, toSch, err := tableData(ctx, dp.to, ddb) - - if err != nil { - return nil, err - } - - fromConv, err := dp.rowConvForSchema(ctx, ddb.ValueReadWriter(), *dp.fromSch, fromSch) - - if err != nil { - return nil, err - } - - toConv, err := dp.rowConvForSchema(ctx, ddb.ValueReadWriter(), *dp.toSch, toSch) - - if err != nil { - return nil, err - } - - sch := joiner.GetSchema() - toCol, _ := sch.GetAllCols().GetByName(toCommit) - fromCol, _ := sch.GetAllCols().GetByName(fromCommit) - toDateCol, _ := sch.GetAllCols().GetByName(toCommitDate) - fromDateCol, _ := sch.GetAllCols().GetByName(fromCommitDate) - - fromCmInfo := commitInfo{types.String(dp.fromName), dp.fromDate, fromCol.Tag, fromDateCol.Tag} - toCmInfo := commitInfo{types.String(dp.toName), dp.toDate, toCol.Tag, toDateCol.Tag} - - rd := diff.NewRowDiffer(ctx, fromSch, toSch, 1024) - // TODO (dhruv) don't cast to noms map - rd.Start(ctx, durable.NomsMapFromIndex(fromData), durable.NomsMapFromIndex(toData)) - - warnFn := func(code int, message string, args ...string) { - ctx.Warn(code, message, args) - } - - src := diff.NewRowDiffSource(rd, joiner, warnFn) - src.AddInputRowConversion(fromConv, toConv) - - return &diffRowItr{ - ad: rd, - diffSrc: src, - joiner: joiner, - sch: joiner.GetSchema(), - fromCommitInfo: fromCmInfo, - toCommitInfo: toCmInfo, - }, nil } // isDiffablePartition checks if the commit pair for this partition is "diffable". @@ -551,7 +401,16 @@ func (dps *DiffPartitions) processCommit(ctx *sql.Context, cmHash hash.Hash, cm var nextPartition *DiffPartition if tblHash != toInfoForCommit.tblHash { - partition := DiffPartition{toInfoForCommit.tbl, tbl, toInfoForCommit.name, cmHashStr, toInfoForCommit.date, &ts, &dps.toSch, &dps.fromSch} + partition := DiffPartition{ + to: toInfoForCommit.tbl, + from: tbl, + toName: toInfoForCommit.name, + fromName: cmHashStr, + toDate: toInfoForCommit.date, + fromDate: &ts, + fromSch: &dps.fromSch, + toSch: &dps.toSch, + } selected, err := dps.selectFunc(ctx, partition) if err != nil { @@ -636,3 +495,96 @@ func (dp DiffPartition) rowConvForSchema(ctx context.Context, vrw types.ValueRea return rowconv.NewRowConverter(ctx, vrw, fm) } + +// GetDiffTableSchemaAndJoiner returns the schema for the diff table given a +// target schema for a row |sch|. In the old storage format, it also returns the +// associated joiner. +func GetDiffTableSchemaAndJoiner(format *types.NomsBinFormat, fromTargetSch, toTargetSch schema.Schema) (diffTableSchema schema.Schema, j *rowconv.Joiner, err error) { + if format == types.Format_DOLT_1 { + diffTableSchema, err = CalculateDiffSchema(fromTargetSch, toTargetSch) + if err != nil { + return nil, nil, err + } + } else { + colCollection := toTargetSch.GetAllCols() + colCollection = colCollection.Append( + schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), + schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false)) + toTargetSch = schema.MustSchemaFromCols(colCollection) + + colCollection = fromTargetSch.GetAllCols() + colCollection = colCollection.Append( + schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), + schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false)) + fromTargetSch = schema.MustSchemaFromCols(colCollection) + + j, err = rowconv.NewJoiner( + []rowconv.NamedSchema{{Name: diff.To, Sch: toTargetSch}, {Name: diff.From, Sch: fromTargetSch}}, + map[string]rowconv.ColNamingFunc{ + diff.To: diff.ToColNamer, + diff.From: diff.FromColNamer, + }) + if err != nil { + return nil, nil, err + } + diffTableSchema = j.GetSchema() + colCollection = diffTableSchema.GetAllCols() + colCollection = colCollection.Append( + schema.NewColumn(diffTypeColName, schema.DiffTypeTag, types.StringKind, false), + ) + diffTableSchema = schema.MustSchemaFromCols(colCollection) + } + + return +} + +// CalculateDiffSchema returns the schema for the dolt_diff table based on the +// schemas from the from and to tables. +func CalculateDiffSchema(fromSch schema.Schema, toSch schema.Schema) (schema.Schema, error) { + colCollection := fromSch.GetAllCols() + colCollection = colCollection.Append( + schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), + schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false)) + fromSch = schema.MustSchemaFromCols(colCollection) + + colCollection = toSch.GetAllCols() + colCollection = colCollection.Append( + schema.NewColumn("commit", schema.DiffCommitTag, types.StringKind, false), + schema.NewColumn("commit_date", schema.DiffCommitDateTag, types.TimestampKind, false)) + toSch = schema.MustSchemaFromCols(colCollection) + + cols := make([]schema.Column, toSch.GetAllCols().Size()+fromSch.GetAllCols().Size()+1) + + i := 0 + err := toSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + toCol, err := schema.NewColumnWithTypeInfo("to_"+col.Name, uint64(i), col.TypeInfo, false, col.Default, false, col.Comment) + if err != nil { + return true, err + } + cols[i] = toCol + i++ + return false, nil + }) + if err != nil { + return nil, err + } + + j := toSch.GetAllCols().Size() + err = fromSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + fromCol, err := schema.NewColumnWithTypeInfo("from_"+col.Name, uint64(i), col.TypeInfo, false, col.Default, false, col.Comment) + if err != nil { + return true, err + } + cols[j] = fromCol + + j++ + return false, nil + }) + if err != nil { + return nil, err + } + + cols[len(cols)-1] = schema.NewColumn("diff_type", schema.DiffTypeTag, types.StringKind, false) + + return schema.UnkeyedSchemaFromCols(schema.NewColCollection(cols...)), nil +} diff --git a/go/libraries/doltcore/sqle/dtables/prolly_row_conv.go b/go/libraries/doltcore/sqle/dtables/prolly_row_conv.go new file mode 100644 index 0000000000..e05d0fb3d2 --- /dev/null +++ b/go/libraries/doltcore/sqle/dtables/prolly_row_conv.go @@ -0,0 +1,168 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dtables + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/val" +) + +// ProllyRowConverter can be used to convert key, value val.Tuple's from |inSchema| +// to |outSchema|. Columns are matched based on names and primary key +// membership. The output of the conversion process is a sql.Row. +type ProllyRowConverter struct { + inSchema schema.Schema + outSchema schema.Schema + keyProj, valProj val.OrdinalMapping + keyDesc val.TupleDesc + valDesc val.TupleDesc + pkTargetTypes []sql.Type + nonPkTargetTypes []sql.Type +} + +func NewProllyRowConverter(inSch, outSch schema.Schema) (ProllyRowConverter, error) { + keyProj, valProj, err := MapSchemaBasedOnName(inSch, outSch) + if err != nil { + return ProllyRowConverter{}, err + } + + pkTargetTypes := make([]sql.Type, inSch.GetPKCols().Size()) + nonPkTargetTypes := make([]sql.Type, inSch.GetNonPKCols().Size()) + + // Populate pkTargetTypes and nonPkTargetTypes with non-nil sql.Type if we need to do a type conversion + for i, j := range keyProj { + if j == -1 { + continue + } + inColType := inSch.GetPKCols().GetByIndex(i).TypeInfo.ToSqlType() + outColType := outSch.GetAllCols().GetByIndex(j).TypeInfo.ToSqlType() + if !inColType.Equals(outColType) { + pkTargetTypes[i] = outColType + } + } + + for i, j := range valProj { + if j == -1 { + continue + } + inColType := inSch.GetNonPKCols().GetByIndex(i).TypeInfo.ToSqlType() + outColType := outSch.GetAllCols().GetByIndex(j).TypeInfo.ToSqlType() + if !inColType.Equals(outColType) { + nonPkTargetTypes[i] = outColType + } + } + + kd, vd := prolly.MapDescriptorsFromScheam(inSch) + return ProllyRowConverter{ + inSchema: inSch, + outSchema: outSch, + keyProj: keyProj, + valProj: valProj, + keyDesc: kd, + valDesc: vd, + pkTargetTypes: pkTargetTypes, + nonPkTargetTypes: nonPkTargetTypes, + }, nil +} + +// PutConverted converts the |key| and |value| val.Tuple from |inSchema| to |outSchema| +// and places the converted row in |dstRow|. +func (c ProllyRowConverter) PutConverted(key, value val.Tuple, dstRow []interface{}) error { + for i, j := range c.keyProj { + if j == -1 { + continue + } + f, err := index.GetField(c.keyDesc, i, key) + if err != nil { + return err + } + if t := c.pkTargetTypes[i]; t != nil { + dstRow[j], err = t.Convert(f) + if err != nil { + return err + } + } else { + dstRow[j] = f + } + } + + for i, j := range c.valProj { + if j == -1 { + continue + } + f, err := index.GetField(c.valDesc, i, value) + if err != nil { + return err + } + if t := c.nonPkTargetTypes[i]; t != nil { + dstRow[j], err = t.Convert(f) + if err != nil { + return err + } + } else { + dstRow[j] = f + } + } + + return nil +} + +// MapSchemaBasedOnName can be used to map column values from one schema to +// another schema. A column in |inSch| is mapped to |outSch| if they share the +// same name and primary key membership status. It returns ordinal mappings that +// can be use to map key, value val.Tuple's of schema |inSch| to a sql.Row of +// |outSch|. The first ordinal map is for keys, and the second is for values. If +// a column of |inSch| is missing in |outSch| then that column's index in the +// ordinal map holds -1. +func MapSchemaBasedOnName(inSch, outSch schema.Schema) (val.OrdinalMapping, val.OrdinalMapping, error) { + keyMapping := make(val.OrdinalMapping, inSch.GetPKCols().Size()) + valMapping := make(val.OrdinalMapping, inSch.GetNonPKCols().Size()) + + err := inSch.GetPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + i := inSch.GetPKCols().TagToIdx[tag] + if col, ok := outSch.GetPKCols().GetByName(col.Name); ok { + j := outSch.GetAllCols().TagToIdx[col.Tag] + keyMapping[i] = j + } else { + return true, fmt.Errorf("could not map primary key column %s", col.Name) + } + return false, nil + }) + if err != nil { + return nil, nil, err + } + + err = inSch.GetNonPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + i := inSch.GetNonPKCols().TagToIdx[col.Tag] + if col, ok := outSch.GetNonPKCols().GetByName(col.Name); ok { + j := outSch.GetAllCols().TagToIdx[col.Tag] + valMapping[i] = j + } else { + valMapping[i] = -1 + } + return false, nil + }) + if err != nil { + return nil, nil, err + } + + return keyMapping, valMapping, nil +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index c1fe8229f9..725c519d76 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -743,7 +743,6 @@ func TestHistorySystemTable(t *testing.T) { } func TestUnscopedDiffSystemTable(t *testing.T) { - skipNewFormat(t) harness := newDoltHarness(t) for _, test := range UnscopedDiffSystemTableScriptTests { databases := harness.NewDatabases("mydb") @@ -831,6 +830,10 @@ func TestQueriesPrepared(t *testing.T) { enginetest.TestQueriesPrepared(t, newDoltHarness(t)) } +func TestPreparedStaticIndexQuery(t *testing.T) { + enginetest.TestPreparedStaticIndexQuery(t, newDoltHarness(t)) +} + func TestSpatialQueriesPrepared(t *testing.T) { skipNewFormat(t) skipPreparedTests(t) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 78bc8ab10a..c71c9fa346 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -1414,6 +1414,21 @@ var DiffSystemTableScriptTests = []enginetest.ScriptTest{ }, }, }, + { + Name: "table with commit column should maintain its data in diff", + SetUpScript: []string{ + "CREATE TABLE t (pk int PRIMARY KEY, commit text);", + "CALL dolt_commit('-am', 'creating table t');", + "INSERT INTO t VALUES (1, 'hi');", + "CALL dolt_commit('-am', 'insert data');", + }, + Assertions: []enginetest.ScriptTestAssertion{ + { + Query: "SELECT to_pk, to_commit, from_pk, from_commit, diff_type from dolt_diff_t;", + Expected: []sql.Row{{1, "hi", nil, nil, "added"}}, + }, + }, + }, } var DiffTableFunctionScriptTests = []enginetest.ScriptTest{ @@ -1774,6 +1789,21 @@ var DiffTableFunctionScriptTests = []enginetest.ScriptTest{ }, }, }, + { + Name: "table with commit column should maintain its data in diff", + SetUpScript: []string{ + "CREATE TABLE t (pk int PRIMARY KEY, commit text);", + "set @Commit1 = dolt_commit('-am', 'creating table t');", + "INSERT INTO t VALUES (1, 'hi');", + "set @Commit2 = dolt_commit('-am', 'insert data');", + }, + Assertions: []enginetest.ScriptTestAssertion{ + { + Query: "SELECT to_pk, to_commit, from_pk, from_commit, diff_type from dolt_diff('t', @Commit1, @Commit2);", + Expected: []sql.Row{{1, "hi", nil, nil, "added"}}, + }, + }, + }, } var UnscopedDiffSystemTableScriptTests = []enginetest.ScriptTest{ diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go index 5576430e50..c9c175464d 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go @@ -25,6 +25,36 @@ import ( ) var DoltTransactionTests = []enginetest.TransactionTest{ + { + // Repro for https://github.com/dolthub/dolt/issues/3402 + Name: "DDL changes from transactions are available before analyzing statements in other sessions (autocommit on)", + Assertions: []enginetest.ScriptTestAssertion{ + { + Query: "/* client a */ select @@autocommit;", + Expected: []sql.Row{{1}}, + }, + { + Query: "/* client b */ select @@autocommit;", + Expected: []sql.Row{{1}}, + }, + { + Query: "/* client a */ select * from t;", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "/* client b */ select * from t;", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "/* client a */ create table t(pk int primary key);", + Expected: []sql.Row{{sql.OkResult{}}}, + }, + { + Query: "/* client b */ select count(*) from t;", + Expected: []sql.Row{{0}}, + }, + }, + }, { Name: "duplicate inserts, autocommit on", SetUpScript: []string{ diff --git a/go/libraries/doltcore/sqle/index/dolt_index.go b/go/libraries/doltcore/sqle/index/dolt_index.go index 51823e7262..b33eee13ec 100644 --- a/go/libraries/doltcore/sqle/index/dolt_index.go +++ b/go/libraries/doltcore/sqle/index/dolt_index.go @@ -34,9 +34,8 @@ type DoltIndex interface { sql.FilteredIndex Schema() schema.Schema IndexSchema() schema.Schema - TableData() durable.Index - IndexRowData() durable.Index Format() *types.NomsBinFormat + GetDurableIndexes(*sql.Context, *doltdb.Table) (durable.Index, durable.Index, error) } func DoltIndexesFromTable(ctx context.Context, db, tbl string, t *doltdb.Table) (indexes []sql.Index, err error) { @@ -69,23 +68,21 @@ func getPrimaryKeyIndex(ctx context.Context, db, tbl string, t *doltdb.Table, sc if err != nil { return nil, err } - - cols := sch.GetPKCols().GetColumns() keyBld := maybeGetKeyBuilder(tableRows) + cols := sch.GetPKCols().GetColumns() + return doltIndex{ - id: "PRIMARY", - tblName: tbl, - dbName: db, - columns: cols, - indexSch: sch, - tableSch: sch, - unique: true, - comment: "", - indexRows: tableRows, - tableRows: tableRows, - vrw: t.ValueReadWriter(), - keyBld: keyBld, + id: "PRIMARY", + tblName: tbl, + dbName: db, + columns: cols, + indexSch: sch, + tableSch: sch, + unique: true, + comment: "", + vrw: t.ValueReadWriter(), + keyBld: keyBld, }, nil } @@ -94,32 +91,24 @@ func getSecondaryIndex(ctx context.Context, db, tbl string, t *doltdb.Table, sch if err != nil { return nil, err } - - tableRows, err := t.GetRowData(ctx) - if err != nil { - return nil, err - } + keyBld := maybeGetKeyBuilder(indexRows) cols := make([]schema.Column, idx.Count()) for i, tag := range idx.IndexedColumnTags() { cols[i], _ = idx.GetColumn(tag) } - keyBld := maybeGetKeyBuilder(indexRows) - return doltIndex{ - id: idx.Name(), - tblName: tbl, - dbName: db, - columns: cols, - indexSch: idx.Schema(), - tableSch: sch, - unique: idx.IsUnique(), - comment: idx.Comment(), - indexRows: indexRows, - tableRows: tableRows, - vrw: t.ValueReadWriter(), - keyBld: keyBld, + id: idx.Name(), + tblName: tbl, + dbName: db, + columns: cols, + indexSch: idx.Schema(), + tableSch: sch, + unique: idx.IsUnique(), + comment: idx.Comment(), + vrw: t.ValueReadWriter(), + keyBld: keyBld, }, nil } @@ -130,12 +119,10 @@ type doltIndex struct { columns []schema.Column - indexSch schema.Schema - tableSch schema.Schema - indexRows durable.Index - tableRows durable.Index - unique bool - comment string + indexSch schema.Schema + tableSch schema.Schema + unique bool + comment string vrw types.ValueReadWriter keyBld *val.TupleBuilder @@ -168,6 +155,22 @@ func (di doltIndex) NewLookup(ctx *sql.Context, ranges ...sql.Range) (sql.IndexL return di.newNomsLookup(ctx, ranges...) } +func (di doltIndex) GetDurableIndexes(ctx *sql.Context, t *doltdb.Table) (primary, secondary durable.Index, err error) { + primary, err = t.GetRowData(ctx) + if err != nil { + return nil, nil, err + } + if di.ID() == "PRIMARY" { + secondary = primary + } else { + secondary, err = t.GetIndexRowData(ctx, di.ID()) + if err != nil { + return nil, nil, err + } + } + return +} + func (di doltIndex) newProllyLookup(ctx *sql.Context, ranges ...sql.Range) (sql.IndexLookup, error) { var err error sqlRanges, err := pruneEmptyRanges(ranges) @@ -357,16 +360,6 @@ func (di doltIndex) Table() string { return di.tblName } -// TableData returns the map of Table data for this index (the map of the target Table, not the index storage Table) -func (di doltIndex) TableData() durable.Index { - return di.tableRows -} - -// IndexRowData returns the map of index row data. -func (di doltIndex) IndexRowData() durable.Index { - return di.indexRows -} - func (di doltIndex) Format() *types.NomsBinFormat { return di.vrw.Format() } diff --git a/go/libraries/doltcore/sqle/index/dolt_index_test.go b/go/libraries/doltcore/sqle/index/dolt_index_test.go index b04e0a011e..6945aee308 100644 --- a/go/libraries/doltcore/sqle/index/dolt_index_test.go +++ b/go/libraries/doltcore/sqle/index/dolt_index_test.go @@ -27,6 +27,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/stretchr/testify/require" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" @@ -180,7 +181,7 @@ var ( ) func TestDoltIndexEqual(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []doltIndexTestCase{ { @@ -296,13 +297,13 @@ func TestDoltIndexEqual(t *testing.T) { t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) { idx, ok := indexMap[test.indexName] require.True(t, ok) - testDoltIndex(t, test.keys, test.expectedRows, idx, indexComp_Eq) + testDoltIndex(t, ctx, root, test.keys, test.expectedRows, idx, indexComp_Eq) }) } } func TestDoltIndexGreaterThan(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []struct { indexName string @@ -437,13 +438,13 @@ func TestDoltIndexGreaterThan(t *testing.T) { t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) { index, ok := indexMap[test.indexName] require.True(t, ok) - testDoltIndex(t, test.keys, test.expectedRows, index, indexComp_Gt) + testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Gt) }) } } func TestDoltIndexGreaterThanOrEqual(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []struct { indexName string @@ -574,13 +575,13 @@ func TestDoltIndexGreaterThanOrEqual(t *testing.T) { t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) { index, ok := indexMap[test.indexName] require.True(t, ok) - testDoltIndex(t, test.keys, test.expectedRows, index, indexComp_GtE) + testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_GtE) }) } } func TestDoltIndexLessThan(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []struct { indexName string @@ -720,13 +721,13 @@ func TestDoltIndexLessThan(t *testing.T) { t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) { index, ok := indexMap[test.indexName] require.True(t, ok) - testDoltIndex(t, test.keys, test.expectedRows, index, indexComp_Lt) + testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Lt) }) } } func TestDoltIndexLessThanOrEqual(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []struct { indexName string @@ -867,13 +868,13 @@ func TestDoltIndexLessThanOrEqual(t *testing.T) { t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) { index, ok := indexMap[test.indexName] require.True(t, ok) - testDoltIndex(t, test.keys, test.expectedRows, index, indexComp_LtE) + testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_LtE) }) } } func TestDoltIndexBetween(t *testing.T) { - indexMap := doltIndexSetup(t) + ctx, root, indexMap := doltIndexSetup(t) tests := []doltIndexBetweenTestCase{ { @@ -1043,7 +1044,6 @@ func TestDoltIndexBetween(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("%s|%v%v", test.indexName, test.greaterThanOrEqual, test.lessThanOrEqual), func(t *testing.T) { - ctx := NewTestSQLCtx(context.Background()) idx, ok := indexMap[test.indexName] require.True(t, ok) @@ -1060,7 +1060,11 @@ func TestDoltIndexBetween(t *testing.T) { pkSch, err := sqlutil.FromDoltSchema("fake_table", idx.Schema()) require.NoError(t, err) - indexIter, err := index.RowIterForIndexLookup(ctx, indexLookup, pkSch, nil) + dt, ok, err := root.GetTable(ctx, idx.Table()) + require.NoError(t, err) + require.True(t, ok) + + indexIter, err := index.RowIterForIndexLookup(ctx, dt, indexLookup, pkSch, nil) require.NoError(t, err) // If this is a primary index assert that a covering index was used @@ -1260,8 +1264,7 @@ func requireUnorderedRowsEqual(t *testing.T, rows1, rows2 []sql.Row) { require.Equal(t, rows1, rows2) } -func testDoltIndex(t *testing.T, keys []interface{}, expectedRows []sql.Row, idx index.DoltIndex, cmp indexComp) { - ctx := NewTestSQLCtx(context.Background()) +func testDoltIndex(t *testing.T, ctx *sql.Context, root *doltdb.RootValue, keys []interface{}, expectedRows []sql.Row, idx index.DoltIndex, cmp indexComp) { exprs := idx.Expressions() builder := sql.NewIndexBuilder(sql.NewEmptyContext(), idx) for i, key := range keys { @@ -1285,10 +1288,14 @@ func testDoltIndex(t *testing.T, keys []interface{}, expectedRows []sql.Row, idx indexLookup, err := builder.Build(ctx) require.NoError(t, err) + dt, ok, err := root.GetTable(ctx, idx.Table()) + require.NoError(t, err) + require.True(t, ok) + pkSch, err := sqlutil.FromDoltSchema("fake_table", idx.Schema()) require.NoError(t, err) - indexIter, err := index.RowIterForIndexLookup(ctx, indexLookup, pkSch, nil) + indexIter, err := index.RowIterForIndexLookup(ctx, dt, indexLookup, pkSch, nil) require.NoError(t, err) var readRows []sql.Row @@ -1301,7 +1308,7 @@ func testDoltIndex(t *testing.T, keys []interface{}, expectedRows []sql.Row, idx requireUnorderedRowsEqual(t, convertSqlRowToInt64(expectedRows), readRows) } -func doltIndexSetup(t *testing.T) map[string]index.DoltIndex { +func doltIndexSetup(t *testing.T) (*sql.Context, *doltdb.RootValue, map[string]index.DoltIndex) { ctx := NewTestSQLCtx(context.Background()) dEnv := dtestutils.CreateTestEnv() root, err := dEnv.WorkingRoot(ctx) @@ -1374,7 +1381,7 @@ INSERT INTO types VALUES (1, 4, '2020-05-14 12:00:03', 1.1, 'd', 1.1, 'a,c', '00 } } - return indexMap + return ctx, root, indexMap } func NewTestSQLCtx(ctx context.Context) *sql.Context { diff --git a/go/libraries/doltcore/sqle/index/index_lookup.go b/go/libraries/doltcore/sqle/index/index_lookup.go index 1bad6de420..d2c8ca1dc6 100644 --- a/go/libraries/doltcore/sqle/index/index_lookup.go +++ b/go/libraries/doltcore/sqle/index/index_lookup.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" "github.com/dolthub/dolt/go/libraries/doltcore/row" "github.com/dolthub/dolt/go/libraries/doltcore/schema" @@ -35,44 +36,49 @@ func PartitionIndexedTableRows(ctx *sql.Context, idx sql.Index, part sql.Partiti rp := part.(rangePartition) doltIdx := idx.(DoltIndex) - if types.IsFormat_DOLT_1(rp.rows.Format()) { - return RowIterForProllyRange(ctx, doltIdx, rp.prollyRange, pkSch, columns) + if types.IsFormat_DOLT_1(rp.primary.Format()) { + return RowIterForProllyRange(ctx, doltIdx, rp.prollyRange, pkSch, columns, rp.primary, rp.secondary) } ranges := []*noms.ReadRange{rp.nomsRange} - return RowIterForNomsRanges(ctx, doltIdx, ranges, rp.rows, columns) + return RowIterForNomsRanges(ctx, doltIdx, ranges, columns, rp.primary, rp.secondary) } -func RowIterForIndexLookup(ctx *sql.Context, ilu sql.IndexLookup, pkSch sql.PrimaryKeySchema, columns []string) (sql.RowIter, error) { +func RowIterForIndexLookup(ctx *sql.Context, t *doltdb.Table, ilu sql.IndexLookup, pkSch sql.PrimaryKeySchema, columns []string) (sql.RowIter, error) { lookup := ilu.(*doltIndexLookup) idx := lookup.idx + primary, secondary, err := idx.GetDurableIndexes(ctx, t) + if err != nil { + return nil, err + } + if types.IsFormat_DOLT_1(idx.Format()) { // todo(andy) - return RowIterForProllyRange(ctx, idx, lookup.prollyRanges[0], pkSch, columns) + return RowIterForProllyRange(ctx, idx, lookup.prollyRanges[0], pkSch, columns, primary, secondary) } else { - return RowIterForNomsRanges(ctx, idx, lookup.nomsRanges, lookup.IndexRowData(), columns) + return RowIterForNomsRanges(ctx, idx, lookup.nomsRanges, columns, primary, secondary) } } -func RowIterForProllyRange(ctx *sql.Context, idx DoltIndex, ranges prolly.Range, pkSch sql.PrimaryKeySchema, columns []string) (sql.RowIter2, error) { +func RowIterForProllyRange(ctx *sql.Context, idx DoltIndex, ranges prolly.Range, pkSch sql.PrimaryKeySchema, columns []string, primary, secondary durable.Index) (sql.RowIter2, error) { covers := indexCoversCols(idx, columns) if covers { - return newProllyCoveringIndexIter(ctx, idx, ranges, pkSch) + return newProllyCoveringIndexIter(ctx, idx, ranges, pkSch, secondary) } else { - return newProllyIndexIter(ctx, idx, ranges) + return newProllyIndexIter(ctx, idx, ranges, primary, secondary) } } -func RowIterForNomsRanges(ctx *sql.Context, idx DoltIndex, ranges []*noms.ReadRange, rowData durable.Index, columns []string) (sql.RowIter, error) { - m := durable.NomsMapFromIndex(rowData) +func RowIterForNomsRanges(ctx *sql.Context, idx DoltIndex, ranges []*noms.ReadRange, columns []string, primary, secondary durable.Index) (sql.RowIter, error) { + m := durable.NomsMapFromIndex(secondary) nrr := noms.NewNomsRangeReader(idx.IndexSchema(), m, ranges) covers := indexCoversCols(idx, columns) if covers || idx.ID() == "PRIMARY" { return NewCoveringIndexRowIterAdapter(ctx, idx, nrr, columns), nil } else { - return NewIndexLookupRowIterAdapter(ctx, idx, nrr) + return NewIndexLookupRowIterAdapter(ctx, idx, primary, nrr) } } @@ -114,15 +120,20 @@ func DoltIndexFromLookup(lookup sql.IndexLookup) DoltIndex { return lookup.(*doltIndexLookup).idx } -func NewRangePartitionIter(lookup sql.IndexLookup) sql.PartitionIter { +func NewRangePartitionIter(ctx *sql.Context, t *doltdb.Table, lookup sql.IndexLookup) (sql.PartitionIter, error) { dlu := lookup.(*doltIndexLookup) + primary, secondary, err := dlu.idx.GetDurableIndexes(ctx, t) + if err != nil { + return nil, err + } return &rangePartitionIter{ nomsRanges: dlu.nomsRanges, prollyRanges: dlu.prollyRanges, curr: 0, mu: &sync.Mutex{}, - rowData: dlu.IndexRowData(), - } + secondary: secondary, + primary: primary, + }, nil } type rangePartitionIter struct { @@ -130,7 +141,10 @@ type rangePartitionIter struct { prollyRanges []prolly.Range curr int mu *sync.Mutex - rowData durable.Index + // the rows of the table the index references + primary durable.Index + // the rows of the index itself + secondary durable.Index } // Close is required by the sql.PartitionIter interface. Does nothing. @@ -143,7 +157,7 @@ func (itr *rangePartitionIter) Next(_ *sql.Context) (sql.Partition, error) { itr.mu.Lock() defer itr.mu.Unlock() - if types.IsFormat_DOLT_1(itr.rowData.Format()) { + if types.IsFormat_DOLT_1(itr.secondary.Format()) { return itr.nextProllyPartition() } return itr.nextNomsPartition() @@ -162,7 +176,8 @@ func (itr *rangePartitionIter) nextProllyPartition() (sql.Partition, error) { return rangePartition{ prollyRange: pr, key: bytes[:], - rows: itr.rowData, + primary: itr.primary, + secondary: itr.secondary, }, nil } @@ -179,7 +194,8 @@ func (itr *rangePartitionIter) nextNomsPartition() (sql.Partition, error) { return rangePartition{ nomsRange: nr, key: bytes[:], - rows: itr.rowData, + primary: itr.primary, + secondary: itr.secondary, }, nil } @@ -187,7 +203,10 @@ type rangePartition struct { nomsRange *noms.ReadRange prollyRange prolly.Range key []byte - rows durable.Index + // the rows of the table the index refers to + primary durable.Index + // the index entries + secondary durable.Index } func (rp rangePartition) Key() []byte { @@ -237,10 +256,6 @@ func (il *doltIndexLookup) String() string { return fmt.Sprintf("doltIndexLookup:%s", il.idx.ID()) } -func (il *doltIndexLookup) IndexRowData() durable.Index { - return il.idx.IndexRowData() -} - // Index implements the interface sql.IndexLookup func (il *doltIndexLookup) Index() sql.Index { return il.idx diff --git a/go/libraries/doltcore/sqle/index/noms_index_iter.go b/go/libraries/doltcore/sqle/index/noms_index_iter.go index 8388a7741c..526ac503c9 100644 --- a/go/libraries/doltcore/sqle/index/noms_index_iter.go +++ b/go/libraries/doltcore/sqle/index/noms_index_iter.go @@ -55,7 +55,7 @@ type indexLookupRowIterAdapter struct { } // NewIndexLookupRowIterAdapter returns a new indexLookupRowIterAdapter. -func NewIndexLookupRowIterAdapter(ctx *sql.Context, idx DoltIndex, keyIter nomsKeyIter) (*indexLookupRowIterAdapter, error) { +func NewIndexLookupRowIterAdapter(ctx *sql.Context, idx DoltIndex, tableData durable.Index, keyIter nomsKeyIter) (*indexLookupRowIterAdapter, error) { lookupTags := make(map[uint64]int) for i, tag := range idx.Schema().GetPKCols().Tags { lookupTags[tag] = i @@ -66,7 +66,7 @@ func NewIndexLookupRowIterAdapter(ctx *sql.Context, idx DoltIndex, keyIter nomsK lookupTags[schema.KeylessRowIdTag] = 0 } - rows := durable.NomsMapFromIndex(idx.TableData()) + rows := durable.NomsMapFromIndex(tableData) conv := NewKVToSqlRowConverterForCols(idx.Format(), idx.Schema()) resBuf := resultBufferPool.Get().(*async.RingBuffer) diff --git a/go/libraries/doltcore/sqle/index/prolly_index_iter.go b/go/libraries/doltcore/sqle/index/prolly_index_iter.go index 12908e0b96..959f5afac6 100644 --- a/go/libraries/doltcore/sqle/index/prolly_index_iter.go +++ b/go/libraries/doltcore/sqle/index/prolly_index_iter.go @@ -50,14 +50,14 @@ var _ sql.RowIter = prollyIndexIter{} var _ sql.RowIter2 = prollyIndexIter{} // NewProllyIndexIter returns a new prollyIndexIter. -func newProllyIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range) (prollyIndexIter, error) { - secondary := durable.ProllyMapFromIndex(idx.IndexRowData()) +func newProllyIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range, dprimary, dsecondary durable.Index) (prollyIndexIter, error) { + secondary := durable.ProllyMapFromIndex(dsecondary) indexIter, err := secondary.IterRange(ctx, rng) if err != nil { return prollyIndexIter{}, err } - primary := durable.ProllyMapFromIndex(idx.TableData()) + primary := durable.ProllyMapFromIndex(dprimary) kd, _ := primary.Descriptors() pkBld := val.NewTupleBuilder(kd) pkMap := ordinalMappingFromIndex(idx) @@ -211,8 +211,8 @@ type prollyCoveringIndexIter struct { var _ sql.RowIter = prollyCoveringIndexIter{} var _ sql.RowIter2 = prollyCoveringIndexIter{} -func newProllyCoveringIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range, pkSch sql.PrimaryKeySchema) (prollyCoveringIndexIter, error) { - secondary := durable.ProllyMapFromIndex(idx.IndexRowData()) +func newProllyCoveringIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range, pkSch sql.PrimaryKeySchema, indexdata durable.Index) (prollyCoveringIndexIter, error) { + secondary := durable.ProllyMapFromIndex(indexdata) indexIter, err := secondary.IterRange(ctx, rng) if err != nil { return prollyCoveringIndexIter{}, err diff --git a/go/libraries/doltcore/sqle/indexed_dolt_table.go b/go/libraries/doltcore/sqle/indexed_dolt_table.go index 794f36916a..c783317d0d 100644 --- a/go/libraries/doltcore/sqle/indexed_dolt_table.go +++ b/go/libraries/doltcore/sqle/indexed_dolt_table.go @@ -52,17 +52,28 @@ func (idt *IndexedDoltTable) Schema() sql.Schema { } func (idt *IndexedDoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { - rows := index.DoltIndexFromLookup(idt.indexLookup).IndexRowData() - return index.SinglePartitionIterFromNomsMap(rows), nil + dt, err := idt.table.doltTable(ctx) + if err != nil { + return nil, err + } + return index.NewRangePartitionIter(ctx, dt, idt.indexLookup) } func (idt *IndexedDoltTable) PartitionRows(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { // todo(andy): only used by 'AS OF` queries - return index.RowIterForIndexLookup(ctx, idt.indexLookup, idt.table.sqlSch, nil) + dt, err := idt.table.doltTable(ctx) + if err != nil { + return nil, err + } + return index.RowIterForIndexLookup(ctx, dt, idt.indexLookup, idt.table.sqlSch, nil) } func (idt *IndexedDoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { - return index.RowIterForIndexLookup(ctx, idt.indexLookup, idt.table.sqlSch, nil) + dt, err := idt.table.doltTable(ctx) + if err != nil { + return nil, err + } + return index.RowIterForIndexLookup(ctx, dt, idt.indexLookup, idt.table.sqlSch, nil) } func (idt *IndexedDoltTable) IsTemporary() bool { @@ -84,7 +95,11 @@ type WritableIndexedDoltTable struct { var _ sql.Table2 = (*WritableIndexedDoltTable)(nil) func (t *WritableIndexedDoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { - return index.NewRangePartitionIter(t.indexLookup), nil + dt, err := t.doltTable(ctx) + if err != nil { + return nil, err + } + return index.NewRangePartitionIter(ctx, dt, t.indexLookup) } func (t *WritableIndexedDoltTable) PartitionRows(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { diff --git a/go/libraries/doltcore/sqle/procedures_table.go b/go/libraries/doltcore/sqle/procedures_table.go index 381f2049b1..2c32ef4227 100644 --- a/go/libraries/doltcore/sqle/procedures_table.go +++ b/go/libraries/doltcore/sqle/procedures_table.go @@ -141,7 +141,12 @@ func DoltProceduresGetAll(ctx *sql.Context, db Database) ([]sql.StoredProcedureD return nil, err } - iter, err := index.RowIterForIndexLookup(ctx, lookup, tbl.sqlSch, nil) + dt, err := tbl.doltTable(ctx) + if err != nil { + return nil, err + } + + iter, err := index.RowIterForIndexLookup(ctx, dt, lookup, tbl.sqlSch, nil) if err != nil { return nil, err } @@ -265,7 +270,12 @@ func DoltProceduresGetDetails(ctx *sql.Context, tbl *WritableDoltTable, name str return sql.StoredProcedureDetails{}, false, err } - rowIter, err := index.RowIterForIndexLookup(ctx, indexLookup, tbl.sqlSch, nil) + dt, err := tbl.doltTable(ctx) + if err != nil { + return sql.StoredProcedureDetails{}, false, err + } + + rowIter, err := index.RowIterForIndexLookup(ctx, dt, indexLookup, tbl.sqlSch, nil) if err != nil { return sql.StoredProcedureDetails{}, false, err } diff --git a/go/libraries/doltcore/sqle/schema_table.go b/go/libraries/doltcore/sqle/schema_table.go index 8314fa4a66..988bfa89b5 100644 --- a/go/libraries/doltcore/sqle/schema_table.go +++ b/go/libraries/doltcore/sqle/schema_table.go @@ -279,7 +279,12 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str return nil, false, err } - iter, err := index.RowIterForIndexLookup(ctx, lookup, tbl.sqlSch, nil) + dt, err := tbl.doltTable(ctx) + if err != nil { + return nil, false, err + } + + iter, err := index.RowIterForIndexLookup(ctx, dt, lookup, tbl.sqlSch, nil) if err != nil { return nil, false, err } diff --git a/go/libraries/doltcore/sqle/sqlselect_test.go b/go/libraries/doltcore/sqle/sqlselect_test.go index da22b3bd38..bff83d1904 100644 --- a/go/libraries/doltcore/sqle/sqlselect_test.go +++ b/go/libraries/doltcore/sqle/sqlselect_test.go @@ -795,7 +795,7 @@ var sqlDiffSchema = sql.Schema{ &sql.Column{Name: "from_first_name", Type: typeinfo.StringDefaultType.ToSqlType()}, &sql.Column{Name: "from_last_name", Type: typeinfo.StringDefaultType.ToSqlType()}, &sql.Column{Name: "from_addr", Type: typeinfo.StringDefaultType.ToSqlType()}, - &sql.Column{Name: "diff_type", Type: sql.Text}, + &sql.Column{Name: "diff_type", Type: typeinfo.StringDefaultType.ToSqlType()}, } var SelectDiffTests = []SelectTest{ diff --git a/go/libraries/doltcore/sqle/system_variables.go b/go/libraries/doltcore/sqle/system_variables.go index 2ab369db27..a310e100b9 100644 --- a/go/libraries/doltcore/sqle/system_variables.go +++ b/go/libraries/doltcore/sqle/system_variables.go @@ -19,7 +19,6 @@ import ( ) const ( - DefaultBranchKey = "dolt_default_branch" ReplicateToRemoteKey = "dolt_replicate_to_remote" ReadReplicaRemoteKey = "dolt_read_replica_remote" SkipReplicationErrorsKey = "dolt_skip_replication_errors" @@ -39,14 +38,6 @@ func init() { func AddDoltSystemVariables() { sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ - { - Name: DefaultBranchKey, - Scope: sql.SystemVariableScope_Global, - Dynamic: true, - SetVarHintApplies: false, - Type: sql.NewSystemStringType(DefaultBranchKey), - Default: "", - }, { Name: ReplicateToRemoteKey, Scope: sql.SystemVariableScope_Global, diff --git a/go/libraries/doltcore/sqle/temp_table.go b/go/libraries/doltcore/sqle/temp_table.go index b9155274dc..c0d7c2c5b7 100644 --- a/go/libraries/doltcore/sqle/temp_table.go +++ b/go/libraries/doltcore/sqle/temp_table.go @@ -150,7 +150,7 @@ func (t *TempTable) DataLength(ctx *sql.Context) (uint64, error) { func (t *TempTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { if t.lookup != nil { - return index.RowIterForIndexLookup(ctx, t.lookup, t.pkSch, nil) + return index.RowIterForIndexLookup(ctx, t.table, t.lookup, t.pkSch, nil) } else { return partitionRows(ctx, t.table, nil, partition) } diff --git a/go/store/prolly/benchmark/benchmark_read_test.go b/go/store/prolly/benchmark/benchmark_read_test.go index 83c0bde1b9..a1351955f2 100644 --- a/go/store/prolly/benchmark/benchmark_read_test.go +++ b/go/store/prolly/benchmark/benchmark_read_test.go @@ -38,7 +38,19 @@ func BenchmarkMapGet(b *testing.B) { }) } -func BenchmarkMapGetParallel(b *testing.B) { +func BenchmarkStepMapGet(b *testing.B) { + b.Skip() + step := uint64(100_000) + for sz := step; sz < step*20; sz += step { + nm := fmt.Sprintf("benchmark maps %d", sz) + b.Run(nm, func(b *testing.B) { + benchmarkProllyMapGet(b, sz) + benchmarkTypesMapGet(b, sz) + }) + } +} + +func BenchmarkParallelMapGet(b *testing.B) { b.Run("benchmark maps 10k", func(b *testing.B) { benchmarkProllyMapGetParallel(b, 10_000) benchmarkTypesMapGetParallel(b, 10_000) @@ -53,6 +65,18 @@ func BenchmarkMapGetParallel(b *testing.B) { }) } +func BenchmarkStepParallelMapGet(b *testing.B) { + b.Skip() + step := uint64(100_000) + for sz := step; sz < step*20; sz += step { + nm := fmt.Sprintf("benchmark maps parallel %d", sz) + b.Run(nm, func(b *testing.B) { + benchmarkProllyMapGetParallel(b, sz) + benchmarkTypesMapGetParallel(b, sz) + }) + } +} + func BenchmarkProllyGetLarge(b *testing.B) { benchmarkProllyMapGet(b, 1_000_000) } @@ -61,6 +85,14 @@ func BenchmarkNomsGetLarge(b *testing.B) { benchmarkTypesMapGet(b, 1_000_000) } +func BenchmarkProllyParallelGetLarge(b *testing.B) { + benchmarkProllyMapGetParallel(b, 1_000_000) +} + +func BenchmarkNomsParallelGetLarge(b *testing.B) { + benchmarkTypesMapGetParallel(b, 1_000_000) +} + func benchmarkProllyMapGet(b *testing.B, size uint64) { bench := generateProllyBench(b, size) b.Run(fmt.Sprintf("benchmark prolly map %d", size), func(b *testing.B) { @@ -94,8 +126,9 @@ func benchmarkProllyMapGetParallel(b *testing.B, size uint64) { b.Run(fmt.Sprintf("benchmark prolly map %d", size), func(b *testing.B) { b.RunParallel(func(b *testing.PB) { ctx := context.Background() + rnd := rand.NewSource(0) for b.Next() { - idx := rand.Uint64() % uint64(len(bench.tups)) + idx := int(rnd.Int63()) % len(bench.tups) key := bench.tups[idx][0] _ = bench.m.Get(ctx, key, func(_, _ val.Tuple) (e error) { return @@ -111,8 +144,9 @@ func benchmarkTypesMapGetParallel(b *testing.B, size uint64) { b.Run(fmt.Sprintf("benchmark types map %d", size), func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { ctx := context.Background() + rnd := rand.NewSource(0) for pb.Next() { - idx := rand.Uint64() % uint64(len(bench.tups)) + idx := int(rnd.Int63()) % len(bench.tups) _, _, _ = bench.m.MaybeGet(ctx, bench.tups[idx][0]) } }) diff --git a/go/store/prolly/benchmark/benchmark_write_test.go b/go/store/prolly/benchmark/benchmark_write_test.go index 885c5f731b..d1be622293 100644 --- a/go/store/prolly/benchmark/benchmark_write_test.go +++ b/go/store/prolly/benchmark/benchmark_write_test.go @@ -18,8 +18,6 @@ import ( "context" "math/rand" "testing" - - "github.com/stretchr/testify/require" ) func BenchmarkMapUpdate(b *testing.B) { @@ -38,11 +36,11 @@ func BenchmarkMapUpdate(b *testing.B) { } func BenchmarkProllySmallWrites(b *testing.B) { - benchmarkProllyMapUpdate(b, 10_000, 10) + benchmarkProllyMapUpdate(b, 10_000, 1) } func BenchmarkTypesSmallWrites(b *testing.B) { - benchmarkTypesMapUpdate(b, 10_000, 10) + benchmarkTypesMapUpdate(b, 10_000, 1) } func BenchmarkProllyMediumWrites(b *testing.B) { @@ -72,8 +70,7 @@ func benchmarkProllyMapUpdate(b *testing.B, size, k uint64) { idx = rand.Uint64() % uint64(len(bench.tups)) value := bench.tups[idx][0] - err := mut.Put(ctx, key, value) - require.NoError(b, err) + _ = mut.Put(ctx, key, value) } _, _ = mut.Map(ctx) } diff --git a/go/store/prolly/message/prolly_map.go b/go/store/prolly/message/prolly_map.go index 79b074fc1b..095337d34a 100644 --- a/go/store/prolly/message/prolly_map.go +++ b/go/store/prolly/message/prolly_map.go @@ -197,11 +197,13 @@ func estimateProllyMapSize(keys, values [][]byte, subtrees []uint64) (keySz, val panic(fmt.Sprintf("value vector exceeds Size limit ( %d > %d )", valSz, MaxVectorOffset)) } + // todo(andy): better estimates bufSz += keySz + valSz // tuples bufSz += refCntSz // subtree counts bufSz += len(keys)*2 + len(values)*2 // offsets bufSz += 8 + 1 + 1 + 1 // metadata bufSz += 72 // vtable (approx) + bufSz += 100 // padding? return } diff --git a/go/store/prolly/message/serialize.go b/go/store/prolly/message/serialize.go index 286fa83200..fa3c9fbc49 100644 --- a/go/store/prolly/message/serialize.go +++ b/go/store/prolly/message/serialize.go @@ -28,8 +28,8 @@ const ( func getFlatbufferBuilder(pool pool.BuffPool, sz int) (b *fb.Builder) { b = fb.NewBuilder(0) - buf := pool.Get(uint64(sz)) - b.Bytes = buf[:0] + b.Bytes = pool.Get(uint64(sz)) + b.Reset() return } diff --git a/go/store/prolly/tree/chunker.go b/go/store/prolly/tree/chunker.go index 2f2bceac8c..6965760196 100644 --- a/go/store/prolly/tree/chunker.go +++ b/go/store/prolly/tree/chunker.go @@ -61,7 +61,6 @@ func newChunker[S message.Serializer](ctx context.Context, cur *Cursor, level in splitter := defaultSplitterFactory(uint8(level % 256)) builder := newNodeBuilder(serializer, level) - builder.startNode() sc := &chunker[S]{ cur: cur, @@ -324,7 +323,6 @@ func (tc *chunker[S]) handleChunkBoundary(ctx context.Context) error { } tc.splitter.Reset() - tc.builder.startNode() return nil } diff --git a/go/store/prolly/tree/chunker_test.go b/go/store/prolly/tree/chunker_test.go index c3095bc5fb..8b1691c03a 100644 --- a/go/store/prolly/tree/chunker_test.go +++ b/go/store/prolly/tree/chunker_test.go @@ -71,9 +71,7 @@ func validateTreeItems(t *testing.T, ns NodeStore, nd Node, expected [][2]Item) i := 0 ctx := context.Background() err := iterTree(ctx, ns, nd, func(actual Item) (err error) { - if !assert.Equal(t, expected[i/2][i%2], actual) { - panic("here") - } + assert.Equal(t, expected[i/2][i%2], actual) i++ return }) diff --git a/go/store/prolly/tree/node_builder.go b/go/store/prolly/tree/node_builder.go index 12b1df680f..93c1adc382 100644 --- a/go/store/prolly/tree/node_builder.go +++ b/go/store/prolly/tree/node_builder.go @@ -16,16 +16,13 @@ package tree import ( "context" + "sync" "github.com/dolthub/dolt/go/store/prolly/message" "github.com/dolthub/dolt/go/store/hash" ) -const ( - nodeBuilderListSize = 256 -) - type novelNode struct { node Node addr hash.Hash @@ -58,11 +55,12 @@ func writeNewNode[S message.Serializer](ctx context.Context, ns NodeStore, bld * }, nil } -func newNodeBuilder[S message.Serializer](serializer S, level int) *nodeBuilder[S] { - return &nodeBuilder[S]{ +func newNodeBuilder[S message.Serializer](serializer S, level int) (nb *nodeBuilder[S]) { + nb = &nodeBuilder[S]{ level: level, serializer: serializer, } + return } type nodeBuilder[S message.Serializer] struct { @@ -72,16 +70,17 @@ type nodeBuilder[S message.Serializer] struct { serializer S } -func (nb *nodeBuilder[S]) startNode() { - nb.reset() -} - func (nb *nodeBuilder[S]) hasCapacity(key, value Item) bool { sum := nb.size + len(key) + len(value) return sum <= int(message.MaxVectorOffset) } func (nb *nodeBuilder[S]) addItems(key, value Item, subtree uint64) { + if nb.keys == nil { + nb.keys = getItemSlices() + nb.values = getItemSlices() + nb.subtrees = getSubtreeSlice() + } nb.keys = append(nb.keys, key) nb.values = append(nb.values, value) nb.size += len(key) + len(value) @@ -94,14 +93,49 @@ func (nb *nodeBuilder[S]) count() int { func (nb *nodeBuilder[S]) build() (node Node) { msg := nb.serializer.Serialize(nb.keys, nb.values, nb.subtrees, nb.level) - nb.reset() + nb.recycleBuffers() + nb.size = 0 return NodeFromBytes(msg) } -func (nb *nodeBuilder[S]) reset() { - // buffers are copied, it's safe to re-use the memory. - nb.keys = nb.keys[:0] - nb.values = nb.values[:0] - nb.size = 0 - nb.subtrees = nb.subtrees[:0] +func (nb *nodeBuilder[S]) recycleBuffers() { + putItemSlices(nb.keys[:0]) + putItemSlices(nb.values[:0]) + putSubtreeSlice(nb.subtrees[:0]) + nb.keys = nil + nb.values = nil + nb.subtrees = nil +} + +// todo(andy): replace with NodeStore.Pool() +const nodeBuilderListSize = 256 + +var itemsPool = sync.Pool{ + New: func() any { + return make([][]byte, 0, nodeBuilderListSize) + }, +} + +func getItemSlices() [][]byte { + sl := itemsPool.Get().([][]byte) + return sl[:0] +} + +func putItemSlices(sl [][]byte) { + itemsPool.Put(sl[:0]) +} + +var subtreePool = sync.Pool{ + New: func() any { + return make([]uint64, 0, nodeBuilderListSize) + }, +} + +func getSubtreeSlice() []uint64 { + sl := subtreePool.Get().([]uint64) + return sl[:0] +} + +func putSubtreeSlice(sl []uint64) { + subtreePool.Put(sl[:0]) } diff --git a/go/store/prolly/tree/node_cache.go b/go/store/prolly/tree/node_cache.go new file mode 100644 index 0000000000..202089ba55 --- /dev/null +++ b/go/store/prolly/tree/node_cache.go @@ -0,0 +1,189 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "encoding/binary" + "fmt" + "sync" + + "github.com/dolthub/dolt/go/store/chunks" + "github.com/dolthub/dolt/go/store/hash" +) + +const ( + numStripes = 256 +) + +func newChunkCache(maxSize int) (c chunkCache) { + sz := maxSize / numStripes + for i := range c.stripes { + c.stripes[i] = newStripe(sz) + } + return +} + +type chunkCache struct { + stripes [numStripes]*stripe +} + +func (c chunkCache) get(addr hash.Hash) (chunks.Chunk, bool) { + return c.pickStripe(addr).get(addr) +} + +func (c chunkCache) insert(ch chunks.Chunk) { + c.pickStripe(ch.Hash()).insert(ch) +} + +func (c chunkCache) pickStripe(addr hash.Hash) *stripe { + i := binary.LittleEndian.Uint32(addr[:4]) % numStripes + return c.stripes[i] +} + +type centry struct { + c chunks.Chunk + i int + prev *centry + next *centry +} + +type stripe struct { + mu *sync.Mutex + chunks map[hash.Hash]*centry + head *centry + sz int + maxSz int + rev int +} + +func newStripe(maxSize int) *stripe { + return &stripe{ + &sync.Mutex{}, + make(map[hash.Hash]*centry), + nil, + 0, + maxSize, + 0, + } +} + +func removeFromList(e *centry) { + e.prev.next = e.next + e.next.prev = e.prev + e.prev = e + e.next = e +} + +func (s *stripe) moveToFront(e *centry) { + e.i = s.rev + s.rev++ + if s.head == e { + return + } + if s.head != nil { + removeFromList(e) + e.next = s.head + e.prev = s.head.prev + s.head.prev = e + e.prev.next = e + } + s.head = e +} + +func (s *stripe) get(h hash.Hash) (chunks.Chunk, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if e, ok := s.chunks[h]; ok { + s.moveToFront(e) + return e.c, true + } else { + return chunks.EmptyChunk, false + } +} + +func (s *stripe) insert(c chunks.Chunk) { + s.mu.Lock() + defer s.mu.Unlock() + s.addIfAbsent(c) +} + +func (s *stripe) addIfAbsent(c chunks.Chunk) { + if e, ok := s.chunks[c.Hash()]; !ok { + e = ¢ry{c, 0, nil, nil} + e.next = e + e.prev = e + s.moveToFront(e) + s.chunks[c.Hash()] = e + s.sz += c.Size() + s.shrinkToMaxSz() + } else { + s.moveToFront(e) + } +} + +func (s *stripe) shrinkToMaxSz() { + for s.sz > s.maxSz { + if s.head != nil { + t := s.head.prev + removeFromList(t) + if t == s.head { + s.head = nil + } + delete(s.chunks, t.c.Hash()) + s.sz -= t.c.Size() + } else { + panic("cache is empty but cache Size is > than max Size") + } + } +} + +func (s *stripe) sanityCheck() { + if s.head != nil { + p := s.head.next + i := 1 + sz := s.head.c.Size() + lasti := s.head.i + for p != s.head { + i++ + sz += p.c.Size() + if p.i >= lasti { + panic("encountered lru list entry with higher rev later in the list.") + } + p = p.next + } + if i != len(s.chunks) { + panic(fmt.Sprintf("cache lru list has different Size than cache.chunks. %d vs %d", i, len(s.chunks))) + } + if sz != s.sz { + panic("entries reachable from lru list have different Size than cache.sz.") + } + j := 1 + p = s.head.prev + for p != s.head { + j++ + p = p.prev + } + if j != i { + panic("length of list backwards is not equal to length of list forward") + } + } else { + if len(s.chunks) != 0 { + panic("lru list is empty but s.chunks is not") + } + if s.sz != 0 { + panic("lru list is empty but s.sz is not 0") + } + } +} diff --git a/go/store/prolly/tree/node_splitter.go b/go/store/prolly/tree/node_splitter.go index a66b7fc90b..5dceca1c6b 100644 --- a/go/store/prolly/tree/node_splitter.go +++ b/go/store/prolly/tree/node_splitter.go @@ -24,7 +24,6 @@ package tree import ( "crypto/sha512" "encoding/binary" - "fmt" "math" "github.com/kch42/buzhash" @@ -64,7 +63,7 @@ type nodeSplitter interface { // Append provides more nodeItems to the splitter. Splitter's make chunk // boundary decisions based on the Item contents. Upon return, callers // can use CrossedBoundary() to see if a chunk boundary has crossed. - Append(items ...Item) error + Append(key, values Item) error // CrossedBoundary returns true if the provided nodeItems have caused a chunk // boundary to be crossed. @@ -113,11 +112,12 @@ func newRollingHashSplitter(salt uint8) nodeSplitter { var _ splitterFactory = newRollingHashSplitter // Append implements NodeSplitter -func (sns *rollingHashSplitter) Append(items ...Item) (err error) { - for _, it := range items { - for _, byt := range it { - _ = sns.hashByte(byt) - } +func (sns *rollingHashSplitter) Append(key, value Item) (err error) { + for _, byt := range key { + _ = sns.hashByte(byt) + } + for _, byt := range value { + _ = sns.hashByte(byt) } return nil } @@ -189,13 +189,9 @@ func newKeySplitter(level uint8) nodeSplitter { var _ splitterFactory = newKeySplitter -func (ks *keySplitter) Append(items ...Item) error { - if len(items) != 2 { - return fmt.Errorf("expected 2 nodeItems, %d were passed", len(items)) - } - +func (ks *keySplitter) Append(key, value Item) error { // todo(andy): account for key/value offsets, vtable, etc. - thisSize := uint32(len(items[0]) + len(items[1])) + thisSize := uint32(len(key) + len(value)) ks.size += thisSize if ks.size < minChunkSize { @@ -206,7 +202,7 @@ func (ks *keySplitter) Append(items ...Item) error { return nil } - h := xxHash32(items[0], ks.salt) + h := xxHash32(key, ks.salt) ks.crossedBoundary = weibullCheck(ks.size, thisSize, h) return nil } diff --git a/go/store/prolly/tree/node_store.go b/go/store/prolly/tree/node_store.go index 7e955409c0..680206f9be 100644 --- a/go/store/prolly/tree/node_store.go +++ b/go/store/prolly/tree/node_store.go @@ -16,14 +16,11 @@ package tree import ( "context" - "fmt" - "sync" - - "github.com/dolthub/dolt/go/store/types" "github.com/dolthub/dolt/go/store/chunks" "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/pool" + "github.com/dolthub/dolt/go/store/types" ) const ( @@ -47,7 +44,7 @@ type NodeStore interface { type nodeStore struct { store chunks.ChunkStore - cache *chunkCache + cache chunkCache bp pool.BuffPool } @@ -105,175 +102,3 @@ func (ns nodeStore) Format() *types.NomsBinFormat { // todo(andy): read from |ns.store| return types.Format_DOLT_1 } - -type centry struct { - c chunks.Chunk - i int - prev *centry - next *centry -} - -type chunkCache struct { - mu *sync.Mutex - chunks map[hash.Hash]*centry - head *centry - sz int - maxSz int - rev int -} - -func newChunkCache(maxSize int) *chunkCache { - return &chunkCache{ - &sync.Mutex{}, - make(map[hash.Hash]*centry), - nil, - 0, - maxSize, - 0, - } -} - -func removeFromCList(e *centry) { - e.prev.next = e.next - e.next.prev = e.prev - e.prev = e - e.next = e -} - -func (mc *chunkCache) moveToFront(e *centry) { - e.i = mc.rev - mc.rev++ - if mc.head == e { - return - } - if mc.head != nil { - removeFromCList(e) - e.next = mc.head - e.prev = mc.head.prev - mc.head.prev = e - e.prev.next = e - } - mc.head = e -} - -func (mc *chunkCache) get(h hash.Hash) (chunks.Chunk, bool) { - mc.mu.Lock() - defer mc.mu.Unlock() - if e, ok := mc.chunks[h]; ok { - mc.moveToFront(e) - return e.c, true - } else { - return chunks.EmptyChunk, false - } -} - -func (mc *chunkCache) getMany(hs hash.HashSet) ([]chunks.Chunk, hash.HashSet) { - mc.mu.Lock() - defer mc.mu.Unlock() - absent := make(map[hash.Hash]struct{}) - var found []chunks.Chunk - for h, _ := range hs { - if e, ok := mc.chunks[h]; ok { - mc.moveToFront(e) - found = append(found, e.c) - } else { - absent[h] = struct{}{} - } - } - return found, absent -} - -func (mc *chunkCache) insert(c chunks.Chunk) { - mc.mu.Lock() - defer mc.mu.Unlock() - mc.addIfAbsent(c) -} - -func (mc *chunkCache) insertMany(cs []chunks.Chunk) { - mc.mu.Lock() - defer mc.mu.Unlock() - for _, c := range cs { - mc.addIfAbsent(c) - } -} - -func (mc *chunkCache) addIfAbsent(c chunks.Chunk) { - if e, ok := mc.chunks[c.Hash()]; !ok { - e := ¢ry{c, 0, nil, nil} - e.next = e - e.prev = e - mc.moveToFront(e) - mc.chunks[c.Hash()] = e - mc.sz += c.Size() - mc.shrinkToMaxSz() - } else { - mc.moveToFront(e) - } -} - -func (mc *chunkCache) Len() int { - mc.mu.Lock() - defer mc.mu.Unlock() - return len(mc.chunks) -} - -func (mc *chunkCache) Size() int { - mc.mu.Lock() - defer mc.mu.Unlock() - return mc.sz -} - -func (mc *chunkCache) shrinkToMaxSz() { - for mc.sz > mc.maxSz { - if mc.head != nil { - t := mc.head.prev - removeFromCList(t) - if t == mc.head { - mc.head = nil - } - delete(mc.chunks, t.c.Hash()) - mc.sz -= t.c.Size() - } else { - panic("cache is empty but cache Size is > than max Size") - } - } -} - -func (mc *chunkCache) sanityCheck() { - if mc.head != nil { - p := mc.head.next - i := 1 - sz := mc.head.c.Size() - lasti := mc.head.i - for p != mc.head { - i++ - sz += p.c.Size() - if p.i >= lasti { - panic("encountered lru list entry with higher rev later in the list.") - } - p = p.next - } - if i != len(mc.chunks) { - panic(fmt.Sprintf("cache lru list has different Size than cache.chunks. %d vs %d", i, len(mc.chunks))) - } - if sz != mc.sz { - panic("entries reachable from lru list have different Size than cache.sz.") - } - j := 1 - p = mc.head.prev - for p != mc.head { - j++ - p = p.prev - } - if j != i { - panic("length of list backwards is not equal to length of list forward") - } - } else { - if len(mc.chunks) != 0 { - panic("lru list is empty but mc.chunks is not") - } - if mc.sz != 0 { - panic("lru list is empty but mc.sz is not 0") - } - } -} diff --git a/go/store/skip/list.go b/go/store/skip/list.go index 64d6192cb5..878d124c0b 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -17,6 +17,8 @@ package skip import ( "math" "math/rand" + + "github.com/zeebo/xxh3" ) const ( @@ -33,9 +35,8 @@ type List struct { count uint32 checkpoint nodeId - - cmp ValueCmp - src rand.Source + cmp ValueCmp + salt uint64 } type ValueCmp func(left, right []byte) int @@ -56,23 +57,22 @@ type skipNode struct { } func NewSkipList(cmp ValueCmp) *List { - // todo(andy): buffer pool - nodes := make([]skipNode, 1, 128) + nodes := make([]skipNode, 0, 8) // initialize sentinel node - nodes[sentinelId] = skipNode{ + nodes = append(nodes, skipNode{ id: sentinelId, key: nil, val: nil, height: maxHeight, next: skipPointer{}, prev: sentinelId, - } + }) return &List{ nodes: nodes, checkpoint: nodeId(1), cmp: cmp, - src: rand.NewSource(0), + salt: rand.Uint64(), } } @@ -190,7 +190,7 @@ func (l *List) insert(key, value []byte, path skipPointer) { key: key, val: value, id: l.nextNodeId(), - height: rollHeight(l.src), + height: rollHeight(key, l.salt), } l.nodes = append(l.nodes, novel) @@ -358,8 +358,8 @@ const ( pattern3 = uint64(1<<12 - 1) ) -func rollHeight(r rand.Source) (h uint8) { - roll := r.Int63() +func rollHeight(key []byte, salt uint64) (h uint8) { + roll := xxh3.HashSeed(key, salt) patterns := []uint64{ pattern0, pattern1, @@ -376,9 +376,3 @@ func rollHeight(r rand.Source) (h uint8) { return } - -func assertTrue(b bool) { - if !b { - panic("expected true") - } -} diff --git a/go/utils/concurrency/runner.go b/go/utils/concurrency/runner.go new file mode 100644 index 0000000000..964171a26c --- /dev/null +++ b/go/utils/concurrency/runner.go @@ -0,0 +1,78 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "os" + + _ "github.com/go-sql-driver/mysql" + "golang.org/x/sync/errgroup" +) + +const clients = 16 +const iters = 10 + +var sqlScript = []string{ + "call dolt_checkout('main');", + "select * from dolt_log order by date desc limit 10;", +} + +var ( + database = "SHAQ" + user = "root" + pass = "" + host = "127.0.0.1" + port = "3306" +) + +// Runs |sqlScript| concurrently on multiple clients. +// Useful for repoducing concurrency bugs. +func main() { + connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", + user, pass, host, port, database) + + db, err := sql.Open("mysql", connStr) + maybeExit(err) + + eg, ctx := errgroup.WithContext(context.Background()) + + for i := 0; i < clients; i++ { + eg.Go(func() (err error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + cerr := conn.Close() + if err != nil { + err = cerr + } + }() + for j := 0; j < iters; j++ { + if err = query(ctx, conn); err != nil { + return err + } + } + return + }) + } + maybeExit(eg.Wait()) +} + +func query(ctx context.Context, conn *sql.Conn) error { + for i := range sqlScript { + _, err := conn.ExecContext(ctx, sqlScript[i]) + if err != nil { + return err + } + } + return nil +} + +func maybeExit(err error) { + if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } +} diff --git a/integration-tests/bats/deleted-branches.bats b/integration-tests/bats/deleted-branches.bats index 018ba9e57c..926e375c07 100644 --- a/integration-tests/bats/deleted-branches.bats +++ b/integration-tests/bats/deleted-branches.bats @@ -36,7 +36,7 @@ make_it() { start_sql_server "dolt_repo_$$" - server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'to_keep'" + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_repo_$$_default_branch = 'to_keep'" server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" @@ -63,7 +63,7 @@ make_it() { start_sql_server "dolt_repo_$$" - server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'this_branch_does_not_exist'" + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_repo_$$_default_branch = 'this_branch_does_not_exist'" # Against the default branch it fails run server_query "dolt_repo_$$" 1 "SELECT * FROM test" "" @@ -78,7 +78,7 @@ make_it() { start_sql_server "dolt_repo_$$" - server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'this_branch_does_not_exist'" + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_repo_$$_default_branch = 'this_branch_does_not_exist'" multi_query "dolt_repo_$$/main" 1 " SELECT * FROM test; @@ -101,14 +101,14 @@ SELECT DOLT_CHECKOUT('to_checkout'); SELECT * FROM test;" } -@test "deleted-branches: can DOLT_CHECKOUT on SQL connecttion with dolt_default_branch set to existing branch when checked out branch is deleted" { +@test "deleted-branches: can DOLT_CHECKOUT on SQL connection with dolt_default_branch set to existing branch when checked out branch is deleted" { make_it dolt branch -c to_keep to_checkout start_sql_server "dolt_repo_$$" - server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'to_keep'" + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_repo_$$_default_branch = 'to_keep'" server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" diff --git a/integration-tests/bats/foreign-keys.bats b/integration-tests/bats/foreign-keys.bats index 89aa2f5539..9b272694d3 100644 --- a/integration-tests/bats/foreign-keys.bats +++ b/integration-tests/bats/foreign-keys.bats @@ -803,7 +803,6 @@ SQL } @test "foreign-keys: dolt table import" { - # Foreign key processing was moved to the engine, therefore you must import data through LOAD DATA now dolt sql < update_parent.csv run dolt table import -u parent update_parent.csv - [ "$status" -eq "1" ] - [[ "$output" =~ "foreign key" ]] || false + [ "$status" -eq "0" ] + [[ "$output" =~ "Rows Processed: 2, Additions: 0, Modifications: 2, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "select * from parent order by id" + [ "$status" -eq "0" ] + [[ "$output" =~ "id,v1,v2" ]] || false + [[ "$output" =~ "1,3,3" ]] || false + [[ "$output" =~ "2,4,4" ]] || false + + run dolt sql -r csv -q "select * from child order by id" + [ "$status" -eq "0" ] + [[ "$output" =~ "id,v1,v2" ]] || false + [[ "$output" =~ "1,3,1" ]] || false + [[ "$output" =~ "2,4,2" ]] || false } @test "foreign-keys: Commit all" { @@ -1616,7 +1627,6 @@ SQL } @test "foreign-keys: dolt table import with null in nullable FK field should work (issue #2108)" { - # Foreign key processing was moved to the engine, therefore you must import data through LOAD DATA now dolt sql < fk_test.csv run dolt table import -u businesses fk_test.csv - [ "$status" -eq "1" ] + [ "$status" -eq "0" ] + [[ "$output" =~ 'Rows Processed: 2, Additions: 2, Modifications: 0, Had No Effect: 0' ]] || false + + run dolt sql -r csv -q "SELECT * FROM businesses order by name" + [ "$status" -eq "0" ] + [[ "$output" =~ 'name,naics_2017' ]] || false + [[ "$output" =~ 'test,' ]] || false + [[ "$output" =~ 'test2,100' ]] || false } @test "foreign-keys: Delayed foreign key resolution" { diff --git a/integration-tests/bats/import-replace-tables.bats b/integration-tests/bats/import-replace-tables.bats index bc13d70b13..46ebcfc3b2 100644 --- a/integration-tests/bats/import-replace-tables.bats +++ b/integration-tests/bats/import-replace-tables.bats @@ -369,3 +369,34 @@ DELIM [ "${#lines[@]}" -eq 2 ] [ "${lines[1]}" = "0,1,2,3" ] } + +@test "import-replace-tables: Replace that breaks fk constraints correctly errors" { + dolt sql < colors-bad.csv +id,name +1,'red' +DELIM + + run dolt table import -r colors colors-bad.csv + [ "$status" -eq 1 ] + [[ "$output" =~ "cannot truncate table colors as it is referenced in foreign key" ]] || false +} diff --git a/integration-tests/bats/import-update-tables.bats b/integration-tests/bats/import-update-tables.bats index 3944bc6485..fe65b54737 100644 --- a/integration-tests/bats/import-update-tables.bats +++ b/integration-tests/bats/import-update-tables.bats @@ -85,6 +85,26 @@ SQL INC_DATA_YEAR,NIBRS_MONTH_ID,AGENCY_ID,MONTH_NUM,DATA_YEAR,REPORTED_STATUS,REPORT_DATE,UPDATE_FLAG,ORIG_FORMAT,DATA_HOME,DDOCNAME,DID,MONTH_PUB_STATUS,STATE_ID,AGENCY_TABLE_TYPE_ID 2019,9128595,9305,3,2019,I,2019-07-18,Y,F,C,2019_03_MN0510000_NIBRS,49502383,0,27,2 CSV + + dolt sql < objects-good.csv +id,name,color +4,laptop,blue +5,dollar,green +6,bottle,red +DELIM + + run dolt table import -u objects objects-good.csv + [ "$status" -eq 0 ] + [[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "SELECT * FROM objects where id >= 4" + [ $status -eq 0 ] + [[ "$output" =~ "id,name,color" ]] || false + [[ "$output" =~ "4,laptop,blue" ]] || false + [[ "$output" =~ "5,dollar,green" ]] || false + [[ "$output" =~ "6,bottle,red" ]] || false +} + +@test "import-update-tables: unsuccessfully update child table in fk relationship" { + cat < objects-bad.csv +id,name,color +4,laptop,blue +5,dollar,green +6,bottle,gray +DELIM + + run dolt table import -u objects objects-bad.csv + [ "$status" -eq 1 ] + [[ "$output" =~ "A bad row was encountered while moving data" ]] || false + [[ "$output" =~ "Bad Row: [6,bottle,gray]" ]] || false + [[ "$output" =~ "cannot add or update a child row - Foreign key violation" ]] || false + + run dolt table import -u objects objects-bad.csv --continue + [ "$status" -eq 0 ] + [[ "$output" =~ "The following rows were skipped:" ]] || false + [[ "$output" =~ "[6,bottle,gray]" ]] || false + [[ "$output" =~ "Rows Processed: 2, Additions: 2, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "SELECT * FROM objects where id >= 4" + [ $status -eq 0 ] + [[ "$output" =~ "id,name,color" ]] || false + [[ "$output" =~ "4,laptop,blue" ]] || false + [[ "$output" =~ "5,dollar,green" ]] || false + ! [[ "$output" =~ "6,bottle,red" ]] || false +} + +@test "import-update-tables: successfully update child table in multi-key fk relationship " { + skip_nbf_dolt_1 + dolt sql -q "drop table objects" + dolt sql -q "drop table colors" + + dolt sql < multi-key-good.csv +id,name,color,material +4,laptop,red,steel +5,dollar,green,rubber +6,bottle,blue,leather +DELIM + + run dolt table import -u objects multi-key-good.csv + [ "$status" -eq 0 ] + [[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "SELECT * FROM objects where id >= 4 ORDER BY id" + [ $status -eq 0 ] + [[ "$output" =~ "id,name,color,material" ]] || false + [[ "$output" =~ "4,laptop,red,steel" ]] || false + [[ "$output" =~ "5,dollar,green,rubber" ]] || false + [[ "$output" =~ "6,bottle,blue,leather" ]] || false + + cat < multi-key-bad.csv +id,name,color,material +4,laptop,red,steel +5,dollar,green,rubber +6,bottle,blue,steel +DELIM + + run dolt table import -u objects multi-key-bad.csv + [ "$status" -eq 1 ] + [[ "$output" =~ "A bad row was encountered while moving data" ]] || false + [[ "$output" =~ "Bad Row: [6,bottle,blue,steel]" ]] || false + [[ "$output" =~ "cannot add or update a child row - Foreign key violation" ]] || false + + run dolt table import -u objects multi-key-bad.csv --continue + [ "$status" -eq 0 ] + [[ "$output" =~ "The following rows were skipped:" ]] || false + [[ "$output" =~ "[6,bottle,blue,steel]" ]] || false + [[ "$output" =~ "Rows Processed: 2, Additions: 0, Modifications: 0, Had No Effect: 2" ]] || false + + run dolt sql -r csv -q "SELECT * FROM objects where id >= 4 ORDER BY id" + [ $status -eq 0 ] + [[ "$output" =~ "id,name,color,material" ]] || false + [[ "$output" =~ "4,laptop,red,steel" ]] || false + [[ "$output" =~ "5,dollar,green,rubber" ]] || false + ! [[ "$output" =~ "6,bottle,blue,steel" ]] || false +} + +@test "import-update-tables: import update with CASCADE ON UPDATE" { + skip_nbf_dolt_1 + dolt sql < table-one.csv +pk,v1,v2 +1,2,2 +DELIM + + run dolt table import -u one table-one.csv + [ $status -eq 0 ] + [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "select * from two where pk = 2" + [ $status -eq 0 ] + [[ "$output" =~ "pk,v1,v2" ]] || false + [[ "$output" =~ "2,2,1" ]] || false + + run dolt sql -r csv -q "select * from three where pk = 3" + [ $status -eq 0 ] + [[ "$output" =~ "pk,v1,v2" ]] || false + [[ "$output" =~ "3,2,1" ]] || false +} + +@test "import-update-tables: unsuccessfully update parent table in fk relationship" { + cat < colors-bad.csv +id,color +3,dsadasda +5,yellow +DELIM + + run dolt table import -u colors colors-bad.csv + [ "$status" -eq 1 ] + [[ "$output" =~ "A bad row was encountered while moving data" ]] || false + [[ "$output" =~ "cannot delete or update a parent row" ]] || false + + run dolt table import -u colors colors-bad.csv --continue + [ "$status" -eq 0 ] + + run dolt sql -r csv -q "SELECT * from colors where id in (3,5)" + [ "$status" -eq 0 ] + [[ "$output" =~ "id,color" ]] || false + [[ "$output" =~ "3,blue" ]] || false + [[ "$output" =~ "5,yellow" ]] || false +} + +@test "import-update-tables: circular foreign keys" { + dolt sql < circular-keys-good.csv +id,v1,v2 +4,4,2 +DELIM + + run dolt table import -u tbl circular-keys-good.csv + [ $status -eq 0 ] + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false + + cat < circular-keys-bad.csv +id,v1,v2 +5,5,1 +6,6,1000 +DELIM + + run dolt table import -u tbl circular-keys-bad.csv + [ $status -eq 1 ] + [[ "$output" =~ "A bad row was encountered while moving data" ]] || false + [[ "$output" =~ "cannot add or update a child row" ]] || false +} + +@test "import-update-tables: disable foreign key checks" { + skip_nbf_dolt_1 + cat < objects-bad.csv +id,name,color +4,laptop,blue +5,dollar,green +6,bottle,gray +DELIM + + run dolt table import -u objects objects-bad.csv --disable-fk-checks + [ "$status" -eq 0 ] + [[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false + + run dolt sql -r csv -q "select * from objects where id = 6" + [ "$status" -eq 0 ] + [[ "$output" =~ "6,bottle,gray" ]] || false + + run dolt constraints verify objects + [ "$status" -eq 1 ] + [[ "$output" =~ "All constraints are not satisfied" ]] || false +} diff --git a/integration-tests/bats/multidb.bats b/integration-tests/bats/multidb.bats new file mode 100644 index 0000000000..98fbd3642c --- /dev/null +++ b/integration-tests/bats/multidb.bats @@ -0,0 +1,35 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash +load $BATS_TEST_DIRNAME/helper/query-server-common.bash + +setup() { + setup_common + TMPDIRS=$(pwd)/tmpdirs + + init_helper $TMPDIRS + cd $TMPDIRS +} + +init_helper() { + TMPDIRS=$1 + mkdir -p "${TMPDIRS}/dbs1" + for i in {1..2}; do + mkdir "${TMPDIRS}/dbs1/repo${i}" + cd "${TMPDIRS}/dbs1/repo${i}" + dolt init + done +} + +teardown() { + stop_sql_server + teardown_common + rm -rf $TMPDIRS + cd $BATS_TMPDIR +} + +@test "multidb: database default branches" { + cd dbs1 + start_multi_db_server repo1 + multi_query repo1 1 "create database new; use new; call dcheckout('-b', 'feat'); create table t (x int); call dcommit('-am', 'cm'); set @@global.new_default_branch='feat'" + server_query repo1 1 "use repo1" +} diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index 4590975c1d..8e27ebeddb 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -755,8 +755,8 @@ SQL INSERT INTO t VALUES (2,2),(3,3);' "" server_query repo1 1 "SHOW tables" "" # no tables on main - server_query repo1 1 "set GLOBAL dolt_default_branch = 'refs/heads/new';" "" - server_query repo1 1 "select @@GLOBAL.dolt_default_branch;" "@@GLOBAL.dolt_default_branch\nrefs/heads/new" + server_query repo1 1 "set GLOBAL repo1_default_branch = 'refs/heads/new';" "" + server_query repo1 1 "select @@GLOBAL.repo1_default_branch;" "@@GLOBAL.repo1_default_branch\nrefs/heads/new" server_query repo1 1 "select active_branch()" "active_branch()\nnew" server_query repo1 1 "SHOW tables" "Tables_in_repo1\nt" } @@ -775,8 +775,8 @@ SQL INSERT INTO t VALUES (2,2),(3,3);' "" server_query repo1 1 "SHOW tables" "" # no tables on main - server_query repo1 1 "set GLOBAL dolt_default_branch = 'new';" "" - server_query repo1 1 "select @@GLOBAL.dolt_default_branch;" "@@GLOBAL.dolt_default_branch\nnew" + server_query repo1 1 "set GLOBAL repo1_default_branch = 'new';" "" + server_query repo1 1 "select @@GLOBAL.repo1_default_branch;" "@@GLOBAL.repo1_default_branch\nnew" server_query repo1 1 "select active_branch()" "active_branch()\nnew" server_query repo1 1 "SHOW tables" "Tables_in_repo1\nt" } diff --git a/integration-tests/mysql-client-tests/r/rmariadb-test.r b/integration-tests/mysql-client-tests/r/rmariadb-test.r index 60a938efab..7001f80dfb 100644 --- a/integration-tests/mysql-client-tests/r/rmariadb-test.r +++ b/integration-tests/mysql-client-tests/r/rmariadb-test.r @@ -76,3 +76,14 @@ if (!ret) { print("Number of commits is incorrect") quit(1) } + +# Add a failing query and ensure that the connection does not quit. +# cc. https://github.com/dolthub/dolt/issues/3418 +try(dbExecute(conn, "insert into test values (0, 1)"), silent = TRUE) +one <- dbGetQuery(conn, "select 1 as pk") +ret <- one == data.frame(pk=1) +if (!ret) { + print("Number of commits is incorrect") + quit(1) +} +