diff --git a/.github/scripts/fuzzer/get-fuzzer-job-json.sh b/.github/scripts/fuzzer/get-fuzzer-job-json.sh index 3f58131fb7..9373a7488a 100755 --- a/.github/scripts/fuzzer/get-fuzzer-job-json.sh +++ b/.github/scripts/fuzzer/get-fuzzer-job-json.sh @@ -42,7 +42,7 @@ echo ' "--bucket=dolt-fuzzer-runs", "--region=us-west-2", "--version-gate-job", - "--fuzzer-args=merge, --cycles=5" + "--fuzzer-args=basic, --cycles=5" ] } ], diff --git a/.github/workflows/ci-format-repo.yaml b/.github/workflows/ci-format-repo.yaml index a4569b825a..eb0d988c76 100644 --- a/.github/workflows/ci-format-repo.yaml +++ b/.github/workflows/ci-format-repo.yaml @@ -3,7 +3,7 @@ name: Format PR on: pull_request: branches: [ main ] - + workflow_dispatch: jobs: format: name: Format PR diff --git a/.github/workflows/ci-fuzzer.yaml b/.github/workflows/ci-fuzzer.yaml index b47c054e49..c0a42bc111 100644 --- a/.github/workflows/ci-fuzzer.yaml +++ b/.github/workflows/ci-fuzzer.yaml @@ -1,6 +1,7 @@ name: Fuzzer on: + workflow_dispatch: push: paths: - 'go/**' diff --git a/go/cmd/dolt/cli/documentation_helper.go b/go/cmd/dolt/cli/documentation_helper.go index 0da828cd50..6244ff7aad 100644 --- a/go/cmd/dolt/cli/documentation_helper.go +++ b/go/cmd/dolt/cli/documentation_helper.go @@ -90,7 +90,8 @@ func (cmdDoc CommandDocumentation) CmdDocToMd() (string, error) { return "", err } - return templBuffer.String(), nil + ret := strings.Replace(templBuffer.String(), "HEAD~", "HEAD\\~", -1) + return ret, nil } // A struct that represents all the data structures required to create the documentation for a command. diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index ceeb088c50..32f51b6faa 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -127,9 +127,6 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, } } - var root *doltdb.RootValue - root, verr = GetWorkingWithVErr(dEnv) - if verr == nil { mergeActive, err := dEnv.IsMergeActive(ctx) if err != nil { @@ -137,37 +134,7 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, return 1 } - // If there are any conflicts or constraint violations then we disallow the merge - hasCnf, err := root.HasConflicts(ctx) - if err != nil { - verr = errhand.BuildDError("error: failed to get conflicts").AddCause(err).Build() - } - hasCV, err := root.HasConstraintViolations(ctx) - if err != nil { - verr = errhand.BuildDError("error: failed to get constraint violations").AddCause(err).Build() - } - - unmergedCnt, err := getUnmergedTableCount(ctx, root) - if err != nil { - cli.PrintErrln(err.Error()) - return 1 - } - - if hasCnf || hasCV { - cli.Printf("error: A merge is already in progress, %d table(s) are unmerged due to conflicts or constraint violations.\n", unmergedCnt) - cli.Println("hint: Fix them up in the working tree, and then use 'dolt add '") - cli.Println("hint: as appropriate to mark resolution and make a commit.") - if hasCnf && hasCV { - cli.Println("fatal: Exiting because of an unresolved conflict and constraint violation.\n" + - "fatal: Use 'dolt conflicts' to investigate and resolve conflicts.") - } else if hasCnf { - cli.Println("fatal: Exiting because of an unresolved conflict.\n" + - "fatal: Use 'dolt conflicts' to investigate and resolve conflicts.") - } else { - cli.Println("fatal: Exiting because of an unresolved constraint violation.") - } - return 1 - } else if mergeActive { + if mergeActive { cli.Println("error: Merging is not possible because you have not committed an active merge.") cli.Println("hint: add affected tables using 'dolt add
' and commit using 'dolt commit -m '") cli.Println("fatal: Exiting because of active merge") @@ -201,7 +168,7 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, } spec.Msg = msg - err = mergePrinting(ctx, dEnv, spec) + err = validateMergeSpec(ctx, spec) if err != nil { return handleCommitErr(ctx, dEnv, err, usage) } @@ -213,7 +180,7 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, cli.PrintErrln(err.Error()) return 1 } - unmergedCnt, err = getUnmergedTableCount(ctx, wRoot) + unmergedCnt, err := getUnmergedTableCount(ctx, wRoot) if err != nil { cli.PrintErrln(err.Error()) return 1 @@ -294,7 +261,7 @@ func getCommitMessage(ctx context.Context, apr *argparser.ArgParseResults, dEnv return "", nil } -func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec) errhand.VerboseError { +func validateMergeSpec(ctx context.Context, spec *merge.MergeSpec) errhand.VerboseError { if spec.HeadH == spec.MergeH { //TODO - why is this different for merge/pull? // cli.Println("Already up to date.") @@ -307,9 +274,9 @@ func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec if spec.Squash { cli.Println("Squash commit -- not updating HEAD") } - if len(spec.TblNames) != 0 { + if len(spec.StompedTblNames) != 0 { bldr := errhand.BuildDError("error: Your local changes to the following tables would be overwritten by merge:") - for _, tName := range spec.TblNames { + for _, tName := range spec.StompedTblNames { bldr.AddDetails(tName) } bldr.AddDetails("Please commit your changes before you merge.") @@ -336,6 +303,7 @@ func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec } return nil } + func abortMerge(ctx context.Context, doltEnv *env.DoltEnv) errhand.VerboseError { roots, err := doltEnv.Roots(ctx) if err != nil { diff --git a/go/cmd/dolt/commands/pull.go b/go/cmd/dolt/commands/pull.go index bf2bf0ef83..5b4f789aab 100644 --- a/go/cmd/dolt/commands/pull.go +++ b/go/cmd/dolt/commands/pull.go @@ -177,7 +177,7 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec) } } - err = mergePrinting(ctx, dEnv, mergeSpec) + err = validateMergeSpec(ctx, mergeSpec) if !ok { return nil } diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 046e4c59bd..3f93a5bcec 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -25,6 +25,7 @@ import ( gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/vitess/go/mysql" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" @@ -33,7 +34,7 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" "github.com/dolthub/dolt/go/libraries/doltcore/env" _ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/privileges" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/mysql_file_handler" ) // Serve starts a MySQL-compatible server. Returns any errors that were encountered. @@ -165,32 +166,65 @@ func Serve( serverConf.TLSConfig = tlsConfig serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport() - if serverConfig.PrivilegeFilePath() != "" { - privileges.SetFilePath(serverConfig.PrivilegeFilePath()) + // Set mysql.db file path from server + if serverConfig.MySQLDbFilePath() != "" { + mysql_file_handler.SetMySQLDbFilePath(serverConfig.MySQLDbFilePath()) } - users, roles, err := privileges.LoadPrivileges() + + // Load in MySQL Db from file, if it exists + data, err := mysql_file_handler.LoadData() if err != nil { - return err, nil + return nil, err } + + // Use privilege file iff mysql.db file DNE + var users []*mysql_db.User + var roles []*mysql_db.RoleEdge var tempUsers []gms.TemporaryUser - if len(users) == 0 && len(serverConfig.User()) > 0 { - tempUsers = append(tempUsers, gms.TemporaryUser{ - Username: serverConfig.User(), - Password: serverConfig.Password(), - }) + if len(data) == 0 { + // Set privilege file path from server + if serverConfig.PrivilegeFilePath() != "" { + mysql_file_handler.SetPrivilegeFilePath(serverConfig.PrivilegeFilePath()) + } + + // Load privileges from privilege file + users, roles, err = mysql_file_handler.LoadPrivileges() + if err != nil { + return err, nil + } + + // Create temporary users if no privileges in config + if len(users) == 0 && len(serverConfig.User()) > 0 { + tempUsers = append(tempUsers, gms.TemporaryUser{ + Username: serverConfig.User(), + Password: serverConfig.Password(), + }) + } } + + // Create SQL Engine with users sqlEngine, err := engine.NewSqlEngine(ctx, mrEnv, engine.FormatTabular, "", isReadOnly, tempUsers, serverConfig.AutoCommit()) if err != nil { return err, nil } defer sqlEngine.Close() - sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.GrantTables.SetPersistCallback(privileges.SavePrivileges) - err = sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.GrantTables.LoadData(sql.NewEmptyContext(), users, roles) + // Load in MySQL DB information + err = sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb.LoadData(sql.NewEmptyContext(), data) if err != nil { return err, nil } + // Load in Privilege data iff mysql db didn't exist + if len(data) == 0 { + err = sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb.LoadPrivilegeData(sql.NewEmptyContext(), users, roles) + if err != nil { + return err, nil + } + } + + // Set persist callbacks + sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb.SetPersistCallback(mysql_file_handler.SaveData) labels := serverConfig.MetricsLabels() listener := newMetricsListener(labels) defer listener.Close() diff --git a/go/cmd/dolt/commands/sqlserver/serverconfig.go b/go/cmd/dolt/commands/sqlserver/serverconfig.go index 591b95447b..eb92dcf01c 100644 --- a/go/cmd/dolt/commands/sqlserver/serverconfig.go +++ b/go/cmd/dolt/commands/sqlserver/serverconfig.go @@ -49,6 +49,8 @@ const ( defaultDataDir = "." defaultMetricsHost = "" defaultMetricsPort = -1 + defaultMySQLDbFilePath = "mysql.db" + defaultPrivilegeFilePath = "privs.json" ) const ( @@ -125,6 +127,8 @@ type ServerConfig interface { // PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a // JSON string. PrivilegeFilePath() string + // MySQLDbFilePath returns the path to the file which contains the information for a MySQL db. + MySQLDbFilePath() string } type commandLineServerConfig struct { @@ -145,6 +149,7 @@ type commandLineServerConfig struct { requireSecureTransport bool persistenceBehavior string privilegeFilePath string + mysqlDbFilePath string } var _ ServerConfig = (*commandLineServerConfig)(nil) @@ -241,6 +246,10 @@ func (cfg *commandLineServerConfig) PrivilegeFilePath() string { return cfg.privilegeFilePath } +func (cfg *commandLineServerConfig) MySQLDbFilePath() string { + return cfg.mysqlDbFilePath +} + // DatabaseNamesAndPaths returns an array of env.EnvNameAndPathObjects corresponding to the databases to be loaded in // a multiple db configuration. If nil is returned the server will look for a database in the current directory and // give it a name automatically. @@ -342,6 +351,8 @@ func DefaultServerConfig() *commandLineServerConfig { queryParallelism: defaultQueryParallelism, persistenceBehavior: defaultPersistenceBahavior, dataDir: defaultDataDir, + privilegeFilePath: defaultPrivilegeFilePath, + mysqlDbFilePath: defaultMySQLDbFilePath, } } diff --git a/go/cmd/dolt/commands/sqlserver/yaml_config.go b/go/cmd/dolt/commands/sqlserver/yaml_config.go index d38537950c..2a9d63716a 100644 --- a/go/cmd/dolt/commands/sqlserver/yaml_config.go +++ b/go/cmd/dolt/commands/sqlserver/yaml_config.go @@ -118,6 +118,7 @@ type YAMLConfig struct { DataDirStr *string `yaml:"data_dir"` MetricsConfig MetricsYAMLConfig `yaml:"metrics"` PrivilegeFile *string `yaml:"privilege_file"` + MySQLDbFile *string `yaml:"mysql_db_file"` } var _ ServerConfig = YAMLConfig{} @@ -324,6 +325,13 @@ func (cfg YAMLConfig) PrivilegeFilePath() string { return "" } +func (cfg YAMLConfig) MySQLDbFilePath() string { + if cfg.MySQLDbFile != nil { + return *cfg.MySQLDbFile + } + return "" +} + // QueryParallelism returns the parallelism that should be used by the go-mysql-server analyzer func (cfg YAMLConfig) QueryParallelism() int { if cfg.PerformanceConfig.QueryParallelism == nil { diff --git a/go/gen/fb/serial/encoding.go b/go/gen/fb/serial/encoding.go index cb5f6626e8..19d0345757 100644 --- a/go/gen/fb/serial/encoding.go +++ b/go/gen/fb/serial/encoding.go @@ -34,11 +34,14 @@ const ( EncodingUint64 Encoding = 10 EncodingFloat32 Encoding = 11 EncodingFloat64 Encoding = 12 - EncodingHash128 Encoding = 13 - EncodingYear Encoding = 14 - EncodingDate Encoding = 15 - EncodingTime Encoding = 16 - EncodingDatetime Encoding = 17 + EncodingBit64 Encoding = 13 + EncodingHash128 Encoding = 14 + EncodingYear Encoding = 15 + EncodingDate Encoding = 16 + EncodingTime Encoding = 17 + EncodingDatetime Encoding = 18 + EncodingEnum Encoding = 19 + EncodingSet Encoding = 20 EncodingString Encoding = 128 EncodingBytes Encoding = 129 EncodingDecimal Encoding = 130 @@ -58,11 +61,14 @@ var EnumNamesEncoding = map[Encoding]string{ EncodingUint64: "Uint64", EncodingFloat32: "Float32", EncodingFloat64: "Float64", + EncodingBit64: "Bit64", EncodingHash128: "Hash128", EncodingYear: "Year", EncodingDate: "Date", EncodingTime: "Time", EncodingDatetime: "Datetime", + EncodingEnum: "Enum", + EncodingSet: "Set", EncodingString: "String", EncodingBytes: "Bytes", EncodingDecimal: "Decimal", @@ -82,11 +88,14 @@ var EnumValuesEncoding = map[string]Encoding{ "Uint64": EncodingUint64, "Float32": EncodingFloat32, "Float64": EncodingFloat64, + "Bit64": EncodingBit64, "Hash128": EncodingHash128, "Year": EncodingYear, "Date": EncodingDate, "Time": EncodingTime, "Datetime": EncodingDatetime, + "Enum": EncodingEnum, + "Set": EncodingSet, "String": EncodingString, "Bytes": EncodingBytes, "Decimal": EncodingDecimal, diff --git a/go/go.mod b/go/go.mod index 624214ae62..67076e4603 100644 --- a/go/go.mod +++ b/go/go.mod @@ -19,7 +19,7 @@ require ( github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58 + github.com/dolthub/vitess v0.0.0-20220525003637-9c94a4060dd1 github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.9.0 github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 @@ -68,8 +68,8 @@ require ( ) require ( - github.com/dolthub/go-mysql-server v0.11.1-0.20220520215413-e432fd42d22f - github.com/google/flatbuffers v2.0.5+incompatible + github.com/dolthub/go-mysql-server v0.11.1-0.20220531182937-257f07bd27e5 + github.com/google/flatbuffers v2.0.6+incompatible github.com/gosuri/uilive v0.0.4 github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 github.com/prometheus/client_golang v1.11.0 @@ -135,7 +135,7 @@ require ( golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20210506142907-4a47615972c2 // indirect - gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect + gopkg.in/yaml.v3 v3.0.0 // indirect ) replace ( diff --git a/go/go.sum b/go/go.sum index 369fa4fd20..97e73ee99b 100755 --- a/go/go.sum +++ b/go/go.sum @@ -178,8 +178,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.11.1-0.20220520215413-e432fd42d22f h1:TeqkrPthoXq/CZ5o6cM7AaIIxzMmiHzgK2H+16dOcg0= -github.com/dolthub/go-mysql-server v0.11.1-0.20220520215413-e432fd42d22f/go.mod h1:h0gpkn07YqshhXbeNkOfII0uV+I37SJYyvccH77+FOk= +github.com/dolthub/go-mysql-server v0.11.1-0.20220531182937-257f07bd27e5 h1:EuTulidBelA0x5c3OqwkC4yuNfnodxJGsGnjSPghPVQ= +github.com/dolthub/go-mysql-server v0.11.1-0.20220531182937-257f07bd27e5/go.mod h1:t8kUmFCl4oCVkMkRxgf7qROSn+5lQsFAUU5TZdoleI8= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= @@ -188,8 +188,8 @@ github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66 h1:WRPDbpJWEnPxP github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66/go.mod h1:N5ZIbMGuDUpTpOFQ7HcsN6WSIpTGQjHP+Mz27AfmAgk= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58 h1:v7uMbJKhb9zi2Nz3pxDOUVfWO30E5wbSckVq7AjgXRw= -github.com/dolthub/vitess v0.0.0-20220517011201-8f50d80eae58/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI= +github.com/dolthub/vitess v0.0.0-20220525003637-9c94a4060dd1 h1:lwzjI/92DnlmpgNqK+KV0oC31BQ/r6VE6RqDJAcb3GY= +github.com/dolthub/vitess v0.0.0-20220525003637-9c94a4060dd1/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -357,8 +357,8 @@ github.com/golangci/revgrep v0.0.0-20180526074752-d9c87f5ffaf0/go.mod h1:qOQCunE github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4/go.mod h1:Izgrg8RkN3rCIMLGE9CyYmU9pY2Jer6DgANEnZ/L/cQ= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/flatbuffers v2.0.5+incompatible h1:ANsW0idDAXIY+mNHzIHxWRfabV2x5LUEEIIWcwsYgB8= -github.com/google/flatbuffers v2.0.5+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v2.0.6+incompatible h1:XHFReMv7nFFusa+CEokzWbzaYocKXI6C7hdU5Kgh9Lw= +github.com/google/flatbuffers v2.0.6+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -1271,8 +1271,8 @@ gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/go/libraries/doltcore/doltdb/durable/table.go b/go/libraries/doltcore/doltdb/durable/table.go index dd9f78262f..0290dcac1e 100644 --- a/go/libraries/doltcore/doltdb/durable/table.go +++ b/go/libraries/doltcore/doltdb/durable/table.go @@ -31,6 +31,7 @@ import ( "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/pool" "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" ) @@ -201,6 +202,7 @@ func RefFromNomsTable(ctx context.Context, table Table) (types.Ref, error) { return refFromNomsValue(ctx, nt.vrw, nt.tableStruct) } ddt := table.(doltDevTable) + return refFromNomsValue(ctx, ddt.vrw, ddt.nomsValue()) } @@ -574,6 +576,17 @@ func (t nomsTable) DebugString(ctx context.Context) string { } } + buf.WriteString("\ndata:\n") + data, err := t.GetTableRows(ctx) + if err != nil { + panic(err) + } + + err = types.WriteEncodedValue(ctx, &buf, NomsMapFromIndex(data)) + if err != nil { + panic(err) + } + return buf.String() } @@ -625,7 +638,24 @@ type doltDevTable struct { } func (t doltDevTable) DebugString(ctx context.Context) string { - return "doltDevTable has no DebugString" // TODO: fill in + rows, err := t.GetTableRows(ctx) + if err != nil { + panic(err) + } + + if t.vrw.Format() == types.Format_DOLT_DEV { + m := NomsMapFromIndex(rows) + var b bytes.Buffer + _ = types.WriteEncodedValue(ctx, &b, m) + return b.String() + } else { + m := ProllyMapFromIndex(rows) + var b bytes.Buffer + m.WalkNodes(ctx, func(ctx context.Context, nd tree.Node) error { + return tree.OutputProllyNode(&b, nd) + }) + return b.String() + } } var _ Table = doltDevTable{} diff --git a/go/libraries/doltcore/doltdb/errors.go b/go/libraries/doltcore/doltdb/errors.go index 2cf51f950a..ecb806bf1d 100644 --- a/go/libraries/doltcore/doltdb/errors.go +++ b/go/libraries/doltcore/doltdb/errors.go @@ -46,8 +46,7 @@ var ErrUpToDate = errors.New("up to date") var ErrIsAhead = errors.New("current fast forward from a to b. a is ahead of b already") var ErrIsBehind = errors.New("cannot reverse from b to a. b is a is behind a already") -var ErrUnresolvedConflicts = errors.New("merge has unresolved conflicts. please use the dolt_conflicts table to resolve") -var ErrUnresolvedConstraintViolations = errors.New("merge has unresolved constraint violations. please use the dolt_constraint_violations table to resolve") +var ErrUnresolvedConflictsOrViolations = errors.New("merge has unresolved conflicts or constraint violations") var ErrMergeActive = errors.New("merging is not possible because you have not committed an active merge") type ErrClientOutOfDate struct { diff --git a/go/libraries/doltcore/doltdb/foreign_key_coll.go b/go/libraries/doltcore/doltdb/foreign_key_coll.go index 9296c1b22d..54bf1ed177 100644 --- a/go/libraries/doltcore/doltdb/foreign_key_coll.go +++ b/go/libraries/doltcore/doltdb/foreign_key_coll.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "fmt" "io" "sort" @@ -316,6 +317,37 @@ func (fkc *ForeignKeyCollection) GetByNameCaseInsensitive(foreignKeyName string) return ForeignKey{}, false } +type FkIndexUpdate struct { + FkName string + FromIdx string + ToIdx string +} + +// UpdateIndexes updates the indexes used by the foreign keys as outlined by the update structs given. All foreign +// keys updated in this manner must belong to the same table, whose schema is provided. +func (fkc *ForeignKeyCollection) UpdateIndexes(ctx context.Context, tableSchema schema.Schema, updates []FkIndexUpdate) error { + for _, u := range updates { + fk, ok := fkc.GetByNameCaseInsensitive(u.FkName) + if !ok { + return errors.New("foreign key not found") + } + fkc.RemoveKeys(fk) + fk.ReferencedTableIndex = u.ToIdx + + err := fkc.AddKeys(fk) + if err != nil { + return err + } + + err = fk.ValidateReferencedTableSchema(tableSchema) + if err != nil { + return err + } + } + + return nil +} + // GetByTags gets the ForeignKey defined over the parent and child columns corresponding to their tags. func (fkc *ForeignKeyCollection) GetByTags(childTags, parentTags []uint64) (ForeignKey, bool) { if len(childTags) == 0 || len(parentTags) == 0 { diff --git a/go/libraries/doltcore/doltdb/root_val.go b/go/libraries/doltcore/doltdb/root_val.go index 8446e9f18f..998b6ea384 100644 --- a/go/libraries/doltcore/doltdb/root_val.go +++ b/go/libraries/doltcore/doltdb/root_val.go @@ -1326,21 +1326,10 @@ func (root *RootValue) DebugString(ctx context.Context, transitive bool) string root.IterTables(ctx, func(name string, table *Table, sch schema.Schema) (stop bool, err error) { buf.WriteString("\nTable ") buf.WriteString(name) - buf.WriteString("\n") + buf.WriteString(":\n") - buf.WriteString("Struct:\n") buf.WriteString(table.DebugString(ctx)) - buf.WriteString("\ndata:\n") - data, err := table.GetNomsRowData(ctx) - if err != nil { - panic(err) - } - - err = types.WriteEncodedValue(ctx, &buf, data) - if err != nil { - panic(err) - } return false, nil }) } diff --git a/go/libraries/doltcore/env/actions/commit.go b/go/libraries/doltcore/env/actions/commit.go index 11214fdf4b..1b26c0ec61 100644 --- a/go/libraries/doltcore/env/actions/commit.go +++ b/go/libraries/doltcore/env/actions/commit.go @@ -16,14 +16,12 @@ package actions import ( "context" - "sort" "time" "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/store/datas" - "github.com/dolthub/dolt/go/store/hash" ) type CommitStagedProps struct { @@ -229,85 +227,3 @@ func GetCommitStaged( return ddb.NewPendingCommit(ctx, roots, rsr.CWBHeadRef(), mergeParents, meta) } - -// TimeSortedCommits returns a reverse-chronological (latest-first) list of the most recent `n` ancestors of `commit`. -// Passing a negative value for `n` will result in all ancestors being returned. -func TimeSortedCommits(ctx context.Context, ddb *doltdb.DoltDB, commit *doltdb.Commit, n int) ([]*doltdb.Commit, error) { - hashToCommit := make(map[hash.Hash]*doltdb.Commit) - err := AddCommits(ctx, ddb, commit, hashToCommit, n) - - if err != nil { - return nil, err - } - - idx := 0 - uniqueCommits := make([]*doltdb.Commit, len(hashToCommit)) - for _, v := range hashToCommit { - uniqueCommits[idx] = v - idx++ - } - - var sortErr error - var metaI, metaJ *datas.CommitMeta - sort.Slice(uniqueCommits, func(i, j int) bool { - if sortErr != nil { - return false - } - - metaI, sortErr = uniqueCommits[i].GetCommitMeta(ctx) - - if sortErr != nil { - return false - } - - metaJ, sortErr = uniqueCommits[j].GetCommitMeta(ctx) - - if sortErr != nil { - return false - } - - return metaI.UserTimestamp > metaJ.UserTimestamp - }) - - if sortErr != nil { - return nil, sortErr - } - - return uniqueCommits, nil -} - -func AddCommits(ctx context.Context, ddb *doltdb.DoltDB, commit *doltdb.Commit, hashToCommit map[hash.Hash]*doltdb.Commit, n int) error { - hash, err := commit.HashOf() - - if err != nil { - return err - } - - if _, ok := hashToCommit[hash]; ok { - return nil - } - - hashToCommit[hash] = commit - - numParents, err := commit.NumParents() - - if err != nil { - return err - } - - for i := 0; i < numParents && len(hashToCommit) != n; i++ { - parentCommit, err := ddb.ResolveParent(ctx, commit, i) - - if err != nil { - return err - } - - err = AddCommits(ctx, ddb, parentCommit, hashToCommit, n) - - if err != nil { - return err - } - } - - return nil -} diff --git a/go/libraries/doltcore/merge/action.go b/go/libraries/doltcore/merge/action.go index 0c7ceb112a..c2db042448 100644 --- a/go/libraries/doltcore/merge/action.go +++ b/go/libraries/doltcore/merge/action.go @@ -35,73 +35,73 @@ var ErrMergeFailedToUpdateRepoState = errors.New("unable to execute repo state u var ErrFailedToDetermineMergeability = errors.New("failed to determine mergeability") type MergeSpec struct { - HeadH hash.Hash - MergeH hash.Hash - HeadC *doltdb.Commit - MergeC *doltdb.Commit - TblNames []string - WorkingDiffs map[string]hash.Hash - Squash bool - Msg string - Noff bool - Force bool - AllowEmpty bool - Email string - Name string - Date time.Time + HeadH hash.Hash + MergeH hash.Hash + HeadC *doltdb.Commit + MergeC *doltdb.Commit + StompedTblNames []string + WorkingDiffs map[string]hash.Hash + Squash bool + Msg string + Noff bool + Force bool + AllowEmpty bool + Email string + Name string + Date time.Time } func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.DoltDB, roots doltdb.Roots, name, email, msg string, commitSpecStr string, squash bool, noff bool, force bool, date time.Time) (*MergeSpec, bool, error) { - cs1, err := doltdb.NewCommitSpec("HEAD") + headCS, err := doltdb.NewCommitSpec("HEAD") if err != nil { return nil, false, err } - cm1, err := ddb.Resolve(context.TODO(), cs1, rsr.CWBHeadRef()) + headCM, err := ddb.Resolve(context.TODO(), headCS, rsr.CWBHeadRef()) if err != nil { return nil, false, err } - cs2, err := doltdb.NewCommitSpec(commitSpecStr) + mergeCS, err := doltdb.NewCommitSpec(commitSpecStr) if err != nil { return nil, false, err } - cm2, err := ddb.Resolve(context.TODO(), cs2, rsr.CWBHeadRef()) + mergeCM, err := ddb.Resolve(context.TODO(), mergeCS, rsr.CWBHeadRef()) if err != nil { return nil, false, err } - h1, err := cm1.HashOf() + headH, err := headCM.HashOf() if err != nil { return nil, false, err } - h2, err := cm2.HashOf() + mergeH, err := mergeCM.HashOf() if err != nil { return nil, false, err } - tblNames, workingDiffs, err := MergeWouldStompChanges(ctx, roots, cm2) + stompedTblNames, workingDiffs, err := MergeWouldStompChanges(ctx, roots, mergeCM) if err != nil { return nil, false, fmt.Errorf("%w; %s", ErrFailedToDetermineMergeability, err.Error()) } return &MergeSpec{ - HeadH: h1, - MergeH: h2, - HeadC: cm1, - MergeC: cm2, - TblNames: tblNames, - WorkingDiffs: workingDiffs, - Squash: squash, - Msg: msg, - Noff: noff, - Force: force, - Email: email, - Name: name, - Date: date, + HeadH: headH, + MergeH: mergeH, + HeadC: headCM, + MergeC: mergeCM, + StompedTblNames: stompedTblNames, + WorkingDiffs: workingDiffs, + Squash: squash, + Msg: msg, + Noff: noff, + Force: force, + Email: email, + Name: name, + Date: date, }, true, nil } diff --git a/go/libraries/doltcore/merge/merge.go b/go/libraries/doltcore/merge/merge.go index 8eee4e3acd..d1fa989f19 100644 --- a/go/libraries/doltcore/merge/merge.go +++ b/go/libraries/doltcore/merge/merge.go @@ -43,6 +43,13 @@ var ErrFastForward = errors.New("fast forward") var ErrSameTblAddedTwice = errors.New("table with same name added in 2 commits can't be merged") var ErrTableDeletedAndModified = errors.New("conflict: table with same name deleted and modified ") +// ErrCantOverwriteConflicts is returned when there are unresolved conflicts +// and the merge produces new conflicts. Because we currently don't have a model +// to merge sets of conflicts together, we need to abort the merge at this +// point. +var ErrCantOverwriteConflicts = errors.New("existing unresolved conflicts would be" + + " overridden by new conflicts produced by merge. Please resolve them and try again") + type Merger struct { root *doltdb.RootValue mergeRoot *doltdb.RootValue @@ -893,7 +900,26 @@ func MergeCommits(ctx context.Context, commit, mergeCommit *doltdb.Commit, opts return MergeRoots(ctx, ourRoot, theirRoot, ancRoot, opts) } +// MergeRoots three-way merges |ourRoot|, |theirRoot|, and |ancRoot| and returns +// the merged root. If any conflicts or constraint violations are produced they +// are stored in the merged root. If |ourRoot| already contains conflicts they +// are stashed before the merge is performed. We abort the merge if the stash +// contains conflicts and we produce new conflicts. We currently don't have a +// model to merge conflicts together. +// +// Constraint violations that exist in ancestor are stashed and merged with the +// violations we detect when we diff the ancestor and the newly merged root. func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootValue, opts editor.Options) (*doltdb.RootValue, map[string]*MergeStats, error) { + ourRoot, conflictStash, err := stashConflicts(ctx, ourRoot) + if err != nil { + return nil, nil, err + } + + ancRoot, violationStash, err := stashViolations(ctx, ancRoot) + if err != nil { + return nil, nil, err + } + tblNames, err := doltdb.UnionTableNames(ctx, ourRoot, theirRoot) if err != nil { @@ -902,7 +928,7 @@ func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootVal tblToStats := make(map[string]*MergeStats) - newRoot := ourRoot + mergedRoot := ourRoot optsWithFKChecks := opts optsWithFKChecks.ForeignKeyChecksDisabled = true @@ -919,14 +945,14 @@ func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootVal if mergedTable != nil { tblToStats[tblName] = stats - newRoot, err = newRoot.PutTable(ctx, tblName, mergedTable) + mergedRoot, err = mergedRoot.PutTable(ctx, tblName, mergedTable) if err != nil { return nil, nil, err } continue } - newRootHasTable, err := newRoot.HasTable(ctx, tblName) + newRootHasTable, err := mergedRoot.HasTable(ctx, tblName) if err != nil { return nil, nil, err } @@ -935,7 +961,7 @@ func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootVal // Merge root deleted this table tblToStats[tblName] = &MergeStats{Operation: TableRemoved} - newRoot, err = newRoot.RemoveTables(ctx, false, false, tblName) + mergedRoot, err = mergedRoot.RemoveTables(ctx, false, false, tblName) if err != nil { return nil, nil, err } @@ -949,7 +975,7 @@ func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootVal } } - mergedFKColl, conflicts, err := ForeignKeysMerge(ctx, newRoot, ourRoot, theirRoot, ancRoot) + mergedFKColl, conflicts, err := ForeignKeysMerge(ctx, mergedRoot, ourRoot, theirRoot, ancRoot) if err != nil { return nil, nil, err } @@ -957,31 +983,94 @@ func MergeRoots(ctx context.Context, ourRoot, theirRoot, ancRoot *doltdb.RootVal return nil, nil, fmt.Errorf("foreign key conflicts") } - newRoot, err = newRoot.PutForeignKeyCollection(ctx, mergedFKColl) + mergedRoot, err = mergedRoot.PutForeignKeyCollection(ctx, mergedFKColl) if err != nil { return nil, nil, err } - newRoot, err = newRoot.UpdateSuperSchemasFromOther(ctx, tblNames, theirRoot) + mergedRoot, err = mergedRoot.UpdateSuperSchemasFromOther(ctx, tblNames, theirRoot) if err != nil { return nil, nil, err } - newRoot, _, err = AddConstraintViolations(ctx, newRoot, ancRoot, nil) + mergedRoot, _, err = AddConstraintViolations(ctx, mergedRoot, ancRoot, nil) if err != nil { return nil, nil, err } - err = calculateViolationStats(ctx, newRoot, tblToStats) + mergedRoot, err = mergeCVsWithStash(ctx, mergedRoot, violationStash) if err != nil { return nil, nil, err } - return newRoot, tblToStats, nil + err = calculateViolationStats(ctx, mergedRoot, tblToStats) + if err != nil { + return nil, nil, err + } + + mergedHasConflicts := checkForConflicts(tblToStats) + if !conflictStash.Empty() && mergedHasConflicts { + return nil, nil, ErrCantOverwriteConflicts + } else if !conflictStash.Empty() { + mergedRoot, err = applyConflictStash(ctx, conflictStash.Stash, mergedRoot) + if err != nil { + return nil, nil, err + } + } + + return mergedRoot, tblToStats, nil } -func calculateViolationStats(ctx context.Context, root *doltdb.RootValue, tblToStats map[string]*MergeStats) error { +// mergeCVsWithStash merges the table constraint violations in |stash| with |root|. +// Returns an updated root with all the merged CVs. +func mergeCVsWithStash(ctx context.Context, root *doltdb.RootValue, stash *violationStash) (*doltdb.RootValue, error) { + updatedRoot := root + for name, stashed := range stash.Stash { + tbl, ok, err := root.GetTable(ctx, name) + if err != nil { + return nil, err + } + if !ok { + // the table with the CVs was deleted + continue + } + curr, err := tbl.GetConstraintViolations(ctx) + if err != nil { + return nil, err + } + unioned, err := types.UnionMaps(ctx, curr, stashed, func(key types.Value, currV types.Value, stashV types.Value) (types.Value, error) { + if !currV.Equals(stashV) { + panic(fmt.Sprintf("encountered conflict when merging constraint violations, conflicted key: %v\ncurrent value: %v\nstashed value: %v\n", key, currV, stashV)) + } + return currV, nil + }) + if err != nil { + return nil, err + } + tbl, err = tbl.SetConstraintViolations(ctx, unioned) + if err != nil { + return nil, err + } + updatedRoot, err = root.PutTable(ctx, name, tbl) + if err != nil { + return nil, err + } + } + return updatedRoot, nil +} +// checks if a conflict occurred during the merge +func checkForConflicts(tblToStats map[string]*MergeStats) bool { + for _, stat := range tblToStats { + if stat.Conflicts > 0 { + return true + } + } + return false +} + +// populates tblToStats with violation statistics +func calculateViolationStats(ctx context.Context, root *doltdb.RootValue, tblToStats map[string]*MergeStats) error { for tblName, stats := range tblToStats { tbl, ok, err := root.GetTable(ctx, tblName) if err != nil { diff --git a/go/libraries/doltcore/merge/stash.go b/go/libraries/doltcore/merge/stash.go new file mode 100644 index 0000000000..ca7b6abd7d --- /dev/null +++ b/go/libraries/doltcore/merge/stash.go @@ -0,0 +1,182 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package merge + +import ( + "context" + + "github.com/dolthub/dolt/go/libraries/doltcore/conflict" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" + "github.com/dolthub/dolt/go/store/types" +) + +type conflictStash struct { + Stash map[string]*conflictData +} + +type conflictData struct { + HasConflicts bool + Sch conflict.ConflictSchema + ConfIdx durable.ConflictIndex +} + +// Empty returns false if any table has a conflict. +// True otherwise. +func (s *conflictStash) Empty() bool { + for _, data := range s.Stash { + if data.HasConflicts { + return false + } + } + return true +} + +type violationStash struct { + // todo: durable + Stash map[string]types.Map +} + +// Empty returns false if any table has constraint violations. +// True otherwise. +func (s *violationStash) Empty() bool { + for _, data := range s.Stash { + if data.Len() > 0 { + return false + } + } + return true +} + +func stashConflicts(ctx context.Context, root *doltdb.RootValue) (*doltdb.RootValue, *conflictStash, error) { + names, err := root.GetTableNames(ctx) + if err != nil { + return nil, nil, err + } + + updatedRoot := root + stash := make(map[string]*conflictData, len(names)) + for _, name := range names { + tbl, _, err := root.GetTable(ctx, name) + if err != nil { + return nil, nil, err + } + d, err := getConflictData(ctx, tbl) + if err != nil { + return nil, nil, err + } + stash[name] = d + tbl, err = tbl.ClearConflicts(ctx) + if err != nil { + return nil, nil, err + } + updatedRoot, err = updatedRoot.PutTable(ctx, name, tbl) + if err != nil { + return nil, nil, err + } + } + + return updatedRoot, &conflictStash{stash}, nil +} + +func stashViolations(ctx context.Context, root *doltdb.RootValue) (*doltdb.RootValue, *violationStash, error) { + names, err := root.GetTableNames(ctx) + if err != nil { + return nil, nil, err + } + + updatedRoot := root + stash := make(map[string]types.Map, len(names)) + for _, name := range names { + tbl, _, err := root.GetTable(ctx, name) + if err != nil { + return nil, nil, err + } + v, err := tbl.GetConstraintViolations(ctx) + stash[name] = v + tbl, err = tbl.SetConstraintViolations(ctx, types.EmptyMap) + if err != nil { + return nil, nil, err + } + updatedRoot, err = updatedRoot.PutTable(ctx, name, tbl) + if err != nil { + return nil, nil, err + } + } + + return updatedRoot, &violationStash{stash}, nil +} + +// applyConflictStash applies the data in |stash| to the root value. Missing +// tables will be skipped. This function will override any previous conflict +// data. +func applyConflictStash(ctx context.Context, stash map[string]*conflictData, root *doltdb.RootValue) (*doltdb.RootValue, error) { + updatedRoot := root + for name, data := range stash { + tbl, ok, err := root.GetTable(ctx, name) + if err != nil { + return nil, err + } + if !ok { + continue + } + tbl, err = setConflictData(ctx, tbl, data) + if err != nil { + return nil, err + } + updatedRoot, err = updatedRoot.PutTable(ctx, name, tbl) + if err != nil { + return nil, err + } + } + + return updatedRoot, nil +} + +func getConflictData(ctx context.Context, tbl *doltdb.Table) (*conflictData, error) { + var sch conflict.ConflictSchema + var confIdx durable.ConflictIndex + + hasCnf, err := tbl.HasConflicts(ctx) + if err != nil { + return nil, err + } + if hasCnf { + sch, confIdx, err = tbl.GetConflicts(ctx) + if err != nil { + return nil, err + } + } + + return &conflictData{ + HasConflicts: hasCnf, + Sch: sch, + ConfIdx: confIdx, + }, nil +} + +func setConflictData(ctx context.Context, tbl *doltdb.Table, data *conflictData) (*doltdb.Table, error) { + var err error + if !data.HasConflicts { + tbl, err = tbl.ClearConflicts(ctx) + } else { + tbl, err = tbl.SetConflicts(ctx, data.Sch, data.ConfIdx) + } + if err != nil { + return nil, err + } + + return tbl, nil +} diff --git a/go/libraries/doltcore/sqle/alterschema.go b/go/libraries/doltcore/sqle/alterschema.go index e33a658d15..9f2997782c 100755 --- a/go/libraries/doltcore/sqle/alterschema.go +++ b/go/libraries/doltcore/sqle/alterschema.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "sort" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -778,6 +779,67 @@ func dropColumn(ctx context.Context, tbl *doltdb.Table, colName string) (*doltdb return tbl.UpdateSchema(ctx, newSch) } +// backupFkcIndexesForKeyDrop finds backup indexes to cover foreign key references during a primary +// key drop. If multiple indexes are valid, we sort by unique and select the first. +// This will not work with a non-pk index drop without an additional index filter argument. +func backupFkcIndexesForPkDrop(ctx *sql.Context, sch schema.Schema, fkc *doltdb.ForeignKeyCollection) ([]doltdb.FkIndexUpdate, error) { + indexes := sch.Indexes().AllIndexes() + + // pkBackups is a mapping from the table's PK tags to potentially compensating indexes + pkBackups := make(map[uint64][]schema.Index, len(sch.GetPKCols().TagToIdx)) + for tag, _ := range sch.GetPKCols().TagToIdx { + pkBackups[tag] = nil + } + + // prefer unique key backups + sort.Slice(indexes[:], func(i, j int) bool { + return indexes[i].IsUnique() && !indexes[j].IsUnique() + }) + + for _, idx := range indexes { + if !idx.IsUserDefined() { + continue + } + + for _, tag := range idx.AllTags() { + if _, ok := pkBackups[tag]; ok { + pkBackups[tag] = append(pkBackups[tag], idx) + } + } + } + + fkUpdates := make([]doltdb.FkIndexUpdate, 0) + for _, fk := range fkc.AllKeys() { + // check if this FK references a parent PK tag we are trying to change + if backups, ok := pkBackups[fk.ReferencedTableColumns[0]]; ok { + covered := false + for _, idx := range backups { + idxTags := idx.AllTags() + if len(fk.TableColumns) > len(idxTags) { + continue + } + failed := false + for i := 0; i < len(fk.ReferencedTableColumns); i++ { + if idxTags[i] != fk.ReferencedTableColumns[i] { + failed = true + break + } + } + if failed { + continue + } + fkUpdates = append(fkUpdates, doltdb.FkIndexUpdate{FkName: fk.Name, FromIdx: fk.ReferencedTableIndex, ToIdx: idx.Name()}) + covered = true + break + } + if !covered { + return nil, sql.ErrCantDropIndex.New("PRIMARY") + } + } + } + return fkUpdates, nil +} + func dropPrimaryKeyFromTable(ctx context.Context, table *doltdb.Table, nbf *types.NomsBinFormat, opts editor.Options) (*doltdb.Table, error) { sch, err := table.GetSchema(ctx) if err != nil { diff --git a/go/libraries/doltcore/sqle/alterschema_test.go b/go/libraries/doltcore/sqle/alterschema_test.go index 95174b1c9a..2265d6314b 100755 --- a/go/libraries/doltcore/sqle/alterschema_test.go +++ b/go/libraries/doltcore/sqle/alterschema_test.go @@ -16,7 +16,7 @@ package sqle import ( "context" - "errors" + goerrors "errors" "fmt" "testing" @@ -25,6 +25,7 @@ import ( "github.com/dolthub/vitess/go/sqltypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" @@ -445,10 +446,10 @@ func TestDropColumnUsedByIndex(t *testing.T) { func TestDropPks(t *testing.T) { var dropTests = []struct { - name string - setup []string - exit int - fkIdxName string + name string + setup []string + expectedErr *errors.Kind + fkIdxName string }{ { name: "no error on drop pk", @@ -456,7 +457,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id))", "insert into parent values (1,1,1),(2,2,2)", }, - exit: 0, }, { name: "no error if backup index", @@ -464,7 +464,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id), key `backup` (id))", "create table child (id int, name varchar(1), age int, primary key (id), constraint `fk` foreign key (id) references parent (id))", }, - exit: 0, fkIdxName: "backup", }, { @@ -473,7 +472,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `backup` (age))", "create table child (id int, name varchar(1), age int, primary key (id), constraint `fk` foreign key (age) references parent (age))", }, - exit: 0, fkIdxName: "backup", }, { @@ -482,7 +480,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `backup` (id, age))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id, age) references parent (id, age))", }, - exit: 0, fkIdxName: "backup", }, { @@ -491,7 +488,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age, name), key `backup` (id, age))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id, age) references parent (id, age))", }, - exit: 0, fkIdxName: "backup", }, { @@ -500,7 +496,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `backup` (id))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 0, fkIdxName: "backup", }, { @@ -509,7 +504,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `bad_backup1` (age, id), key `bad_backup2` (age), key `backup` (id, age, name))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 0, fkIdxName: "backup", }, { @@ -518,7 +512,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `bad_backup` (age, id), key `backup1` (id), key `backup2` (id, age, name))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 0, fkIdxName: "backup1", }, { @@ -527,7 +520,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id, age), key `bad_backup` (age, id), key `backup1` (id, age, name), unique key `backup2` (id, age))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 0, fkIdxName: "backup2", }, { @@ -536,7 +528,6 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, other int, primary key (id, age, name), key `backup` (id, age, other))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id, age) references parent (id, age))", }, - exit: 0, fkIdxName: "backup", }, { @@ -545,8 +536,8 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id))", "create table child (id int, name varchar(1), age int, primary key (id), constraint `fk` foreign key (id) references parent (id))", }, - exit: 1, - fkIdxName: "id", + expectedErr: sql.ErrCantDropIndex, + fkIdxName: "id", }, { name: "error if FK ref but bad backup index", @@ -554,8 +545,8 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, primary key (id), key `bad_backup2` (age))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 1, - fkIdxName: "id", + expectedErr: sql.ErrCantDropIndex, + fkIdxName: "id", }, { name: "error if misordered compound backup index for FK", @@ -563,8 +554,8 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, constraint `primary` primary key (id), key `backup` (age, id))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (id) references parent (id))", }, - exit: 1, - fkIdxName: "id", + expectedErr: sql.ErrCantDropIndex, + fkIdxName: "id", }, { name: "error if incomplete compound backup index for FK", @@ -572,8 +563,8 @@ func TestDropPks(t *testing.T) { "create table parent (id int, name varchar(1), age int, constraint `primary` primary key (age, id), key `backup` (age, name))", "create table child (id int, name varchar(1), age int, constraint `fk` foreign key (age, id) references parent (age, id))", }, - exit: 1, - fkIdxName: "ageid", + expectedErr: sql.ErrCantDropIndex, + fkIdxName: "ageid", }, } @@ -598,11 +589,14 @@ func TestDropPks(t *testing.T) { } drop := "alter table parent drop primary key" - _, _, err = engine.Query(sqlCtx, drop) - switch tt.exit { - case 1: + _, iter, err := engine.Query(sqlCtx, drop) + require.NoError(t, err) + + err = drainIter(sqlCtx, iter) + if tt.expectedErr != nil { require.Error(t, err) - default: + assert.True(t, tt.expectedErr.Is(err), "Expected error of type %s but got %s", tt.expectedErr, err) + } else { require.NoError(t, err) } @@ -771,7 +765,7 @@ func TestNewPkOrdinals(t *testing.T) { t.Run(tt.name, func(t *testing.T) { res, err := modifyPkOrdinals(oldSch, tt.newSch) if tt.err != nil { - require.True(t, errors.Is(err, tt.err)) + require.True(t, goerrors.Is(err, tt.err)) } else { require.Equal(t, res, tt.expPkOrdinals) } diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 273bbfe8c7..db8e8ba849 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -23,7 +23,7 @@ import ( "time" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/grant_tables" + "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/vitess/go/vt/proto/query" "gopkg.in/src-d/go-errors.v1" @@ -61,7 +61,7 @@ func DbsAsDSQLDBs(dbs []sql.Database) []SqlDatabase { var sqlDb SqlDatabase if sqlDatabase, ok := db.(SqlDatabase); ok { sqlDb = sqlDatabase - } else if privDatabase, ok := db.(grant_tables.PrivilegedDatabase); ok { + } else if privDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok { if sqlDatabase, ok := privDatabase.Unwrap().(SqlDatabase); ok { sqlDb = sqlDatabase } diff --git a/go/libraries/doltcore/sqle/dfunctions/commit.go b/go/libraries/doltcore/sqle/dfunctions/commit.go index ba1ff481b1..b8e2838212 100644 --- a/go/libraries/doltcore/sqle/dfunctions/commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/commit.go @@ -38,7 +38,7 @@ func NewCommitFunc(args ...sql.Expression) (sql.Expression, error) { func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { args, err := getDoltArgs(ctx, row, cf.Children()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } return DoDoltCommit(ctx, args) } diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go index d9df5a3fc2..dfa6a9e0aa 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go @@ -45,7 +45,7 @@ func NewDoltCommitFunc(args ...sql.Expression) (sql.Expression, error) { func (d DoltCommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { args, err := getDoltArgs(ctx, row, d.Children()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } return DoDoltCommit(ctx, args) } diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go index b79d3b9e4b..497bdf7bf7 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go @@ -47,14 +47,14 @@ type DoltMergeFunc struct { const DoltMergeWarningCode int = 1105 // Since this our own custom warning we'll use 1105, the code for an unknown error const ( - hasConflicts int = 0 - noConflicts int = 1 + hasConflictsOrViolations int = 0 + noConflictsOrViolations int = 1 ) func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { args, err := getDoltArgs(ctx, row, d.Children()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } return DoDoltMerge(ctx, args) } @@ -63,57 +63,57 @@ func DoDoltMerge(ctx *sql.Context, args []string) (int, error) { dbName := ctx.GetCurrentDatabase() if len(dbName) == 0 { - return noConflicts, fmt.Errorf("Empty database name.") + return noConflictsOrViolations, fmt.Errorf("Empty database name.") } sess := dsess.DSessFromSess(ctx.Session) apr, err := cli.CreateMergeArgParser().Parse(args) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) { - return noConflicts, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam) + return noConflictsOrViolations, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam) } ws, err := sess.WorkingSet(ctx, dbName) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } roots, ok := sess.GetRoots(ctx, dbName) if !ok { - return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + return noConflictsOrViolations, sql.ErrDatabaseNotFound.New(dbName) } if apr.Contains(cli.AbortParam) { if !ws.MergeActive() { - return noConflicts, fmt.Errorf("fatal: There is no merge to abort") + return noConflictsOrViolations, fmt.Errorf("fatal: There is no merge to abort") } ws, err = abortMerge(ctx, ws, roots) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } err := sess.SetWorkingSet(ctx, dbName, ws) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } err = sess.CommitWorkingSet(ctx, dbName, sess.GetTransaction()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } - return noConflicts, nil + return noConflictsOrViolations, nil } branchName := apr.Arg(0) mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, branchName) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } ws, conflicts, err := mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec) if err != nil { @@ -128,30 +128,20 @@ func DoDoltMerge(ctx *sql.Context, args []string) (int, error) { // persists merge commits in the database, but expects the caller to update the working set. // TODO FF merging commit with constraint violations requires `constraint verify` func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.DoltSession, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec) (*doltdb.WorkingSet, int, error) { - if conflicts, err := roots.Working.HasConflicts(ctx); err != nil { - return ws, noConflicts, err - } else if conflicts { - return ws, hasConflicts, doltdb.ErrUnresolvedConflicts - } - - if hasConstraintViolations, err := roots.Working.HasConstraintViolations(ctx); err != nil { - return ws, hasConflicts, err - } else if hasConstraintViolations { - return ws, hasConflicts, doltdb.ErrUnresolvedConstraintViolations - } + // todo: allow merges even when an existing merge is uncommitted if ws.MergeActive() { - return ws, noConflicts, doltdb.ErrMergeActive + return ws, noConflictsOrViolations, doltdb.ErrMergeActive } err := checkForUncommittedChanges(ctx, roots.Working, roots.Head) if err != nil { - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } dbData, ok := sess.GetDbData(ctx, dbName) if !ok { - return ws, noConflicts, fmt.Errorf("failed to get dbData") + return ws, noConflictsOrViolations, fmt.Errorf("failed to get dbData") } canFF, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC) @@ -160,62 +150,61 @@ func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.DoltSession, roots doltdb case doltdb.ErrIsAhead, doltdb.ErrUpToDate: ctx.Warn(DoltMergeWarningCode, err.Error()) default: - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } } if canFF { if spec.Noff { ws, err = executeNoFFMerge(ctx, sess, spec, dbName, ws, dbData) - if err == doltdb.ErrUnresolvedConflicts { + if err == doltdb.ErrUnresolvedConflictsOrViolations { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message wsErr := sess.SetWorkingSet(ctx, dbName, ws) if wsErr != nil { - return ws, hasConflicts, wsErr + return ws, hasConflictsOrViolations, wsErr } ctx.Warn(DoltMergeWarningCode, err.Error()) - // Return 0 indicating there are conflicts - return ws, hasConflicts, nil + return ws, hasConflictsOrViolations, nil } - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } ws, err = executeFFMerge(ctx, dbName, spec.Squash, ws, dbData, spec.MergeC) - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } dbState, ok, err := sess.LookupDbState(ctx, dbName) if err != nil { - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } else if !ok { - return ws, noConflicts, sql.ErrDatabaseNotFound.New(dbName) + return ws, noConflictsOrViolations, sql.ErrDatabaseNotFound.New(dbName) } ws, err = executeMerge(ctx, spec.Squash, spec.HeadC, spec.MergeC, ws, dbState.EditOpts()) - if err == doltdb.ErrUnresolvedConflicts || err == doltdb.ErrUnresolvedConstraintViolations { + if err == doltdb.ErrUnresolvedConflictsOrViolations { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message wsErr := sess.SetWorkingSet(ctx, dbName, ws) if wsErr != nil { - return ws, hasConflicts, wsErr + return ws, hasConflictsOrViolations, wsErr } ctx.Warn(DoltMergeWarningCode, err.Error()) - return ws, hasConflicts, nil + return ws, hasConflictsOrViolations, nil } else if err != nil { - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } err = sess.SetWorkingSet(ctx, dbName, ws) if err != nil { - return ws, noConflicts, err + return ws, noConflictsOrViolations, err } - return ws, noConflicts, nil + return ws, noConflictsOrViolations, nil } func abortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Roots) (*doltdb.WorkingSet, error) { @@ -400,9 +389,9 @@ func mergeRootToWorking( } ws = ws.WithWorkingRoot(workingRoot).WithStagedRoot(workingRoot) - if checkForConflicts(mergeStats) { + if checkForConflicts(mergeStats) || checkForViolations(mergeStats) { // this error is recoverable in-session, so we return the new ws along with the error - return ws, doltdb.ErrUnresolvedConflicts + return ws, doltdb.ErrUnresolvedConflictsOrViolations } return ws, nil @@ -437,6 +426,15 @@ func checkForConflicts(tblToStats map[string]*merge.MergeStats) bool { return false } +func checkForViolations(tblToStats map[string]*merge.MergeStats) bool { + for _, stats := range tblToStats { + if stats.ConstraintViolations > 0 { + return true + } + } + return false +} + func (d DoltMergeFunc) String() string { childrenStrings := make([]string, len(d.Children())) diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go b/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go index e8ddb5d554..c1c9e30a9e 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go @@ -66,7 +66,7 @@ func (d DoltPullFunc) WithChildren(children ...sql.Expression) (sql.Expression, func (d DoltPullFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { args, err := getDoltArgs(ctx, row, d.Children()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } return DoDoltPull(ctx, args) } @@ -75,22 +75,22 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, error) { dbName := ctx.GetCurrentDatabase() if len(dbName) == 0 { - return noConflicts, fmt.Errorf("empty database name.") + return noConflictsOrViolations, fmt.Errorf("empty database name.") } sess := dsess.DSessFromSess(ctx.Session) dbData, ok := sess.GetDbData(ctx, dbName) if !ok { - return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + return noConflictsOrViolations, sql.ErrDatabaseNotFound.New(dbName) } apr, err := cli.CreatePullArgParser().Parse(args) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } if apr.NArg() > 1 { - return noConflicts, actions.ErrInvalidPullArgs + return noConflictsOrViolations, actions.ErrInvalidPullArgs } var remoteName string @@ -100,17 +100,17 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, error) { pullSpec, err := env.NewPullSpec(ctx, dbData.Rsr, remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag)) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } srcDB, err := pullSpec.Remote.GetRemoteDBWithoutCaching(ctx, dbData.Ddb.ValueReadWriter().Format()) if err != nil { - return noConflicts, fmt.Errorf("failed to get remote db; %w", err) + return noConflictsOrViolations, fmt.Errorf("failed to get remote db; %w", err) } ws, err := sess.WorkingSet(ctx, dbName) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } var conflicts int @@ -122,23 +122,23 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, error) { // todo: can we pass nil for either of the channels? srcDBCommit, err := actions.FetchRemoteBranch(ctx, dbData.Rsw.TempTableFilesDir(), pullSpec.Remote, srcDB, dbData.Ddb, pullSpec.Branch, runProgFuncs, stopProgFuncs) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } // TODO: this could be replaced with a canFF check to test for error err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit) if err != nil { - return noConflicts, fmt.Errorf("fetch failed; %w", err) + return noConflictsOrViolations, fmt.Errorf("fetch failed; %w", err) } roots, ok := sess.GetRoots(ctx, dbName) if !ok { - return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + return noConflictsOrViolations, sql.ErrDatabaseNotFound.New(dbName) } mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, remoteTrackRef.String()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } ws, conflicts, err = mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec) if err != nil && !errors.Is(doltdb.ErrUpToDate, err) { @@ -154,10 +154,10 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, error) { err = actions.FetchFollowTags(ctx, dbData.Rsw.TempTableFilesDir(), srcDB, dbData.Ddb, runProgFuncs, stopProgFuncs) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } - return noConflicts, nil + return noConflictsOrViolations, nil } func pullerProgFunc(ctx context.Context, statsCh <-chan pull.Stats) { diff --git a/go/libraries/doltcore/sqle/dfunctions/merge.go b/go/libraries/doltcore/sqle/dfunctions/merge.go index c7ef694f01..56ee98819e 100644 --- a/go/libraries/doltcore/sqle/dfunctions/merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/merge.go @@ -41,7 +41,7 @@ func NewMergeFunc(args ...sql.Expression) (sql.Expression, error) { func (mf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { args, err := getDoltArgs(ctx, row, mf.Children()) if err != nil { - return noConflicts, err + return noConflictsOrViolations, err } return DoDoltMerge(ctx, args) } diff --git a/go/libraries/doltcore/sqle/dsess/transactions.go b/go/libraries/doltcore/sqle/dsess/transactions.go index 34a32f4c3f..4a0a54d48f 100644 --- a/go/libraries/doltcore/sqle/dsess/transactions.go +++ b/go/libraries/doltcore/sqle/dsess/transactions.go @@ -191,7 +191,7 @@ func (tx *DoltTransaction) doCommit( return nil, nil, err } - wsHash, err := existingWs.HashOf() + existingWSHash, err := existingWs.HashOf() if err != nil { return nil, nil, err } @@ -204,7 +204,7 @@ func (tx *DoltTransaction) doCommit( } var newCommit *doltdb.Commit - workingSet, newCommit, err = writeFn(ctx, tx, commit, workingSet, wsHash) + workingSet, newCommit, err = writeFn(ctx, tx, commit, workingSet, existingWSHash) if err == datas.ErrOptimisticLockFailed { // this is effectively a `continue` in the loop return nil, nil, nil @@ -234,7 +234,7 @@ func (tx *DoltTransaction) doCommit( } var newCommit *doltdb.Commit - mergedWorkingSet, newCommit, err = writeFn(ctx, tx, commit, mergedWorkingSet, wsHash) + mergedWorkingSet, newCommit, err = writeFn(ctx, tx, commit, mergedWorkingSet, existingWSHash) if err == datas.ErrOptimisticLockFailed { // this is effectively a `continue` in the loop return nil, nil, nil diff --git a/go/libraries/doltcore/sqle/dtables/log_table.go b/go/libraries/doltcore/sqle/dtables/log_table.go index 3b07177a0b..7e213315c3 100644 --- a/go/libraries/doltcore/sqle/dtables/log_table.go +++ b/go/libraries/doltcore/sqle/dtables/log_table.go @@ -15,12 +15,11 @@ package dtables import ( - "io" - "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" - "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" ) @@ -72,41 +71,33 @@ func (dt *LogTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.RowIte // LogItr is a sql.RowItr implementation which iterates over each commit as if it's a row in the table. type LogItr struct { - commits []*doltdb.Commit - idx int + child doltdb.CommitItr } // NewLogItr creates a LogItr from the current environment. func NewLogItr(ctx *sql.Context, ddb *doltdb.DoltDB, head *doltdb.Commit) (*LogItr, error) { - commits, err := actions.TimeSortedCommits(ctx, ddb, head, -1) - + hash, err := head.HashOf() if err != nil { return nil, err } - return &LogItr{commits, 0}, nil + child, err := commitwalk.GetTopologicalOrderIterator(ctx, ddb, hash) + if err != nil { + return nil, err + } + + return &LogItr{child}, nil } // Next retrieves the next row. It will return io.EOF if it's the last row. // After retrieving the last row, Close will be automatically closed. func (itr *LogItr) Next(ctx *sql.Context) (sql.Row, error) { - if itr.idx >= len(itr.commits) { - return nil, io.EOF - } - - defer func() { - itr.idx++ - }() - - cm := itr.commits[itr.idx] - meta, err := cm.GetCommitMeta(ctx) - + h, cm, err := itr.child.Next(ctx) if err != nil { return nil, err } - h, err := cm.HashOf() - + meta, err := cm.GetCommitMeta(ctx) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/dtables/unscoped_diff_table.go b/go/libraries/doltcore/sqle/dtables/unscoped_diff_table.go index 5f101c39e3..9661c868ab 100644 --- a/go/libraries/doltcore/sqle/dtables/unscoped_diff_table.go +++ b/go/libraries/doltcore/sqle/dtables/unscoped_diff_table.go @@ -17,12 +17,13 @@ package dtables import ( "context" "errors" - "io" "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" - "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/store/datas" + "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/go-mysql-server/sql" ) @@ -87,29 +88,27 @@ func (dt *UnscopedDiffTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (s type UnscopedDiffTableItr struct { ctx *sql.Context ddb *doltdb.DoltDB - commits []*doltdb.Commit - commitIdx int + child doltdb.CommitItr + meta *datas.CommitMeta + hash hash.Hash tableChanges []tableChange tableChangesIdx int } // NewUnscopedDiffTableItr creates a UnscopedDiffTableItr from the current environment. func NewUnscopedDiffTableItr(ctx *sql.Context, ddb *doltdb.DoltDB, head *doltdb.Commit) (*UnscopedDiffTableItr, error) { - commits, err := actions.TimeSortedCommits(ctx, ddb, head, -1) - + hash, err := head.HashOf() if err != nil { return nil, err } + child, err := commitwalk.GetTopologicalOrderIterator(ctx, ddb, hash) - return &UnscopedDiffTableItr{ctx, ddb, commits, 0, nil, -1}, nil -} - -// HasNext returns true if this UnscopedDiffItr has more elements left. -func (itr *UnscopedDiffTableItr) HasNext() bool { - // There are more diff records to iterate over if: - // 1) there is more than one commit left to process, or - // 2) the tableChanges array isn't nilled out and has data left to process - return itr.commitIdx+1 < len(itr.commits) || itr.tableChanges != nil + return &UnscopedDiffTableItr{ + ctx: ctx, + ddb: ddb, + child: child, + tableChangesIdx: -1, + }, nil } // incrementIndexes increments the table changes index, and if it's the end of the table changes array, moves @@ -119,69 +118,72 @@ func (itr *UnscopedDiffTableItr) incrementIndexes() { if itr.tableChangesIdx >= len(itr.tableChanges) { itr.tableChangesIdx = -1 itr.tableChanges = nil - itr.commitIdx++ } } // Next retrieves the next row. It will return io.EOF if it's the last row. // After retrieving the last row, Close will be automatically closed. func (itr *UnscopedDiffTableItr) Next(ctx *sql.Context) (sql.Row, error) { - if !itr.HasNext() { - return nil, io.EOF - } defer itr.incrementIndexes() - // Load table changes if we don't have them for this commit yet for itr.tableChanges == nil { - err := itr.loadTableChanges(ctx, itr.commits[itr.commitIdx]) + err := itr.loadTableChanges(ctx) if err != nil { return nil, err } } - commit := itr.commits[itr.commitIdx] - hash, err := commit.HashOf() - if err != nil { - return nil, err - } - - meta, err := commit.GetCommitMeta(ctx) - if err != nil { - return nil, err - } - tableChange := itr.tableChanges[itr.tableChangesIdx] + meta := itr.meta + h := itr.hash - return sql.NewRow(hash.String(), tableChange.tableName, meta.Name, meta.Email, meta.Time(), - meta.Description, tableChange.dataChange, tableChange.schemaChange), nil + return sql.NewRow( + h.String(), + tableChange.tableName, + meta.Name, + meta.Email, + meta.Time(), + meta.Description, + tableChange.dataChange, + tableChange.schemaChange, + ), nil } -// loadTableChanges loads the set of table changes for the current commit into this iterator, taking -// care of advancing the iterator if that commit didn't mutate any tables and checking for EOF condition. -func (itr *UnscopedDiffTableItr) loadTableChanges(ctx context.Context, commit *doltdb.Commit) error { +// loadTableChanges loads the next commit's table changes and metadata +// into the iterator. +func (itr *UnscopedDiffTableItr) loadTableChanges(ctx context.Context) error { + hash, commit, err := itr.child.Next(ctx) + if err != nil { + return err + } + tableChanges, err := itr.calculateTableChanges(ctx, commit) if err != nil { return err } - // If there are no table deltas for this commit (e.g. a "dolt doc" commit), - // advance to the next commit, checking for EOF condition. + itr.tableChanges = tableChanges + itr.tableChangesIdx = 0 if len(tableChanges) == 0 { - itr.commitIdx++ - if !itr.HasNext() { - return io.EOF - } - } else { - itr.tableChanges = tableChanges - itr.tableChangesIdx = 0 + return nil } + meta, err := commit.GetCommitMeta(ctx) + if err != nil { + return err + } + itr.meta = meta + itr.hash = hash return nil } // calculateTableChanges calculates the tables that changed in the specified commit, by comparing that // commit with its immediate ancestor commit. func (itr *UnscopedDiffTableItr) calculateTableChanges(ctx context.Context, commit *doltdb.Commit) ([]tableChange, error) { + if len(commit.DatasParents()) == 0 { + return nil, nil + } + toRootValue, err := commit.GetRootValue(ctx) if err != nil { return nil, err diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index f959df44ca..baa3d87a6c 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -76,51 +76,29 @@ func TestSingleQuery(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() - var scripts = []queries.ScriptTest{ { - Name: "Multialter DDL with ADD/DROP Primary Key", + Name: "Create table with TIME type", SetUpScript: []string{ - "CREATE TABLE t(pk int primary key, v1 int)", + "create table my_types (pk int primary key, c0 time);", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "ALTER TABLE t ADD COLUMN (v2 int), drop primary key, add primary key (v2)", - Expected: []sql.Row{{sql.NewOkResult(0)}}, + Query: "INSERT INTO my_types VALUES (1, '11:22:33.444444');", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 1, InsertID: 0}}}, }, { - Query: "DESCRIBE t", - Expected: []sql.Row{ - {"pk", "int", "NO", "", "", ""}, - {"v1", "int", "YES", "", "", ""}, - {"v2", "int", "NO", "PRI", "", ""}, - }, - }, - { - Query: "ALTER TABLE t ADD COLUMN (v3 int), drop primary key, add primary key (notacolumn)", - ExpectedErr: sql.ErrKeyColumnDoesNotExist, - }, - { - Query: "DESCRIBE t", - Expected: []sql.Row{ - {"pk", "int", "NO", "", "", ""}, - {"v1", "int", "YES", "", "", ""}, - {"v2", "int", "NO", "PRI", "", ""}, - }, + Query: "UPDATE my_types SET c0='11:22' WHERE pk=1;", + Expected: []sql.Row{{sql.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1, Warnings: 0}}}}, }, }, }, } harness := newDoltHarness(t) + harness.Setup(setup.MydbData) for _, test := range scripts { - myDb := harness.NewDatabase("mydb") - databases := []sql.Database{myDb} - engine := enginetest.NewEngineWithDbs(t, harness, databases) - //engine.Analyzer.Debug = true - //engine.Analyzer.Verbose = true - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, harness, test) } } @@ -302,16 +280,16 @@ func TestDoltUserPrivileges(t *testing.T) { harness := newDoltHarness(t) for _, script := range DoltUserPrivTests { t.Run(script.Name, func(t *testing.T) { - myDb := harness.NewDatabase("mydb") - databases := []sql.Database{myDb} - engine := enginetest.NewEngineWithDbs(t, harness, databases) + harness.Setup(setup.MydbData) + engine, err := harness.NewEngine(t) + require.NoError(t, err) defer engine.Close() ctx := enginetest.NewContextWithClient(harness, sql.Client{ User: "root", Address: "localhost", }) - engine.Analyzer.Catalog.GrantTables.AddRootAccount() + engine.Analyzer.Catalog.MySQLDb.AddRootAccount() for _, statement := range script.SetUpScript { if sh, ok := interface{}(harness).(enginetest.SkippingHarness); ok { @@ -587,7 +565,9 @@ func TestStoredProcedures(t *testing.T) { func TestTransactions(t *testing.T) { skipNewFormat(t) - enginetest.TestTransactionScripts(t, newDoltHarness(t)) + for _, script := range queries.TransactionTests { + enginetest.TestTransactionScript(t, newDoltHarness(t), script) + } for _, script := range DoltTransactionTests { enginetest.TestTransactionScript(t, newDoltHarness(t), script) @@ -640,20 +620,17 @@ func TestShowCreateTableAsOf(t *testing.T) { func TestDoltMerge(t *testing.T) { skipNewFormat(t) - harness := newDoltHarness(t) - harness.Setup(setup.MydbData) for _, script := range MergeScripts { - harness.engine = nil - enginetest.TestScript(t, harness, script) + // dolt versioning conflicts with reset harness -- use new harness every time + enginetest.TestScript(t, newDoltHarness(t), script) } } func TestDoltReset(t *testing.T) { skipNewFormat(t) - harness := newDoltHarness(t) for _, script := range DoltReset { - harness.engine = nil - enginetest.TestScript(t, harness, script) + // dolt versioning conflicts with reset harness -- use new harness every time + enginetest.TestScript(t, newDoltHarness(t), script) } } @@ -763,22 +740,19 @@ func TestBrokenSystemTableQueries(t *testing.T) { func TestHistorySystemTable(t *testing.T) { skipNewFormat(t) harness := newDoltHarness(t) + harness.Setup(setup.MydbData) for _, test := range HistorySystemTableScriptTests { - databases := harness.NewDatabases("mydb") - engine := enginetest.NewEngineWithDbs(t, harness, databases) + harness.engine = nil t.Run(test.Name, func(t *testing.T) { - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, harness, test) }) } } func TestUnscopedDiffSystemTable(t *testing.T) { - harness := newDoltHarness(t) for _, test := range UnscopedDiffSystemTableScriptTests { - databases := harness.NewDatabases("mydb") - engine := enginetest.NewEngineWithDbs(t, harness, databases) t.Run(test.Name, func(t *testing.T) { - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, newDoltHarness(t), test) }) } } @@ -786,12 +760,11 @@ func TestUnscopedDiffSystemTable(t *testing.T) { func TestDiffTableFunction(t *testing.T) { skipNewFormat(t) harness := newDoltHarness(t) - + harness.Setup(setup.MydbData) for _, test := range DiffTableFunctionScriptTests { - databases := harness.NewDatabases("mydb") - engine := enginetest.NewEngineWithDbs(t, harness, databases) + harness.engine = nil t.Run(test.Name, func(t *testing.T) { - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, harness, test) }) } } @@ -799,11 +772,11 @@ func TestDiffTableFunction(t *testing.T) { func TestCommitDiffSystemTable(t *testing.T) { skipNewFormat(t) harness := newDoltHarness(t) + harness.Setup(setup.MydbData) for _, test := range CommitDiffSystemTableScriptTests { - databases := harness.NewDatabases("mydb") - engine := enginetest.NewEngineWithDbs(t, harness, databases) + harness.engine = nil t.Run(test.Name, func(t *testing.T) { - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, harness, test) }) } } @@ -811,11 +784,11 @@ func TestCommitDiffSystemTable(t *testing.T) { func TestDiffSystemTable(t *testing.T) { skipNewFormat(t) harness := newDoltHarness(t) + harness.Setup(setup.MydbData) for _, test := range DiffSystemTableScriptTests { - databases := harness.NewDatabases("mydb") - engine := enginetest.NewEngineWithDbs(t, harness, databases) + harness.engine = nil t.Run(test.Name, func(t *testing.T) { - enginetest.TestScriptWithEngine(t, engine, harness, test) + enginetest.TestScript(t, harness, test) }) } } @@ -825,7 +798,6 @@ func TestTestReadOnlyDatabases(t *testing.T) { } func TestAddDropPks(t *testing.T) { - skipNewFormat(t) enginetest.TestAddDropPks(t, newDoltHarness(t)) } @@ -1112,8 +1084,8 @@ func TestAddDropPrimaryKeys(t *testing.T) { }, Assertions: []queries.ScriptTestAssertion{ { - Query: "ALTER TABLE test ADD PRIMARY KEY (id, c1, c2)", - ExpectedErrStr: "primary key cannot have NULL values", + Query: "ALTER TABLE test ADD PRIMARY KEY (id, c1, c2)", + ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull, }, }, } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index 31af15298a..31c98dab57 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -25,6 +25,8 @@ import ( "github.com/dolthub/go-mysql-server/enginetest" "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/information_schema" + "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/stretchr/testify/require" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -54,6 +56,8 @@ type DoltHarness struct { skippedQueries []string setupData []setup.SetupScript resetData []setup.SetupScript + initDbs map[string]struct{} + autoInc bool engine *gms.Engine } @@ -89,6 +93,7 @@ func newDoltHarness(t *testing.T) *DoltHarness { if types.IsFormat_DOLT_1(dEnv.DoltDB.Format()) { dh = dh.WithSkippedQueries([]string{ "SHOW CREATE TABLE child", // todo(andy): "TestForeignKeys - ALTER TABLE RENAME COLUMN" + "typestable", }) } @@ -99,10 +104,10 @@ var defaultSkippedQueries = []string{ "show variables", // we set extra variables "show create table fk_tbl", // we create an extra key for the FK that vanilla gms does not "show indexes from", // we create / expose extra indexes (for foreign keys) - "typestable", // Bit type isn't working? "show global variables like", // we set extra variables } +// Setup sets the setup scripts for this DoltHarness's engine func (d *DoltHarness) Setup(setupData ...[]setup.SetupScript) { d.engine = nil d.setupData = nil @@ -111,6 +116,9 @@ func (d *DoltHarness) Setup(setupData ...[]setup.SetupScript) { } } +// resetScripts returns a set of queries that will reset the given database +// names. If [autoInc], the queries for resetting autoincrement tables are +// included. func resetScripts(dbs []string, autoInc bool) []setup.SetupScript { var resetCmds setup.SetupScript for i := range dbs { @@ -126,47 +134,63 @@ func resetScripts(dbs []string, autoInc bool) []setup.SetupScript { return []setup.SetupScript{resetCmds} } +// commitScripts returns a set of queries that will commit the workingsets +// of the given database names func commitScripts(dbs []string) []setup.SetupScript { var commitCmds setup.SetupScript for i := range dbs { db := dbs[i] commitCmds = append(commitCmds, fmt.Sprintf("use %s", db)) - commitCmds = append(commitCmds, fmt.Sprintf("call dcommit('--allow-empty', '-am', 'checkpoint enginetest database %s')", db)) + commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s')", db)) } commitCmds = append(commitCmds, "use mydb") return []setup.SetupScript{commitCmds} } +// NewEngine creates a new *gms.Engine or calls reset and clear scripts on the existing +// engine for reuse. func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) { if d.engine == nil { - e, err := enginetest.NewEngineWithSetup(t, d, d.setupData) + pro := d.NewDatabaseProvider(information_schema.NewInformationSchemaDatabase()) + e, err := enginetest.NewEngineWithProviderSetup(t, d, pro, d.setupData) if err != nil { return nil, err } d.engine = e - ctx := enginetest.NewContext(d) - res := enginetest.MustQuery(ctx, e, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');") + var res []sql.Row + // todo(max): need better way to reset autoincrement regardless of test type + ctx := enginetest.NewContext(d) + res = enginetest.MustQuery(ctx, e, "select count(*) from information_schema.tables where table_name = 'auto_increment_tbl';") + d.autoInc = res[0][0].(int64) > 0 + + res = enginetest.MustQuery(ctx, e, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');") var dbs []string for i := range res { dbs = append(dbs, res[i][0].(string)) } - res = enginetest.MustQuery(ctx, e, "select count(*) from information_schema.tables where table_name = 'auto_increment_tbl';") - autoInc := res[0][0].(int64) > 0 - e, err = enginetest.RunEngineScripts(ctx, e, commitScripts(dbs), d.SupportsNativeIndexCreation()) if err != nil { return nil, err } - d.resetData = resetScripts(dbs, autoInc) - return e, nil } + // grants are files that can only be manually reset + d.engine.Analyzer.Catalog.MySQLDb = mysql_db.CreateEmptyMySQLDb() + d.engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + + //todo(max): easier if tests specify their databases ahead of time ctx := enginetest.NewContext(d) - return enginetest.RunEngineScripts(ctx, d.engine, d.resetData, d.SupportsNativeIndexCreation()) + res := enginetest.MustQuery(ctx, d.engine, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');") + var dbs []string + for i := range res { + dbs = append(dbs, res[i][0].(string)) + } + + return enginetest.RunEngineScripts(ctx, d.engine, resetScripts(dbs, d.autoInc), d.SupportsNativeIndexCreation()) } // WithParallelism returns a copy of the harness with parallelism set to the given number of threads. A value of 0 or @@ -317,7 +341,9 @@ func (d *DoltHarness) NewDatabaseProvider(dbs ...sql.Database) sql.MutableDataba require.NoError(d.t, err) d.multiRepoEnv = mrEnv for _, db := range dbs { - d.multiRepoEnv.AddEnv(db.Name(), d.createdEnvs[db.Name()]) + if db.Name() != information_schema.InformationSchemaDatabaseName { + d.multiRepoEnv.AddEnv(db.Name(), d.createdEnvs[db.Name()]) + } } b := env.GetDefaultInitBranch(d.multiRepoEnv.Config()) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 600801ba9c..2b2759f2b8 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -985,6 +985,316 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "Drop and add primary key on two branches converges to same schema", + SetUpScript: []string{ + "create table t1 (i int);", + "call dolt_commit('-am', 't1 table')", + "call dolt_checkout('-b', 'b1')", + "alter table t1 add primary key(i)", + "alter table t1 drop primary key", + "alter table t1 add primary key(i)", + "alter table t1 drop primary key", + "alter table t1 add primary key(i)", + "call dolt_commit('-am', 'b1 primary key changes')", + "call dolt_checkout('main')", + "alter table t1 add primary key(i)", + "call dolt_commit('-am', 'main primary key change')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_merge('b1')", + Expected: []sql.Row{{1}}, + }, + { + Query: "select count(*) from dolt_conflicts", + Expected: []sql.Row{{0}}, + }, + }, + }, + { + Name: "merging branches into a constraint violated head. Any new violations are appended", + SetUpScript: []string{ + "CREATE table parent (pk int PRIMARY KEY, col1 int);", + "CREATE table child (pk int PRIMARY KEY, parent_fk int, FOREIGN KEY (parent_fk) REFERENCES parent(pk));", + "CREATE table other (pk int);", + "INSERT INTO parent VALUES (1, 1), (2, 2);", + "CALL DOLT_COMMIT('-am', 'setup');", + "CALL DOLT_BRANCH('branch1');", + "CALL DOLT_BRANCH('branch2');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // we need dolt_force_transaction_commit because we want to + // transaction commit constraint violations that occur as a + // result of a merge. + Query: "set autocommit = off, dolt_force_transaction_commit = on", + Expected: []sql.Row{{}}, + }, + { + Query: "DELETE FROM parent where pk = 1;", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "CALL DOLT_COMMIT('-am', 'delete parent 1');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('branch1');", + Expected: []sql.Row{{0}}, + }, + { + Query: "INSERT INTO CHILD VALUES (1, 1);", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "CALL DOLT_COMMIT('-am', 'insert child of parent 1');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('main');", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL DOLT_MERGE('branch1');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 1}}, + }, + { + Query: "COMMIT;", + Expected: []sql.Row{}, + }, + { + Query: "CALL DOLT_COMMIT('-am', 'commit constraint violations');", + ExpectedErrStr: "error: the table(s) child has constraint violations", + }, + { + Query: "CALL DOLT_COMMIT('-afm', 'commit constraint violations');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_BRANCH('branch3');", + Expected: []sql.Row{{0}}, + }, + { + Query: "DELETE FROM parent where pk = 2;", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "CALL DOLT_COMMIT('-afm', 'remove parent 2');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('branch2');", + Expected: []sql.Row{{0}}, + }, + { + Query: "INSERT INTO OTHER VALUES (1);", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "CALL DOLT_COMMIT('-am', 'non-fk insert');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('main');", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL DOLT_MERGE('branch2');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 1}}, + }, + { + Query: "COMMIT;", + Expected: []sql.Row{}, + }, + { + Query: "CALL DOLT_COMMIT('-am', 'commit non-conflicting merge');", + ExpectedErrStr: "error: the table(s) child has constraint violations", + }, + { + Query: "CALL DOLT_COMMIT('-afm', 'commit non-conflicting merge');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('branch3');", + Expected: []sql.Row{{0}}, + }, + { + Query: "INSERT INTO CHILD VALUES (2, 2);", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + Query: "CALL DOLT_COMMIT('-afm', 'add child of parent 2');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_CHECKOUT('main');", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL DOLT_MERGE('branch3');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 1}, {"foreign key", 2, 2}}, + }, + }, + }, + { + Name: "conflicting merge aborts when conflicts and violations already exist", + SetUpScript: []string{ + "CREATE table parent (pk int PRIMARY KEY, col1 int);", + "CREATE table child (pk int PRIMARY KEY, parent_fk int, FOREIGN KEY (parent_fk) REFERENCES parent(pk));", + "INSERT INTO parent VALUES (1, 1), (2, 1);", + "CALL DOLT_COMMIT('-am', 'create table with data');", + "CALL DOLT_BRANCH('other');", + "CALL DOLT_BRANCH('other2');", + "UPDATE parent SET col1 = 2 where pk = 1;", + "DELETE FROM parent where pk = 2;", + "CALL DOLT_COMMIT('-am', 'updating col1 to 2 and remove pk = 2');", + "CALL DOLT_CHECKOUT('other');", + "UPDATE parent SET col1 = 3 where pk = 1;", + "INSERT into child VALUEs (1, 2);", + "CALL DOLT_COMMIT('-am', 'updating col1 to 3 and adding child of pk 2');", + "CALL DOLT_CHECKOUT('other2')", + "UPDATE parent SET col1 = 4 where pk = 1", + "CALL DOLT_COMMIT('-am', 'updating col1 to 4');", + "CALL DOLT_CHECKOUT('main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SET dolt_force_transaction_commit = 1", + Expected: []sql.Row{{}}, + }, + { + Query: "CALL DOLT_MERGE('other');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT * from parent;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT * from child;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT base_col1, base_pk, our_col1, our_pk, their_col1, their_pk from dolt_conflicts_parent;", + Expected: []sql.Row{{1, 1, 2, 1, 3, 1}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 2}}, + }, + // commit so we can merge again + { + Query: "CALL DOLT_COMMIT('-afm', 'committing merge conflicts');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_MERGE('other2');", + ExpectedErrStr: "existing unresolved conflicts would be overridden by new conflicts produced by merge. Please resolve them and try again", + }, + { + Query: "SELECT * from parent;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT * from child;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT base_col1, base_pk, our_col1, our_pk, their_col1, their_pk from dolt_conflicts_parent;", + Expected: []sql.Row{{1, 1, 2, 1, 3, 1}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 2}}, + }, + }, + }, + { + Name: "non-conflicting / non-violating merge succeeds when conflicts and violations already exist", + SetUpScript: []string{ + "CREATE table parent (pk int PRIMARY KEY, col1 int);", + "CREATE table child (pk int PRIMARY KEY, parent_fk int, FOREIGN KEY (parent_fk) REFERENCES parent(pk));", + "INSERT INTO parent VALUES (1, 1), (2, 1);", + "CALL DOLT_COMMIT('-am', 'create table with data');", + "CALL DOLT_BRANCH('other');", + "CALL DOLT_BRANCH('other2');", + "UPDATE parent SET col1 = 2 where pk = 1;", + "DELETE FROM parent where pk = 2;", + "CALL DOLT_COMMIT('-am', 'updating col1 to 2 and remove pk = 2');", + "CALL DOLT_CHECKOUT('other');", + "UPDATE parent SET col1 = 3 where pk = 1;", + "INSERT into child VALUES (1, 2);", + "CALL DOLT_COMMIT('-am', 'updating col1 to 3 and adding child of pk 2');", + "CALL DOLT_CHECKOUT('other2')", + "INSERT INTO parent values (3, 1);", + "CALL DOLT_COMMIT('-am', 'insert parent with pk 3');", + "CALL DOLT_CHECKOUT('main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SET dolt_force_transaction_commit = 1;", + Expected: []sql.Row{{}}, + }, + { + Query: "CALL DOLT_MERGE('other');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT * from parent;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT * from child;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT base_col1, base_pk, our_col1, our_pk, their_col1, their_pk from dolt_conflicts_parent;", + Expected: []sql.Row{{1, 1, 2, 1, 3, 1}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 2}}, + }, + // commit so we can merge again + { + Query: "CALL DOLT_COMMIT('-afm', 'committing merge conflicts');", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_MERGE('other2');", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT * from parent;", + Expected: []sql.Row{{1, 2}, {3, 1}}, + }, + { + Query: "SELECT * from child;", + Expected: []sql.Row{{1, 2}}, + }, + { + Query: "SELECT base_col1, base_pk, our_col1, our_pk, their_col1, their_pk from dolt_conflicts_parent;", + Expected: []sql.Row{{1, 1, 2, 1, 3, 1}}, + }, + { + Query: "SELECT violation_type, pk, parent_fk from dolt_constraint_violations_child;", + Expected: []sql.Row{{"foreign key", 1, 2}}, + }, + }, + }, } var DoltReset = []queries.ScriptTest{ diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go index 0493542b86..41f2d6ca8d 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go @@ -20,10 +20,12 @@ import ( "github.com/dolthub/go-mysql-server/enginetest" "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" "github.com/dolthub/go-mysql-server/sql" "github.com/stretchr/testify/require" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/store/types" ) @@ -37,6 +39,7 @@ func TestDoltTransactionCommitOneClient(t *testing.T) { // In this test, we're setting only one client to match transaction commits to dolt commits. // Autocommit is disabled for the enabled client, as it's the recommended way to use this feature. harness := newDoltHarness(t) + harness.Setup(setup.MydbData) enginetest.TestTransactionScript(t, harness, queries.TransactionTest{ Name: "dolt commit on transaction commit one client", SetUpScript: []string{ @@ -147,8 +150,13 @@ func TestDoltTransactionCommitOneClient(t *testing.T) { }, }, }) + _, err := harness.NewEngine(t) - db := harness.databases[0].GetDoltDB() + ctx := enginetest.NewContext(harness) + db, ok := ctx.Session.(*dsess.DoltSession).GetDoltDB(ctx, "mydb") + if !ok { + t.Fatal("'mydb' database not found") + } cs, err := doltdb.NewCommitSpec("HEAD") require.NoError(t, err) headRefs, err := db.GetHeadRefs(context.Background()) @@ -165,7 +173,7 @@ func TestDoltTransactionCommitOneClient(t *testing.T) { require.NoError(t, err) icm, err := initialCommit.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Equal(t, "Initialize data repository", icm.Description) + require.Equal(t, "checkpoint enginetest database mydb", icm.Description) } func TestDoltTransactionCommitTwoClients(t *testing.T) { @@ -274,7 +282,13 @@ func TestDoltTransactionCommitTwoClients(t *testing.T) { }, }, }) - db := harness.databases[0].GetDoltDB() + _, err := harness.NewEngine(t) + + ctx := enginetest.NewContext(harness) + db, ok := ctx.Session.(*dsess.DoltSession).GetDoltDB(ctx, "mydb") + if !ok { + t.Fatal("'mydb' database not found") + } cs, err := doltdb.NewCommitSpec("HEAD") require.NoError(t, err) headRefs, err := db.GetHeadRefs(context.Background()) @@ -297,7 +311,7 @@ func TestDoltTransactionCommitTwoClients(t *testing.T) { require.NoError(t, err) cm0, err := commit0.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Equal(t, "Initialize data repository", cm0.Description) + require.Equal(t, "checkpoint enginetest database mydb", cm0.Description) } func TestDoltTransactionCommitAutocommit(t *testing.T) { @@ -346,7 +360,13 @@ func TestDoltTransactionCommitAutocommit(t *testing.T) { }, }, }) - db := harness.databases[0].GetDoltDB() + _, err := harness.NewEngine(t) + + ctx := enginetest.NewContext(harness) + db, ok := ctx.Session.(*dsess.DoltSession).GetDoltDB(ctx, "mydb") + if !ok { + t.Fatal("'mydb' database not found") + } cs, err := doltdb.NewCommitSpec("HEAD") require.NoError(t, err) headRefs, err := db.GetHeadRefs(context.Background()) @@ -375,7 +395,7 @@ func TestDoltTransactionCommitAutocommit(t *testing.T) { require.NoError(t, err) cm0, err := commit0.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Equal(t, "Initialize data repository", cm0.Description) + require.Equal(t, "checkpoint enginetest database mydb", cm0.Description) } func TestDoltTransactionCommitLateFkResolution(t *testing.T) { diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go index 44420389ca..6db24647f0 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go @@ -851,6 +851,10 @@ var DoltConflictHandlingTests = []queries.TransactionTest{ Query: "/* client b */ select * from test order by 1", Expected: []sql.Row{{0, 0}, {1, 1}}, }, + { // no conflicts, transaction got rolled back + Query: "/* client b */ select count(*) from dolt_conflicts", + Expected: []sql.Row{{0}}, + }, }, }, { @@ -897,6 +901,10 @@ var DoltConflictHandlingTests = []queries.TransactionTest{ Query: "/* client b */ select * from test order by 1", Expected: []sql.Row{{0, 0}, {1, 1}}, }, + { // no conflicts, transaction got rolled back + Query: "/* client b */ select count(*) from dolt_conflicts", + Expected: []sql.Row{{0}}, + }, }, }, { diff --git a/go/libraries/doltcore/sqle/index/index_lookup.go b/go/libraries/doltcore/sqle/index/index_lookup.go index d2c8ca1dc6..db41fae784 100644 --- a/go/libraries/doltcore/sqle/index/index_lookup.go +++ b/go/libraries/doltcore/sqle/index/index_lookup.go @@ -66,7 +66,7 @@ func RowIterForProllyRange(ctx *sql.Context, idx DoltIndex, ranges prolly.Range, if covers { return newProllyCoveringIndexIter(ctx, idx, ranges, pkSch, secondary) } else { - return newProllyIndexIter(ctx, idx, ranges, primary, secondary) + return newProllyIndexIter(ctx, idx, ranges, pkSch, primary, secondary) } } diff --git a/go/libraries/doltcore/sqle/index/prolly_fields.go b/go/libraries/doltcore/sqle/index/prolly_fields.go index ea86c44a89..5555444574 100644 --- a/go/libraries/doltcore/sqle/index/prolly_fields.go +++ b/go/libraries/doltcore/sqle/index/prolly_fields.go @@ -28,6 +28,54 @@ import ( "github.com/dolthub/dolt/go/store/val" ) +// todo(andy): this should go in GMS +func DenormalizeRow(sch sql.Schema, row sql.Row) (sql.Row, error) { + var err error + for i := range row { + if row[i] == nil { + continue + } + switch typ := sch[i].Type.(type) { + case sql.DecimalType: + row[i] = row[i].(decimal.Decimal).String() + case sql.EnumType: + row[i], err = typ.Unmarshal(int64(row[i].(uint16))) + case sql.SetType: + row[i], err = typ.Unmarshal(row[i].(uint64)) + default: + } + if err != nil { + return nil, err + } + } + return row, nil +} + +// todo(andy): this should go in GMS +func NormalizeRow(sch sql.Schema, row sql.Row) (sql.Row, error) { + var err error + for i := range row { + if row[i] == nil { + continue + } + switch typ := sch[i].Type.(type) { + case sql.DecimalType: + row[i], err = decimal.NewFromString(row[i].(string)) + case sql.EnumType: + var v int64 + v, err = typ.Marshal(row[i]) + row[i] = uint16(v) + case sql.SetType: + row[i], err = typ.Marshal(row[i]) + default: + } + if err != nil { + return nil, err + } + } + return row, nil +} + // GetField reads the value from the ith field of the Tuple as an interface{}. func GetField(td val.TupleDesc, i int, tup val.Tuple) (v interface{}, err error) { var ok bool @@ -52,12 +100,10 @@ func GetField(td val.TupleDesc, i int, tup val.Tuple) (v interface{}, err error) v, ok = td.GetFloat32(i, tup) case val.Float64Enc: v, ok = td.GetFloat64(i, tup) + case val.Bit64Enc: + v, ok = td.GetBit(i, tup) case val.DecimalEnc: - var d decimal.Decimal - d, ok = td.GetDecimal(i, tup) - if ok { - v = deserializeDecimal(d) - } + v, ok = td.GetDecimal(i, tup) case val.YearEnc: v, ok = td.GetYear(i, tup) case val.DateEnc: @@ -70,6 +116,10 @@ func GetField(td val.TupleDesc, i int, tup val.Tuple) (v interface{}, err error) } case val.DatetimeEnc: v, ok = td.GetDatetime(i, tup) + case val.EnumEnc: + v, ok = td.GetEnum(i, tup) + case val.SetEnc: + v, ok = td.GetSet(i, tup) case val.StringEnc: v, ok = td.GetString(i, tup) case val.ByteStringEnc: @@ -127,12 +177,10 @@ func PutField(tb *val.TupleBuilder, i int, v interface{}) error { tb.PutFloat32(i, v.(float32)) case val.Float64Enc: tb.PutFloat64(i, v.(float64)) + case val.Bit64Enc: + tb.PutBit(i, uint64(convUint(v))) case val.DecimalEnc: - d, err := serializeDecimal(v.(string)) - if err != nil { - return nil - } - tb.PutDecimal(i, d) + tb.PutDecimal(i, v.(decimal.Decimal)) case val.YearEnc: tb.PutYear(i, v.(int16)) case val.DateEnc: @@ -145,6 +193,10 @@ func PutField(tb *val.TupleBuilder, i int, v interface{}) error { tb.PutSqlTime(i, t) case val.DatetimeEnc: tb.PutDatetime(i, v.(time.Time)) + case val.EnumEnc: + tb.PutEnum(i, v.(uint16)) + case val.SetEnc: + tb.PutSet(i, v.(uint64)) case val.StringEnc: tb.PutString(i, v.(string)) case val.ByteStringEnc: @@ -220,14 +272,6 @@ func convUint(v interface{}) uint { } } -func convJson(v interface{}) (buf []byte, err error) { - v, err = sql.JSON.Convert(v) - if err != nil { - return nil, err - } - return json.Marshal(v.(sql.JSONDocument).Val) -} - func deserializeGeometry(buf []byte) (v interface{}) { srid, _, typ := geo.ParseEWKBHeader(buf) buf = buf[geo.EWKBHeaderSize:] @@ -257,6 +301,18 @@ func serializeGeometry(v interface{}) []byte { } } +func convJson(v interface{}) (buf []byte, err error) { + v, err = sql.JSON.Convert(v) + if err != nil { + return nil, err + } + return json.Marshal(v.(sql.JSONDocument).Val) +} + +func deserializeTime(v int64) (interface{}, error) { + return typeinfo.TimeType.ConvertNomsValueToValue(types.Int(v)) +} + func serializeTime(v interface{}) (int64, error) { i, err := typeinfo.TimeType.ConvertValueToNomsValue(nil, nil, v) if err != nil { @@ -264,15 +320,3 @@ func serializeTime(v interface{}) (int64, error) { } return int64(i.(types.Int)), nil } - -func deserializeTime(v int64) (interface{}, error) { - return typeinfo.TimeType.ConvertNomsValueToValue(types.Int(v)) -} - -func serializeDecimal(v interface{}) (decimal.Decimal, error) { - return decimal.NewFromString(v.(string)) -} - -func deserializeDecimal(v decimal.Decimal) interface{} { - return v.String() -} diff --git a/go/libraries/doltcore/sqle/index/prolly_fields_test.go b/go/libraries/doltcore/sqle/index/prolly_fields_test.go index 72778600f5..faac6c9234 100644 --- a/go/libraries/doltcore/sqle/index/prolly_fields_test.go +++ b/go/libraries/doltcore/sqle/index/prolly_fields_test.go @@ -22,6 +22,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression/function" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -95,10 +96,15 @@ func TestRoundTripProllyFields(t *testing.T) { typ: val.Type{Enc: val.Float64Enc}, value: float64(-math.Pi), }, + { + name: "bit", + typ: val.Type{Enc: val.Bit64Enc}, + value: uint64(42), + }, { name: "decimal", typ: val.Type{Enc: val.DecimalEnc}, - value: "0.263419374632932747932030573792", + value: mustParseDecimal("0.263419374632932747932030573792"), }, { name: "string", @@ -120,11 +126,11 @@ func TestRoundTripProllyFields(t *testing.T) { typ: val.Type{Enc: val.DateEnc}, value: dateFromTime(time.Now().UTC()), }, - //{ - // name: "time", - // typ: val.Type{Enc: val.DateEnc}, - // value: dateFromTime(time.Now().UTC()), - //}, + { + name: "time", + typ: val.Type{Enc: val.TimeEnc}, + value: "11:22:00", + }, { name: "datetime", typ: val.Type{Enc: val.DatetimeEnc}, @@ -207,6 +213,14 @@ func mustParseJson(t *testing.T, s string) sql.JSONDocument { return sql.JSONDocument{Val: v} } +func mustParseDecimal(s string) decimal.Decimal { + d, err := decimal.NewFromString(s) + if err != nil { + panic(err) + } + return d +} + func dateFromTime(t time.Time) time.Time { y, m, d := t.Year(), t.Month(), t.Day() return time.Date(y, m, d, 0, 0, 0, 0, time.UTC) diff --git a/go/libraries/doltcore/sqle/index/prolly_index_iter.go b/go/libraries/doltcore/sqle/index/prolly_index_iter.go index 424ff11cbf..36db838a32 100644 --- a/go/libraries/doltcore/sqle/index/prolly_index_iter.go +++ b/go/libraries/doltcore/sqle/index/prolly_index_iter.go @@ -46,13 +46,20 @@ type prollyIndexIter struct { // keyMap and valMap transform tuples from // primary row storage into sql.Row's keyMap, valMap val.OrdinalMapping + sqlSch sql.Schema } var _ sql.RowIter = prollyIndexIter{} var _ sql.RowIter2 = prollyIndexIter{} // NewProllyIndexIter returns a new prollyIndexIter. -func newProllyIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range, dprimary, dsecondary durable.Index) (prollyIndexIter, error) { +func newProllyIndexIter( + ctx *sql.Context, + idx DoltIndex, + rng prolly.Range, + pkSch sql.PrimaryKeySchema, + dprimary, dsecondary durable.Index, +) (prollyIndexIter, error) { secondary := durable.ProllyMapFromIndex(dsecondary) indexIter, err := secondary.IterRange(ctx, rng) if err != nil { @@ -79,6 +86,7 @@ func newProllyIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Range, dprim rowChan: make(chan sql.Row, indexLookupBufSize), keyMap: km, valMap: vm, + sqlSch: pkSch.Schema, } eg.Go(func() error { @@ -95,7 +103,7 @@ func (p prollyIndexIter) Next(ctx *sql.Context) (r sql.Row, err error) { select { case r, ok = <-p.rowChan: if ok { - return r, nil + return DenormalizeRow(p.sqlSch, r) } } if !ok { @@ -222,6 +230,7 @@ type prollyCoveringIndexIter struct { // |keyMap| and |valMap| are both of len == keyMap, valMap val.OrdinalMapping + sqlSch sql.Schema } var _ sql.RowIter = prollyCoveringIndexIter{} @@ -251,6 +260,7 @@ func newProllyCoveringIndexIter(ctx *sql.Context, idx DoltIndex, rng prolly.Rang valDesc: valDesc, keyMap: keyMap, valMap: valMap, + sqlSch: pkSch.Schema, } return iter, nil @@ -268,7 +278,7 @@ func (p prollyCoveringIndexIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - return r, nil + return DenormalizeRow(p.sqlSch, r) } func (p prollyCoveringIndexIter) Next2(ctx *sql.Context, f *sql.RowFrame) error { diff --git a/go/libraries/doltcore/sqle/index/prolly_row_iter.go b/go/libraries/doltcore/sqle/index/prolly_row_iter.go index 93f7571cf6..70cd24e840 100644 --- a/go/libraries/doltcore/sqle/index/prolly_row_iter.go +++ b/go/libraries/doltcore/sqle/index/prolly_row_iter.go @@ -15,7 +15,6 @@ package index import ( - "context" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -52,6 +51,7 @@ var encodingToType [256]query.Type type prollyRowIter struct { iter prolly.MapIter + sqlSch sql.Schema keyDesc val.TupleDesc valDesc val.TupleDesc keyProj []int @@ -63,8 +63,8 @@ var _ sql.RowIter = prollyRowIter{} var _ sql.RowIter2 = prollyRowIter{} func NewProllyRowIter( - ctx context.Context, sch schema.Schema, + schSch sql.Schema, rows prolly.Map, iter prolly.MapIter, projections []string, @@ -91,6 +91,7 @@ func NewProllyRowIter( return prollyRowIter{ iter: iter, + sqlSch: schSch, keyDesc: kd, valDesc: vd, keyProj: keyProj, @@ -159,8 +160,7 @@ func (it prollyRowIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } } - - return row, nil + return DenormalizeRow(it.sqlSch, row) } func (it prollyRowIter) Next2(ctx *sql.Context, frame *sql.RowFrame) error { diff --git a/go/libraries/doltcore/sqle/privileges/file_handler.go b/go/libraries/doltcore/sqle/mysql_file_handler/file_handler.go similarity index 52% rename from go/libraries/doltcore/sqle/privileges/file_handler.go rename to go/libraries/doltcore/sqle/mysql_file_handler/file_handler.go index 4d3a701a31..2cae804ef8 100644 --- a/go/libraries/doltcore/sqle/privileges/file_handler.go +++ b/go/libraries/doltcore/sqle/mysql_file_handler/file_handler.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package privileges +package mysql_file_handler import ( "encoding/json" @@ -22,22 +22,46 @@ import ( "sync" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/grant_tables" + "github.com/dolthub/go-mysql-server/sql/mysql_db" ) -var ( - filePath string - fileMutex = &sync.Mutex{} -) +var fileMutex = &sync.Mutex{} +var mysqlDbFilePath string +var privsFilePath string // privDataJson is used to marshal/unmarshal the privilege data to/from JSON. type privDataJson struct { - Users []*grant_tables.User - Roles []*grant_tables.RoleEdge + Users []*mysql_db.User + Roles []*mysql_db.RoleEdge } -// SetFilePath sets the file path that will be used for saving and loading privileges. -func SetFilePath(fp string) { +// SetPrivilegeFilePath sets the file path that will be used for loading privileges. +func SetPrivilegeFilePath(fp string) { + // do nothing for empty file path + if len(fp) == 0 { + return + } + + fileMutex.Lock() + defer fileMutex.Unlock() + + _, err := os.Stat(fp) + if err != nil { + // Some strange unknown failure, okay to panic here + if !errors.Is(err, os.ErrNotExist) { + panic(err) + } + } + privsFilePath = fp +} + +// SetMySQLDbFilePath sets the file path that will be used for saving and loading MySQL Db tables. +func SetMySQLDbFilePath(fp string) { + // do nothing for empty file path + if len(fp) == 0 { + return + } + fileMutex.Lock() defer fileMutex.Unlock() @@ -53,21 +77,26 @@ func SetFilePath(fp string) { panic(err) } } - filePath = fp + mysqlDbFilePath = fp } // LoadPrivileges reads the file previously set on the file path and returns the privileges and role connections. If the // file path has not been set, returns an empty slice for both, but does not error. This is so that the logic path can // retain the calls regardless of whether a user wants privileges to be loaded or persisted. -func LoadPrivileges() ([]*grant_tables.User, []*grant_tables.RoleEdge, error) { - fileMutex.Lock() - defer fileMutex.Unlock() - if filePath == "" { +func LoadPrivileges() ([]*mysql_db.User, []*mysql_db.RoleEdge, error) { + // return nil for empty path + if len(privsFilePath) == 0 { return nil, nil, nil } - fileContents, err := ioutil.ReadFile(filePath) + fileMutex.Lock() + defer fileMutex.Unlock() + + fileContents, err := ioutil.ReadFile(privsFilePath) if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil, nil + } return nil, nil, err } if len(fileContents) == 0 { @@ -81,25 +110,34 @@ func LoadPrivileges() ([]*grant_tables.User, []*grant_tables.RoleEdge, error) { return data.Users, data.Roles, nil } -var _ grant_tables.PersistCallback = SavePrivileges +// LoadData reads the mysql.db file, returns nil if empty or not found +func LoadData() ([]byte, error) { + // return nil for empty path + if len(mysqlDbFilePath) == 0 { + return nil, nil + } -// SavePrivileges implements the interface grant_tables.PersistCallback. This is used to save privileges to disk. If the -// file path has not been previously set, this returns without error. This is so that the logic path can retain the -// calls regardless of whether a user wants privileges to be loaded or persisted. -func SavePrivileges(ctx *sql.Context, users []*grant_tables.User, roles []*grant_tables.RoleEdge) error { fileMutex.Lock() defer fileMutex.Unlock() - if filePath == "" { - return nil + + buf, err := ioutil.ReadFile(mysqlDbFilePath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, err } - data := &privDataJson{ - Users: users, - Roles: roles, + if len(buf) == 0 { + return nil, nil } - jsonData, err := json.Marshal(data) - if err != nil { - return err - } - return ioutil.WriteFile(filePath, jsonData, 0777) + + return buf, nil +} + +var _ mysql_db.PersistCallback = SaveData + +// SaveData writes the provided []byte (in valid flatbuffer format) to the mysql db file +func SaveData(ctx *sql.Context, data []byte) error { + fileMutex.Lock() + defer fileMutex.Unlock() + + return ioutil.WriteFile(mysqlDbFilePath, data, 0777) } diff --git a/go/libraries/doltcore/sqle/rows.go b/go/libraries/doltcore/sqle/rows.go index 3df8743a2d..6df78f44dc 100644 --- a/go/libraries/doltcore/sqle/rows.go +++ b/go/libraries/doltcore/sqle/rows.go @@ -68,7 +68,7 @@ type doltTableRowIter struct { } // Returns a new row iterator for the table given -func newRowIterator(ctx context.Context, tbl *doltdb.Table, projCols []string, partition doltTablePartition) (sql.RowIter, error) { +func newRowIterator(ctx context.Context, tbl *doltdb.Table, sqlSch sql.Schema, projCols []string, partition doltTablePartition) (sql.RowIter, error) { sch, err := tbl.GetSchema(ctx) if err != nil { @@ -76,7 +76,7 @@ func newRowIterator(ctx context.Context, tbl *doltdb.Table, projCols []string, p } if types.IsFormat_DOLT_1(tbl.Format()) { - return ProllyRowIterFromPartition(ctx, tbl, projCols, partition) + return ProllyRowIterFromPartition(ctx, tbl, sqlSch, projCols, partition) } if schema.IsKeyless(sch) { @@ -168,7 +168,13 @@ func (itr *doltTableRowIter) Close(*sql.Context) error { return nil } -func ProllyRowIterFromPartition(ctx context.Context, tbl *doltdb.Table, projections []string, partition doltTablePartition) (sql.RowIter, error) { +func ProllyRowIterFromPartition( + ctx context.Context, + tbl *doltdb.Table, + sqlSch sql.Schema, + projections []string, + partition doltTablePartition, +) (sql.RowIter, error) { rows := durable.ProllyMapFromIndex(partition.rowData) sch, err := tbl.GetSchema(ctx) if err != nil { @@ -183,7 +189,7 @@ func ProllyRowIterFromPartition(ctx context.Context, tbl *doltdb.Table, projecti return nil, err } - return index.NewProllyRowIter(ctx, sch, rows, iter, projections) + return index.NewProllyRowIter(sch, sqlSch, rows, iter, projections) } // TableToRowIter returns a |sql.RowIter| for a full table scan for the given |table|. If @@ -208,6 +214,7 @@ func TableToRowIter(ctx *sql.Context, table *WritableDoltTable, columns []string end: NoUpperBound, rowData: data, } + sqlSch := table.sqlSch.Schema - return newRowIterator(ctx, t, columns, p) + return newRowIterator(ctx, t, sqlSch, columns, p) } diff --git a/go/libraries/doltcore/sqle/sqlutil/convert.go b/go/libraries/doltcore/sqle/sqlutil/convert.go index 541c9975ec..b19e55a8d7 100644 --- a/go/libraries/doltcore/sqle/sqlutil/convert.go +++ b/go/libraries/doltcore/sqle/sqlutil/convert.go @@ -124,7 +124,7 @@ func ToDoltSchema( // ToDoltCol returns the dolt column corresponding to the SQL column given func ToDoltCol(tag uint64, col *sql.Column) (schema.Column, error) { var constraints []schema.ColConstraint - if !col.Nullable { + if !col.Nullable || col.PrimaryKey { constraints = append(constraints, schema.NotNullConstraint{}) } typeInfo, err := typeinfo.FromSqlType(col.Type) diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index 70a78b2226..a46da0c369 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -325,7 +325,7 @@ func (t *DoltTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sq return nil, err } - return partitionRows(ctx, table, t.projectedCols, partition) + return partitionRows(ctx, table, t.sqlSch.Schema, t.projectedCols, partition) } func (t DoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.RowIter2, error) { @@ -334,7 +334,7 @@ func (t DoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.Row return nil, err } - iter, err := partitionRows(ctx, table, t.projectedCols, part) + iter, err := partitionRows(ctx, table, t.sqlSch.Schema, t.projectedCols, part) if err != nil { return nil, err } @@ -342,12 +342,12 @@ func (t DoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.Row return iter.(sql.RowIter2), err } -func partitionRows(ctx *sql.Context, t *doltdb.Table, projCols []string, partition sql.Partition) (sql.RowIter, error) { +func partitionRows(ctx *sql.Context, t *doltdb.Table, sqlSch sql.Schema, projCols []string, partition sql.Partition) (sql.RowIter, error) { switch typedPartition := partition.(type) { case doltTablePartition: - return newRowIterator(ctx, t, projCols, typedPartition) + return newRowIterator(ctx, t, sqlSch, projCols, typedPartition) case index.SinglePartition: - return newRowIterator(ctx, t, projCols, doltTablePartition{rowData: typedPartition.RowData, end: NoUpperBound}) + return newRowIterator(ctx, t, sqlSch, projCols, doltTablePartition{rowData: typedPartition.RowData, end: NoUpperBound}) } return nil, errors.New("unsupported partition type") @@ -961,10 +961,17 @@ func (t *AlterableDoltTable) ShouldRewriteTable( modifiedColumn *sql.Column, ) bool { // TODO: this could be a lot more specific, we don't always need to rewrite on schema changes in the new format - return types.IsFormat_DOLT_1(t.nbf) || len(oldSchema.Schema) < len(newSchema.Schema) + return types.IsFormat_DOLT_1(t.nbf) || + len(oldSchema.Schema) < len(newSchema.Schema) || + (len(newSchema.PkOrdinals) != len(oldSchema.PkOrdinals)) } -func (t *AlterableDoltTable) RewriteInserter(ctx *sql.Context, newSchema sql.PrimaryKeySchema) (sql.RowInserter, error) { +func (t *AlterableDoltTable) RewriteInserter( + ctx *sql.Context, + oldSchema sql.PrimaryKeySchema, + newSchema sql.PrimaryKeySchema, + modifiedColumn *sql.Column, +) (sql.RowInserter, error) { sess := dsess.DSessFromSess(ctx.Session) // Begin by creating a new table with the same name and the new schema, then removing all its existing rows @@ -1009,7 +1016,7 @@ func (t *AlterableDoltTable) RewriteInserter(ctx *sql.Context, newSchema sql.Pri return nil, err } - // If we have an auto increment column, we need to set it here before we begin the rewrite process + // If we have an auto increment column, we need to set it here before we begin the rewrite process (it may have changed) if schema.HasAutoIncrement(newSch) { newSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { if col.AutoIncrement { @@ -1035,6 +1042,13 @@ func (t *AlterableDoltTable) RewriteInserter(ctx *sql.Context, newSchema sql.Pri return nil, err } + if len(oldSchema.PkOrdinals) > 0 && len(newSchema.PkOrdinals) == 0 { + newRoot, err = t.adjustForeignKeysForDroppedPk(ctx, newRoot) + if err != nil { + return nil, err + } + } + newWs := ws.WithWorkingRoot(newRoot) // TODO: figure out locking. Other DBs automatically lock a table during this kind of operation, we should probably @@ -1053,6 +1067,34 @@ func (t *AlterableDoltTable) RewriteInserter(ctx *sql.Context, newSchema sql.Pri return ed, nil } +func (t *AlterableDoltTable) adjustForeignKeysForDroppedPk(ctx *sql.Context, root *doltdb.RootValue) (*doltdb.RootValue, error) { + if t.autoIncCol.AutoIncrement { + return nil, sql.ErrWrongAutoKey.New() + } + + fkc, err := root.GetForeignKeyCollection(ctx) + if err != nil { + return nil, err + } + + fkcUpdates, err := backupFkcIndexesForPkDrop(ctx, t.sch, fkc) + if err != nil { + return nil, err + } + + err = fkc.UpdateIndexes(ctx, t.sch, fkcUpdates) + if err != nil { + return nil, err + } + + root, err = root.PutForeignKeyCollection(ctx, fkc) + if err != nil { + return nil, err + } + + return root, nil +} + // DropColumn implements sql.AlterableTable func (t *AlterableDoltTable) DropColumn(ctx *sql.Context, columnName string) error { if types.IsFormat_DOLT_1(t.nbf) { @@ -1232,7 +1274,7 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c // Note that we aren't calling the public PartitionRows, because it always gets the table data from the session // root, which hasn't been updated yet - rowIter, err := partitionRows(ctx, updatedTable, t.projectedCols, index.SinglePartition{RowData: rowData}) + rowIter, err := partitionRows(ctx, updatedTable, t.sqlSch.Schema, t.projectedCols, index.SinglePartition{RowData: rowData}) if err != nil { return err } @@ -2184,12 +2226,22 @@ func (t *AlterableDoltTable) DropPrimaryKey(ctx *sql.Context) error { return err } - fkcUpdates, err := t.backupFkcIndexesForPkDrop(ctx, root) + fkc, err := root.GetForeignKeyCollection(ctx) if err != nil { return err } - newRoot, err := t.updateFkcIndex(ctx, root, fkcUpdates) + fkcUpdates, err := backupFkcIndexesForPkDrop(ctx, t.sch, fkc) + if err != nil { + return err + } + + err = fkc.UpdateIndexes(ctx, t.sch, fkcUpdates) + if err != nil { + return err + } + + newRoot, err := root.PutForeignKeyCollection(ctx, fkc) if err != nil { return err } @@ -2218,109 +2270,6 @@ func (t *AlterableDoltTable) DropPrimaryKey(ctx *sql.Context) error { return t.updateFromRoot(ctx, newRoot) } -type fkIndexUpdate struct { - fkName string - fromIdx string - toIdx string -} - -// updateFkcIndex applies a list of fkIndexUpdates to a ForeignKeyCollection and returns a new root value -func (t *AlterableDoltTable) updateFkcIndex(ctx *sql.Context, root *doltdb.RootValue, updates []fkIndexUpdate) (*doltdb.RootValue, error) { - fkc, err := root.GetForeignKeyCollection(ctx) - if err != nil { - return nil, err - } - - for _, u := range updates { - fk, ok := fkc.GetByNameCaseInsensitive(u.fkName) - if !ok { - return nil, errors.New("foreign key not found") - } - fkc.RemoveKeys(fk) - fk.ReferencedTableIndex = u.toIdx - fkc.AddKeys(fk) - err := fk.ValidateReferencedTableSchema(t.sch) - if err != nil { - return nil, err - } - } - - root, err = root.PutForeignKeyCollection(ctx, fkc) - if err != nil { - return nil, err - } - return root, nil -} - -// backupFkcIndexesForKeyDrop finds backup indexes to cover foreign key references during a primary -// key drop. If multiple indexes are valid, we sort by unique and select the first. -// This will not work with a non-pk index drop without an additional index filter argument. -func (t *AlterableDoltTable) backupFkcIndexesForPkDrop(ctx *sql.Context, root *doltdb.RootValue) ([]fkIndexUpdate, error) { - fkc, err := root.GetForeignKeyCollection(ctx) - if err != nil { - return nil, err - } - - indexes := t.sch.Indexes().AllIndexes() - if err != nil { - return nil, err - } - - // pkBackups is a mapping from the table's PK tags to potentially compensating indexes - pkBackups := make(map[uint64][]schema.Index, len(t.sch.GetPKCols().TagToIdx)) - for tag, _ := range t.sch.GetPKCols().TagToIdx { - pkBackups[tag] = nil - } - - // prefer unique key backups - sort.Slice(indexes[:], func(i, j int) bool { - return indexes[i].IsUnique() && !indexes[j].IsUnique() - }) - - for _, idx := range indexes { - if !idx.IsUserDefined() { - continue - } - - for _, tag := range idx.AllTags() { - if _, ok := pkBackups[tag]; ok { - pkBackups[tag] = append(pkBackups[tag], idx) - } - } - } - - fkUpdates := make([]fkIndexUpdate, 0) - for _, fk := range fkc.AllKeys() { - // check if this FK references a parent PK tag we are trying to change - if backups, ok := pkBackups[fk.ReferencedTableColumns[0]]; ok { - covered := false - for _, idx := range backups { - idxTags := idx.AllTags() - if len(fk.TableColumns) > len(idxTags) { - continue - } - failed := false - for i := 0; i < len(fk.ReferencedTableColumns); i++ { - if idxTags[i] != fk.ReferencedTableColumns[i] { - failed = true - break - } - } - if failed { - continue - } - fkUpdates = append(fkUpdates, fkIndexUpdate{fk.Name, fk.ReferencedTableIndex, idx.Name()}) - covered = true - break - } - if !covered { - return nil, sql.ErrCantDropIndex.New("PRIMARY") - } - } - } - return fkUpdates, nil -} - func findIndexWithPrefix(sch schema.Schema, prefixCols []string) (schema.Index, bool, error) { type idxWithLen struct { schema.Index diff --git a/go/libraries/doltcore/sqle/temp_table.go b/go/libraries/doltcore/sqle/temp_table.go index c0d7c2c5b7..ef2c06bb72 100644 --- a/go/libraries/doltcore/sqle/temp_table.go +++ b/go/libraries/doltcore/sqle/temp_table.go @@ -152,7 +152,7 @@ func (t *TempTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sq if t.lookup != nil { return index.RowIterForIndexLookup(ctx, t.table, t.lookup, t.pkSch, nil) } else { - return partitionRows(ctx, t.table, nil, partition) + return partitionRows(ctx, t.table, t.sqlSchema().Schema, nil, partition) } } diff --git a/go/libraries/doltcore/sqle/writer/prolly_fk_indexer.go b/go/libraries/doltcore/sqle/writer/prolly_fk_indexer.go index 4584aa1d81..8db0ceb29f 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_fk_indexer.go +++ b/go/libraries/doltcore/sqle/writer/prolly_fk_indexer.go @@ -86,6 +86,7 @@ func (n prollyFkIndexer) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.R rangeIter: rangeIter, idxToPkMap: idxToPkMap, primary: primary, + sqlSch: n.writer.sqlSch, }, nil } else { rangeIter, err := idxWriter.(prollyKeylessSecondaryWriter).mut.IterRange(ctx, n.pRange) @@ -95,6 +96,7 @@ func (n prollyFkIndexer) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.R return &prollyFkKeylessRowIter{ rangeIter: rangeIter, primary: n.writer.primary.(prollyKeylessWriter), + sqlSch: n.writer.sqlSch, }, nil } } @@ -104,6 +106,7 @@ type prollyFkPkRowIter struct { rangeIter prolly.MapIter idxToPkMap map[int]int primary prollyIndexWriter + sqlSch sql.Schema } var _ sql.RowIter = prollyFkPkRowIter{} @@ -140,7 +143,10 @@ func (iter prollyFkPkRowIter) Next(ctx *sql.Context) (sql.Row, error) { } return nil }) - return nextRow, err + if err != nil { + return nil, err + } + return index.DenormalizeRow(iter.sqlSch, nextRow) } // Close implements the interface sql.RowIter. @@ -152,6 +158,7 @@ func (iter prollyFkPkRowIter) Close(ctx *sql.Context) error { type prollyFkKeylessRowIter struct { rangeIter prolly.MapIter primary prollyKeylessWriter + sqlSch sql.Schema } var _ sql.RowIter = prollyFkKeylessRowIter{} @@ -179,7 +186,10 @@ func (iter prollyFkKeylessRowIter) Next(ctx *sql.Context) (sql.Row, error) { } return nil }) - return nextRow, err + if err != nil { + return nil, err + } + return index.DenormalizeRow(iter.sqlSch, nextRow) } // Close implements the interface sql.RowIter. diff --git a/go/libraries/doltcore/sqle/writer/prolly_table_writer.go b/go/libraries/doltcore/sqle/writer/prolly_table_writer.go index 12aec71750..98aec958f0 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_table_writer.go +++ b/go/libraries/doltcore/sqle/writer/prolly_table_writer.go @@ -122,7 +122,14 @@ func getSecondaryKeylessProllyWriters(ctx context.Context, t *doltdb.Table, sqlS } // Insert implements TableWriter. -func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { +func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) (err error) { + if sqlRow, err = index.NormalizeRow(w.sqlSch, sqlRow); err != nil { + return err + } + + if err := w.primary.Insert(ctx, sqlRow); err != nil { + return err + } for _, wr := range w.secondary { if err := wr.Insert(ctx, sqlRow); err != nil { if sql.ErrUniqueKeyViolation.Is(err) { @@ -131,14 +138,15 @@ func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { return err } } - if err := w.primary.Insert(ctx, sqlRow); err != nil { - return err - } return nil } // Delete implements TableWriter. -func (w *prollyTableWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error { +func (w *prollyTableWriter) Delete(ctx *sql.Context, sqlRow sql.Row) (err error) { + if sqlRow, err = index.NormalizeRow(w.sqlSch, sqlRow); err != nil { + return err + } + for _, wr := range w.secondary { if err := wr.Delete(ctx, sqlRow); err != nil { return err @@ -152,6 +160,13 @@ func (w *prollyTableWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error { // Update implements TableWriter. func (w *prollyTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) (err error) { + if oldRow, err = index.NormalizeRow(w.sqlSch, oldRow); err != nil { + return err + } + if newRow, err = index.NormalizeRow(w.sqlSch, newRow); err != nil { + return err + } + for _, wr := range w.secondary { if err := wr.Update(ctx, oldRow, newRow); err != nil { if sql.ErrUniqueKeyViolation.Is(err) { @@ -244,8 +259,9 @@ func (w *prollyTableWriter) Reset(ctx context.Context, sess *prollyWriteSession, } aiCol := autoIncrementColFromSchema(sch) var newPrimary indexWriter + var newSecondaries []indexWriter - if _, ok := w.primary.(prollyKeylessWriter); ok { + if schema.IsKeyless(sch) { newPrimary, err = getPrimaryKeylessProllyWriter(ctx, tbl, sqlSch.Schema, sch) if err != nil { return err @@ -324,9 +340,6 @@ func (w *prollyTableWriter) table(ctx context.Context) (t *doltdb.Table, err err } func (w *prollyTableWriter) flush(ctx *sql.Context) error { - if !w.primary.HasEdits(ctx) { - return nil - } ws, err := w.flusher.Flush(ctx) if err != nil { return err diff --git a/go/libraries/doltcore/sqle/writer/prolly_write_session.go b/go/libraries/doltcore/sqle/writer/prolly_write_session.go index dc526384e7..0ff8f18e26 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_write_session.go +++ b/go/libraries/doltcore/sqle/writer/prolly_write_session.go @@ -189,6 +189,7 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err return nil, err } } + s.workingSet = s.workingSet.WithWorkingRoot(flushed) return s.workingSet, nil diff --git a/go/serial/encoding.fbs b/go/serial/encoding.fbs index 04f29d8443..0179b0f65f 100644 --- a/go/serial/encoding.fbs +++ b/go/serial/encoding.fbs @@ -27,11 +27,14 @@ enum Encoding : uint8 { Uint64 = 10, Float32 = 11, Float64 = 12, - Hash128 = 13, - Year = 14, - Date = 15, - Time = 16, - Datetime = 17, + Bit64 = 13, + Hash128 = 14, + Year = 15, + Date = 16, + Time = 17, + Datetime = 18, + Enum = 19, + Set = 20, // variable width String = 128, @@ -40,4 +43,3 @@ enum Encoding : uint8 { JSON = 131, Geometry = 133, } - diff --git a/go/store/cmd/noms/noms_show.go b/go/store/cmd/noms/noms_show.go index 498f55c891..6e6023a17f 100644 --- a/go/store/cmd/noms/noms_show.go +++ b/go/store/cmd/noms/noms_show.go @@ -31,8 +31,10 @@ import ( flag "github.com/juju/gnuflag" + "github.com/dolthub/dolt/go/gen/fb/serial" "github.com/dolthub/dolt/go/store/cmd/noms/util" "github.com/dolthub/dolt/go/store/config" + "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/prolly" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" @@ -139,9 +141,35 @@ func outputType(value interface{}) { case tree.Node: typeString = "prolly.Node" case types.Value: - t, err := types.TypeOf(value) - typeString = t.HumanReadableString() - util.CheckError(err) + switch value := value.(type) { + case types.SerialMessage: + switch serial.GetFileID(value) { + case serial.StoreRootFileID: + typeString = "StoreRoot" + case serial.TagFileID: + typeString = "Tag" + case serial.WorkingSetFileID: + typeString = "WorkingSet" + case serial.CommitFileID: + typeString = "Commit" + case serial.RootValueFileID: + typeString = "RootValue" + case serial.TableFileID: + typeString = "TableFile" + case serial.ProllyTreeNodeFileID: + typeString = "ProllyTreeNode" + case serial.AddressMapFileID: + typeString = "AddressMap" + default: + t, err := types.TypeOf(value) + typeString = t.HumanReadableString() + util.CheckError(err) + } + default: + t, err := types.TypeOf(value) + typeString = t.HumanReadableString() + util.CheckError(err) + } default: typeString = fmt.Sprintf("unknown type %T", value) } @@ -156,7 +184,46 @@ func outputEncodedValue(ctx context.Context, w io.Writer, value interface{}) err case tree.Node: return tree.OutputProllyNode(w, value) case types.Value: - return types.WriteEncodedValue(ctx, w, value) + switch value := value.(type) { + // Some types of serial message need to be output here because of dependency cycles between type / tree package + case types.SerialMessage: + switch serial.GetFileID(value) { + case serial.TableFileID: + msg := serial.GetRootAsTable(value, 0) + + fmt.Fprintf(w, "{\n") + fmt.Fprintf(w, "\tSchema: #%s\n", hash.New(msg.SchemaBytes()).String()) + fmt.Fprintf(w, "\tViolations: #%s\n", hash.New(msg.ViolationsBytes()).String()) + // TODO: merge conflicts, not stable yet + + fmt.Fprintf(w, "\tAutoinc: %d\n", msg.AutoIncrementValue()) + + fmt.Fprintf(w, "\tPrimary index: {\n") + node := tree.NodeFromBytes(msg.PrimaryIndexBytes()) + tree.OutputProllyNode(w, node) + fmt.Fprintf(w, "\t}\n") + + fmt.Fprintf(w, "\tSecondary indexes: {\n") + idxRefs := msg.SecondaryIndexes(nil) + hashes := idxRefs.RefArrayBytes() + for i := 0; i < idxRefs.NamesLength(); i++ { + name := idxRefs.Names(i) + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(w, "\t\t%s: #%s\n", name, addr.String()) + } + fmt.Fprintf(w, "\t}\n") + fmt.Fprintf(w, "}") + + return nil + case serial.ProllyTreeNodeFileID: + node := prolly.NodeFromValue(value) + return tree.OutputProllyNode(w, node) + default: + return types.WriteEncodedValue(ctx, w, value) + } + default: + return types.WriteEncodedValue(ctx, w, value) + } default: _, err := w.Write([]byte(fmt.Sprintf("unknown value type %T: %v", value, value))) return err diff --git a/go/store/cmd/noms/splunk.pl b/go/store/cmd/noms/splunk.pl index 1f2afacbe1..251e135d79 100755 --- a/go/store/cmd/noms/splunk.pl +++ b/go/store/cmd/noms/splunk.pl @@ -85,10 +85,15 @@ sub print_show { my $noms_show_output = show($hash); for my $line (split /\n/, $noms_show_output) { - if ($line =~ /#([a-z0-9]{32})/) { - $hashes{$label} = $1; - print "$label) $line\n"; - $label++; + if ($line =~ /#([a-z0-9]{32})/ ) { + $h = $1; + if ( $1 =~ /[a-z1-9]/ ) { + $hashes{$label} = $h; + print "$label) $line\n"; + $label++; + } else { + print " $line\n"; + } } else { print " $line\n"; } diff --git a/go/store/prolly/shim.go b/go/store/prolly/shim.go index 01dc88701b..56e5e9a1fd 100644 --- a/go/store/prolly/shim.go +++ b/go/store/prolly/shim.go @@ -132,20 +132,11 @@ func encodingFromSqlType(typ query.Type) val.Encoding { // todo(andy): replace temp encodings switch typ { - case query.Type_DECIMAL: - return val.DecimalEnc - case query.Type_GEOMETRY: - return val.GeometryEnc - case query.Type_BIT: - return val.Uint64Enc case query.Type_BLOB: - return val.ByteStringEnc + // todo: temporary hack for enginetests + return val.StringEnc case query.Type_TEXT: return val.StringEnc - case query.Type_ENUM: - return val.StringEnc - case query.Type_SET: - return val.StringEnc case query.Type_JSON: return val.JSONEnc } @@ -175,6 +166,10 @@ func encodingFromSqlType(typ query.Type) val.Encoding { return val.Float32Enc case query.Type_FLOAT64: return val.Float64Enc + case query.Type_BIT: + return val.Uint64Enc + case query.Type_DECIMAL: + return val.DecimalEnc case query.Type_YEAR: return val.YearEnc case query.Type_DATE: @@ -185,6 +180,10 @@ func encodingFromSqlType(typ query.Type) val.Encoding { return val.DatetimeEnc case query.Type_DATETIME: return val.DatetimeEnc + case query.Type_ENUM: + return val.EnumEnc + case query.Type_SET: + return val.SetEnc case query.Type_BINARY: return val.ByteStringEnc case query.Type_VARBINARY: @@ -193,6 +192,8 @@ func encodingFromSqlType(typ query.Type) val.Encoding { return val.StringEnc case query.Type_VARCHAR: return val.StringEnc + case query.Type_GEOMETRY: + return val.GeometryEnc default: panic(fmt.Sprintf("unknown encoding %v", typ)) } diff --git a/go/store/types/serial_message.go b/go/store/types/serial_message.go index fac0182a85..6d7de572d9 100644 --- a/go/store/types/serial_message.go +++ b/go/store/types/serial_message.go @@ -20,6 +20,7 @@ import ( "fmt" "math" "strings" + "time" "github.com/dolthub/dolt/go/gen/fb/serial" "github.com/dolthub/dolt/go/store/hash" @@ -51,7 +52,8 @@ func (sm SerialMessage) Hash(nbf *NomsBinFormat) (hash.Hash, error) { } func (sm SerialMessage) HumanReadableString() string { - if serial.GetFileID([]byte(sm)) == serial.StoreRootFileID { + switch serial.GetFileID(sm) { + case serial.StoreRootFileID: msg := serial.GetRootAsStoreRoot([]byte(sm), 0) ret := &strings.Builder{} refs := msg.Refs(nil) @@ -59,13 +61,103 @@ func (sm SerialMessage) HumanReadableString() string { hashes := refs.RefArrayBytes() for i := 0; i < refs.NamesLength(); i++ { name := refs.Names(i) - addr := hash.New(hashes[:20]) - fmt.Fprintf(ret, " %s: %s\n", name, addr.String()) + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(ret, "\t%s: #%s\n", name, addr.String()) } fmt.Fprintf(ret, "}") return ret.String() + case serial.TagFileID: + return "Tag" + case serial.WorkingSetFileID: + msg := serial.GetRootAsWorkingSet(sm, 0) + ret := &strings.Builder{} + fmt.Fprintf(ret, "{\n") + fmt.Fprintf(ret, "\tName: %s\n", msg.Name()) + fmt.Fprintf(ret, "\tDesc: %s\n", msg.Desc()) + fmt.Fprintf(ret, "\tEmail: %s\n", msg.Email()) + fmt.Fprintf(ret, "\tTime: %s\n", time.UnixMilli((int64)(msg.TimestampMillis())).String()) + fmt.Fprintf(ret, "\tWorkingRootAddr: #%s\n", hash.New(msg.WorkingRootAddrBytes()).String()) + fmt.Fprintf(ret, "\tStagedRootAddr: #%s\n", hash.New(msg.StagedRootAddrBytes()).String()) + fmt.Fprintf(ret, "}") + return ret.String() + case serial.CommitFileID: + msg := serial.GetRootAsCommit(sm, 0) + ret := &strings.Builder{} + fmt.Fprintf(ret, "{\n") + fmt.Fprintf(ret, "\tName: %s\n", msg.Name()) + fmt.Fprintf(ret, "\tDesc: %s\n", msg.Description()) + fmt.Fprintf(ret, "\tEmail: %s\n", msg.Email()) + fmt.Fprintf(ret, "\tTime: %s\n", time.UnixMilli((int64)(msg.TimestampMillis())).String()) + fmt.Fprintf(ret, "\tHeight: %d\n", msg.Height()) + + fmt.Fprintf(ret, "\tParents: {\n") + hashes := msg.ParentAddrsBytes() + for i := 0; i < msg.ParentAddrsLength()/hash.ByteLen; i++ { + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(ret, "\t\t#%s\n", addr.String()) + } + fmt.Fprintf(ret, "\t}\n") + + fmt.Fprintf(ret, "\tParentClosure: {\n") + hashes = msg.ParentClosureBytes() + for i := 0; i < msg.ParentClosureLength()/hash.ByteLen; i++ { + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(ret, "\t\t#%s\n", addr.String()) + } + fmt.Fprintf(ret, "\t}\n") + + fmt.Fprintf(ret, "}") + return ret.String() + case serial.RootValueFileID: + msg := serial.GetRootAsRootValue(sm, 0) + ret := &strings.Builder{} + fmt.Fprintf(ret, "{\n") + fmt.Fprintf(ret, "\tFeatureVersion: %d\n", msg.FeatureVersion()) + fmt.Fprintf(ret, "\tForeignKeys: #%s\n", hash.New(msg.ForeignKeyAddrBytes()).String()) + fmt.Fprintf(ret, "\tSuperSchema: #%s\n", hash.New(msg.SuperSchemasAddrBytes()).String()) + fmt.Fprintf(ret, "\tTables: {\n") + tableRefs := msg.Tables(nil) + hashes := tableRefs.RefArrayBytes() + for i := 0; i < tableRefs.NamesLength(); i++ { + name := tableRefs.Names(i) + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(ret, "\t\t%s: #%s\n", name, addr.String()) + } + fmt.Fprintf(ret, "\t}\n") + fmt.Fprintf(ret, "}") + return ret.String() + case serial.TableFileID: + msg := serial.GetRootAsTable(sm, 0) + ret := &strings.Builder{} + + fmt.Fprintf(ret, "{\n") + fmt.Fprintf(ret, "\tSchema: #%s\n", hash.New(msg.SchemaBytes()).String()) + fmt.Fprintf(ret, "\tViolations: #%s\n", hash.New(msg.ViolationsBytes()).String()) + // TODO: merge conflicts, not stable yet + + fmt.Fprintf(ret, "\tAutoinc: %d\n", msg.AutoIncrementValue()) + + // TODO: can't use tree package to print here, creates a cycle + fmt.Fprintf(ret, "\tPrimary index: prolly tree\n") + + fmt.Fprintf(ret, "\tSecondary indexes: {\n") + idxRefs := msg.SecondaryIndexes(nil) + hashes := idxRefs.RefArrayBytes() + for i := 0; i < idxRefs.NamesLength(); i++ { + name := idxRefs.Names(i) + addr := hash.New(hashes[i*20 : (i+1)*20]) + fmt.Fprintf(ret, "\t\t%s: #%s\n", name, addr.String()) + } + fmt.Fprintf(ret, "\t}\n") + fmt.Fprintf(ret, "}") + return ret.String() + case serial.ProllyTreeNodeFileID: + return "ProllyTreeNode" + case serial.AddressMapFileID: + return "AddressMap" + default: + return "SerialMessage (HumanReadableString not implemented)" } - return "SerialMessage" } func (sm SerialMessage) Less(nbf *NomsBinFormat, other LesserValuable) (bool, error) { diff --git a/go/store/val/codec.go b/go/store/val/codec.go index bd79be50e3..2d942a600f 100644 --- a/go/store/val/codec.go +++ b/go/store/val/codec.go @@ -50,11 +50,14 @@ const ( uint64Size ByteSize = 8 float32Size ByteSize = 4 float64Size ByteSize = 8 + bit64Size ByteSize = 8 hash128Size ByteSize = 16 yearSize ByteSize = 1 dateSize ByteSize = 4 timeSize ByteSize = 8 datetimeSize ByteSize = 8 + enumSize ByteSize = 2 + setSize ByteSize = 8 ) type Encoding byte @@ -72,11 +75,14 @@ const ( Uint64Enc = Encoding(serial.EncodingUint64) Float32Enc = Encoding(serial.EncodingFloat32) Float64Enc = Encoding(serial.EncodingFloat64) + Bit64Enc = Encoding(serial.EncodingBit64) Hash128Enc = Encoding(serial.EncodingHash128) YearEnc = Encoding(serial.EncodingYear) DateEnc = Encoding(serial.EncodingDate) TimeEnc = Encoding(serial.EncodingTime) DatetimeEnc = Encoding(serial.EncodingDatetime) + EnumEnc = Encoding(serial.EncodingEnum) + SetEnc = Encoding(serial.EncodingSet) sentinel Encoding = 127 ) @@ -121,16 +127,22 @@ func sizeFromType(t Type) (ByteSize, bool) { return float32Size, true case Float64Enc: return float64Size, true + case Hash128Enc: + return hash128Size, true case YearEnc: return yearSize, true case DateEnc: return dateSize, true - //case TimeEnc: - // return timeSize, true + case TimeEnc: + return timeSize, true case DatetimeEnc: return datetimeSize, true - case Hash128Enc: - return hash128Size, true + case EnumEnc: + return enumSize, true + case SetEnc: + return setSize, true + case Bit64Enc: + return bit64Size, true default: return 0, false } @@ -361,6 +373,18 @@ func compareFloat64(l, r float64) int { } } +func readBit64(val []byte) uint64 { + return readUint64(val) +} + +func writeBit64(buf []byte, val uint64) { + writeUint64(buf, val) +} + +func compareBit64(l, r uint64) int { + return compareUint64(l, r) +} + func readDecimal(val []byte) decimal.Decimal { e := readInt32(val[:int32Size]) s := readInt8(val[int32Size : int32Size+int8Size]) @@ -469,6 +493,30 @@ func compareDatetime(l, r time.Time) int { } } +func readEnum(val []byte) uint16 { + return readUint16(val) +} + +func writeEnum(buf []byte, val uint16) { + writeUint16(buf, val) +} + +func compareEnum(l, r uint16) int { + return compareUint16(l, r) +} + +func readSet(val []byte) uint64 { + return readUint64(val) +} + +func writeSet(buf []byte, val uint64) { + writeUint64(buf, val) +} + +func compareSet(l, r uint64) int { + return compareUint64(l, r) +} + func readString(val []byte) string { return stringFromBytes(readByteString(val)) } diff --git a/go/store/val/codec_test.go b/go/store/val/codec_test.go index 7c54906e55..a204d813fd 100644 --- a/go/store/val/codec_test.go +++ b/go/store/val/codec_test.go @@ -78,6 +78,22 @@ func TestCompare(t *testing.T) { l: encFloat(1), r: encFloat(0), cmp: 1, }, + // bit + { + typ: Type{Enc: Bit64Enc}, + l: encBit(0), r: encBit(0), + cmp: 0, + }, + { + typ: Type{Enc: Bit64Enc}, + l: encBit(0), r: encBit(1), + cmp: -1, + }, + { + typ: Type{Enc: Bit64Enc}, + l: encBit(1), r: encBit(0), + cmp: 1, + }, // decimal { typ: Type{Enc: DecimalEnc}, @@ -161,6 +177,38 @@ func TestCompare(t *testing.T) { r: encDatetime(time.Date(2000, 11, 01, 01, 01, 01, 00, time.UTC)), cmp: -1, }, + // enum + { + typ: Type{Enc: EnumEnc}, + l: encEnum(0), r: encEnum(0), + cmp: 0, + }, + { + typ: Type{Enc: EnumEnc}, + l: encEnum(0), r: encEnum(1), + cmp: -1, + }, + { + typ: Type{Enc: EnumEnc}, + l: encEnum(1), r: encEnum(0), + cmp: 1, + }, + // set + { + typ: Type{Enc: SetEnc}, + l: encSet(0), r: encSet(0), + cmp: 0, + }, + { + typ: Type{Enc: SetEnc}, + l: encSet(0), r: encSet(1), + cmp: -1, + }, + { + typ: Type{Enc: SetEnc}, + l: encSet(1), r: encSet(0), + cmp: 1, + }, // string { typ: Type{Enc: StringEnc}, @@ -231,6 +279,12 @@ func encFloat(f float64) []byte { return buf } +func encBit(u uint64) []byte { + buf := make([]byte, bit64Size) + writeBit64(buf, u) + return buf +} + func encDecimal(d decimal.Decimal) []byte { buf := make([]byte, sizeOfDecimal(d)) writeDecimal(buf, d) @@ -268,6 +322,18 @@ func encDatetime(dt time.Time) []byte { return buf } +func encEnum(u uint16) []byte { + buf := make([]byte, enumSize) + writeEnum(buf, u) + return buf +} + +func encSet(u uint64) []byte { + buf := make([]byte, setSize) + writeSet(buf, u) + return buf +} + func TestCodecRoundTrip(t *testing.T) { t.Run("round trip bool", func(t *testing.T) { roundTripBools(t) @@ -365,6 +431,14 @@ func roundTripUints(t *testing.T) { zero(buf) } + buf = make([]byte, enumSize) + for _, value := range uintegers { + exp := uint16(value) + writeEnum(buf, exp) + assert.Equal(t, exp, readEnum(buf)) + zero(buf) + } + buf = make([]byte, uint32Size) uintegers = append(uintegers, math.MaxUint32) for _, value := range uintegers { @@ -382,6 +456,22 @@ func roundTripUints(t *testing.T) { assert.Equal(t, exp, readUint64(buf)) zero(buf) } + + buf = make([]byte, bit64Size) + for _, value := range uintegers { + exp := uint64(value) + writeBit64(buf, exp) + assert.Equal(t, exp, readBit64(buf)) + zero(buf) + } + + buf = make([]byte, setSize) + for _, value := range uintegers { + exp := uint64(value) + writeSet(buf, exp) + assert.Equal(t, exp, readSet(buf)) + zero(buf) + } } func roundTripFloats(t *testing.T) { @@ -467,7 +557,9 @@ func roundTripDatetimes(t *testing.T) { func roundTripDecimal(t *testing.T) { decimals := []decimal.Decimal{ + decimalFromString("0"), decimalFromString("1"), + decimalFromString("-1"), decimalFromString("-3.7e0"), decimalFromString("0.00000000000000000003e20"), decimalFromString(".22"), diff --git a/go/store/val/tuple_builder.go b/go/store/val/tuple_builder.go index de83615e1a..5e27a9a171 100644 --- a/go/store/val/tuple_builder.go +++ b/go/store/val/tuple_builder.go @@ -159,6 +159,21 @@ func (tb *TupleBuilder) PutFloat64(i int, v float64) { tb.pos += float64Size } +func (tb *TupleBuilder) PutBit(i int, v uint64) { + tb.Desc.expectEncoding(i, Bit64Enc) + tb.fields[i] = tb.buf[tb.pos : tb.pos+bit64Size] + writeBit64(tb.fields[i], v) + tb.pos += bit64Size +} + +func (tb *TupleBuilder) PutDecimal(i int, v decimal.Decimal) { + tb.Desc.expectEncoding(i, DecimalEnc) + sz := sizeOfDecimal(v) + tb.fields[i] = tb.buf[tb.pos : tb.pos+sz] + writeDecimal(tb.fields[i], v) + tb.pos += sz +} + // PutYear writes an int16-encoded year to the ith field of the Tuple being built. func (tb *TupleBuilder) PutYear(i int, v int16) { tb.Desc.expectEncoding(i, YearEnc) @@ -189,12 +204,18 @@ func (tb *TupleBuilder) PutDatetime(i int, v time.Time) { tb.pos += datetimeSize } -func (tb *TupleBuilder) PutDecimal(i int, v decimal.Decimal) { - tb.Desc.expectEncoding(i, DecimalEnc) - sz := sizeOfDecimal(v) - tb.fields[i] = tb.buf[tb.pos : tb.pos+sz] - writeDecimal(tb.fields[i], v) - tb.pos += sz +func (tb *TupleBuilder) PutEnum(i int, v uint16) { + tb.Desc.expectEncoding(i, EnumEnc) + tb.fields[i] = tb.buf[tb.pos : tb.pos+enumSize] + writeEnum(tb.fields[i], v) + tb.pos += enumSize +} + +func (tb *TupleBuilder) PutSet(i int, v uint64) { + tb.Desc.expectEncoding(i, SetEnc) + tb.fields[i] = tb.buf[tb.pos : tb.pos+setSize] + writeSet(tb.fields[i], v) + tb.pos += setSize } // PutString writes a string to the ith field of the Tuple being built. diff --git a/go/store/val/tuple_compare.go b/go/store/val/tuple_compare.go index 773255e15e..4d8e204077 100644 --- a/go/store/val/tuple_compare.go +++ b/go/store/val/tuple_compare.go @@ -90,6 +90,10 @@ func compare(typ Type, left, right []byte) int { return compareFloat32(readFloat32(left), readFloat32(right)) case Float64Enc: return compareFloat64(readFloat64(left), readFloat64(right)) + case Bit64Enc: + return compareBit64(readBit64(left), readBit64(right)) + case DecimalEnc: + return compareDecimal(readDecimal(left), readDecimal(right)) case YearEnc: return compareYear(readYear(left), readYear(right)) case DateEnc: @@ -98,8 +102,10 @@ func compare(typ Type, left, right []byte) int { return compareTime(readTime(left), readTime(right)) case DatetimeEnc: return compareDatetime(readDatetime(left), readDatetime(right)) - case DecimalEnc: - return compareDecimal(readDecimal(left), readDecimal(right)) + case EnumEnc: + return compareEnum(readEnum(left), readEnum(right)) + case SetEnc: + return compareSet(readSet(left), readSet(right)) case StringEnc: return compareString(readString(left), readString(right)) case ByteStringEnc: diff --git a/go/store/val/tuple_descriptor.go b/go/store/val/tuple_descriptor.go index 0ef1c019ad..7751b5e174 100644 --- a/go/store/val/tuple_descriptor.go +++ b/go/store/val/tuple_descriptor.go @@ -241,6 +241,17 @@ func (td TupleDesc) GetFloat64(i int, tup Tuple) (v float64, ok bool) { return } +// GetBit reads a uint64 from the ith field of the Tuple. +// If the ith field is NULL, |ok| is set to false. +func (td TupleDesc) GetBit(i int, tup Tuple) (v uint64, ok bool) { + td.expectEncoding(i, Bit64Enc) + b := td.GetField(i, tup) + if b != nil { + v, ok = readBit64(b), true + } + return +} + // GetDecimal reads a float64 from the ith field of the Tuple. // If the ith field is NULL, |ok| is set to false. func (td TupleDesc) GetDecimal(i int, tup Tuple) (v decimal.Decimal, ok bool) { @@ -296,6 +307,28 @@ func (td TupleDesc) GetDatetime(i int, tup Tuple) (v time.Time, ok bool) { return } +// GetEnum reads a uin16 from the ith field of the Tuple. +// If the ith field is NULL, |ok| is set to false. +func (td TupleDesc) GetEnum(i int, tup Tuple) (v uint16, ok bool) { + td.expectEncoding(i, EnumEnc) + b := td.GetField(i, tup) + if b != nil { + v, ok = readEnum(b), true + } + return +} + +// GetSet reads a uint64 from the ith field of the Tuple. +// If the ith field is NULL, |ok| is set to false. +func (td TupleDesc) GetSet(i int, tup Tuple) (v uint64, ok bool) { + td.expectEncoding(i, SetEnc) + b := td.GetField(i, tup) + if b != nil { + v, ok = readSet(b), true + } + return +} + // GetString reads a string from the ith field of the Tuple. // If the ith field is NULL, |ok| is set to false. func (td TupleDesc) GetString(i int, tup Tuple) (v string, ok bool) { @@ -423,19 +456,30 @@ func formatValue(enc Encoding, value []byte) string { case Float64Enc: v := readFloat64(value) return fmt.Sprintf("%f", v) + case Bit64Enc: + v := readUint64(value) + return strconv.FormatUint(v, 10) + case DecimalEnc: + v := readDecimal(value) + return v.String() case YearEnc: v := readYear(value) return strconv.Itoa(int(v)) case DateEnc: v := readDate(value) return v.Format("2006-01-02") - //case TimeEnc: - // // todo(andy) - // v := readTime(value) - // return v + case TimeEnc: + v := readTime(value) + return strconv.FormatInt(v, 10) case DatetimeEnc: v := readDatetime(value) return v.Format(time.RFC3339) + case EnumEnc: + v := readEnum(value) + return strconv.Itoa(int(v)) + case SetEnc: + v := readSet(value) + return strconv.FormatUint(v, 10) case StringEnc: return readString(value) case ByteStringEnc: diff --git a/integration-tests/MySQLDockerfile b/integration-tests/MySQLDockerfile index 1e6c47e1e2..a732771c79 100644 --- a/integration-tests/MySQLDockerfile +++ b/integration-tests/MySQLDockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:18.04 +FROM --platform=linux/amd64 ubuntu:18.04 # install python, libmysqlclient-dev, java, bats, git ruby, perl, cpan ENV DEBIAN_FRONTEND=noninteractive @@ -74,8 +74,13 @@ RUN curl -LO https://download.visualstudio.microsoft.com/download/pr/13b9d84c-a3 tar -C /usr/local/bin -xzf dotnet-sdk-5.0.400-linux-x64.tar.gz && \ dotnet --version +# install pip for python3.8 +RUN curl -LO https://bootstrap.pypa.io/get-pip.py && \ + python3.8 get-pip.py && \ + pip --version + # install mysql connector and pymsql -RUN pip3 install mysql-connector-python PyMySQL sqlalchemy +RUN pip install mysql-connector-python PyMySQL sqlalchemy # Setup JAVA_HOME -- useful for docker commandline ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ diff --git a/integration-tests/ORMDockerfile b/integration-tests/ORMDockerfile index 3134dd5308..e49f4b4ac1 100644 --- a/integration-tests/ORMDockerfile +++ b/integration-tests/ORMDockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:18.04 +FROM --platform=linux/amd64 ubuntu:18.04 # install peewee ENV DEBIAN_FRONTEND=noninteractive diff --git a/integration-tests/bats/constraint-violations.bats b/integration-tests/bats/constraint-violations.bats index 52bbcca17a..14acd83178 100644 --- a/integration-tests/bats/constraint-violations.bats +++ b/integration-tests/bats/constraint-violations.bats @@ -42,7 +42,7 @@ SQL [[ "$output" =~ "test" ]] || false run dolt merge other [ "$status" -eq "1" ] - [[ "$output" =~ "constraint violation" ]] || false + [[ "$output" =~ "Merging is not possible because you have not committed an active merge" ]] || false # we can stage conflicts, but not commit them dolt add test diff --git a/integration-tests/bats/merge.bats b/integration-tests/bats/merge.bats index 5eab253578..00f78140b2 100644 --- a/integration-tests/bats/merge.bats +++ b/integration-tests/bats/merge.bats @@ -655,3 +655,177 @@ SQL [[ "$output" =~ "test1" ]] || false [[ ! "$output" =~ "test2" ]] || false } + +@test "merge: non-violating merge succeeds when violations already exist" { + skip_nbf_dolt_1 + dolt sql < " { send "select user from mysql.user;\r"; } +} + +# look for only root user +expect { + "root" +} + +# quit +expect { + "mysql> " { exit 0 } +} \ No newline at end of file diff --git a/integration-tests/bats/sql-client.bats b/integration-tests/bats/sql-client.bats index a1f37beab2..58c2de2fc3 100644 --- a/integration-tests/bats/sql-client.bats +++ b/integration-tests/bats/sql-client.bats @@ -25,7 +25,7 @@ SQL } show_tables() { - dolt sql-client --host=0.0.0.0 --port=$PORT --user=dolt =<