diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index 0897540733..2e4fafefea 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -50,7 +50,7 @@ type testPerson struct { Title string } -type testBranch struct { +type testResult struct { Branch string } @@ -297,6 +297,12 @@ func TestServerFailsIfPortInUse(t *testing.T) { wg.Wait() } +type defaultBranchTest struct { + query *dbr.SelectStmt + expectedRes []testResult + expectedErrStr string +} + func TestServerSetDefaultBranch(t *testing.T) { dEnv, err := sqle.CreateEnvWithSeedData() require.NoError(t, err) @@ -316,114 +322,126 @@ func TestServerSetDefaultBranch(t *testing.T) { const dbName = "dolt" + defaultBranch := env.DefaultInitBranch + conn, err := dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) require.NoError(t, err) sess := conn.NewSession(nil) - defaultBranch := env.DefaultInitBranch - - tests := []struct { - query *dbr.SelectStmt - expectedRes []testBranch - }{ + tests := []defaultBranchTest{ { - query: sess.Select("active_branch() as branch"), - expectedRes: []testBranch{{defaultBranch}}, - }, - { - query: sess.SelectBySql("set GLOBAL dolt_default_branch = 'refs/heads/new'"), - expectedRes: []testBranch{}, - }, - { - query: sess.Select("active_branch() as branch"), - expectedRes: []testBranch{{defaultBranch}}, + query: sess.SelectBySql("select active_branch() as branch"), + expectedRes: []testResult{{defaultBranch}}, }, { query: sess.SelectBySql("call dolt_checkout('-b', 'new')"), - expectedRes: []testBranch{{""}}, + expectedRes: []testResult{{""}}, + }, + { + query: sess.SelectBySql("call dolt_checkout('-b', 'new2')"), + expectedRes: []testResult{{""}}, + }, + } + + runDefaultBranchTests(t, tests, conn) + + conn, err = dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) + require.NoError(t, err) + sess = conn.NewSession(nil) + + tests = []defaultBranchTest{ + { + query: sess.SelectBySql("select active_branch() as branch"), + expectedRes: []testResult{{defaultBranch}}, + }, + { + query: sess.SelectBySql("set GLOBAL dolt_default_branch = 'refs/heads/new'"), + expectedRes: nil, + }, + { + query: sess.SelectBySql("select active_branch() as branch"), + expectedRes: []testResult{{"main"}}, }, { query: sess.SelectBySql("call dolt_checkout('main')"), - expectedRes: []testBranch{{""}}, + expectedRes: []testResult{{""}}, }, } - for _, test := range tests { - t.Run(test.query.Query, func(t *testing.T) { - var branch []testBranch - _, err := test.query.LoadContext(context.Background(), &branch) - assert.NoError(t, err) - assert.ElementsMatch(t, branch, test.expectedRes) - }) - } - conn.Close() + runDefaultBranchTests(t, tests, conn) conn, err = dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) require.NoError(t, err) - defer conn.Close() - sess = conn.NewSession(nil) - tests = []struct { - query *dbr.SelectStmt - expectedRes []testBranch - }{ + tests = []defaultBranchTest{ { - query: sess.Select("active_branch() as branch"), - expectedRes: []testBranch{{"new"}}, + query: sess.SelectBySql("select active_branch() as branch"), + expectedRes: []testResult{{"new"}}, }, { - query: sess.SelectBySql("set GLOBAL dolt_default_branch = 'new'"), - expectedRes: []testBranch{}, + query: sess.SelectBySql("set GLOBAL dolt_default_branch = 'new2'"), + expectedRes: nil, }, } - defer func(sess *dbr.Session) { - var res []struct { - int - } - sess.SelectBySql("set GLOBAL dolt_default_branch = ''").LoadContext(context.Background(), &res) - }(sess) - - for _, test := range tests { - t.Run(test.query.Query, func(t *testing.T) { - var branch []testBranch - _, err := test.query.LoadContext(context.Background(), &branch) - assert.NoError(t, err) - assert.ElementsMatch(t, branch, test.expectedRes) - }) - } - conn.Close() + runDefaultBranchTests(t, tests, conn) conn, err = dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) require.NoError(t, err) - defer conn.Close() - sess = conn.NewSession(nil) - tests = []struct { - query *dbr.SelectStmt - expectedRes []testBranch - }{ + tests = []defaultBranchTest{ { - query: sess.Select("active_branch() as branch"), - expectedRes: []testBranch{{"new"}}, + query: sess.SelectBySql("select active_branch() as branch"), + expectedRes: []testResult{{"new2"}}, }, } - for _, test := range tests { - t.Run(test.query.Query, func(t *testing.T) { - var branch []testBranch - _, err := test.query.LoadContext(context.Background(), &branch) - assert.NoError(t, err) - assert.ElementsMatch(t, branch, test.expectedRes) - }) + runDefaultBranchTests(t, tests, conn) + + conn, err = dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) + require.NoError(t, err) + sess = conn.NewSession(nil) + + tests = []defaultBranchTest{ + { + query: sess.SelectBySql("set GLOBAL dolt_default_branch = 'doesNotExist'"), + expectedRes: nil, + }, } - var res []struct { - int + runDefaultBranchTests(t, tests, conn) + + conn, err = dbr.Open("mysql", ConnectionString(serverConfig, dbName), nil) + require.NoError(t, err) + sess = conn.NewSession(nil) + + tests = []defaultBranchTest{ + { + query: sess.SelectBySql("select active_branch() as branch"), + expectedErrStr: "cannot resolve default branch head", // TODO: should be a better error message + }, } - sess.SelectBySql("set GLOBAL dolt_default_branch = ''").LoadContext(context.Background(), &res) + + runDefaultBranchTests(t, tests, conn) +} + +func runDefaultBranchTests(t *testing.T, tests []defaultBranchTest, conn *dbr.Connection) { + for _, test := range tests { + t.Run(test.query.Query, func(t *testing.T) { + var branch []testResult + _, err := test.query.LoadContext(context.Background(), &branch) + if test.expectedErrStr != "" { + require.Error(t, err) + assert.Containsf(t, err.Error(), test.expectedErrStr, "expected error string not found") + } else { + require.NoError(t, err) + assert.Equal(t, test.expectedRes, branch) + } + }) + } + require.NoError(t, conn.Close()) } func TestReadReplica(t *testing.T) { diff --git a/go/go.mod b/go/go.mod index 23b41577ba..ef1f825e64 100644 --- a/go/go.mod +++ b/go/go.mod @@ -59,7 +59,7 @@ require ( github.com/cespare/xxhash v1.1.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.15.1-0.20230606174340-20dde39da840 + github.com/dolthub/go-mysql-server v0.15.1-0.20230607160120-febad34dabd4 github.com/dolthub/swiss v0.1.0 github.com/goccy/go-json v0.10.2 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/go/go.sum b/go/go.sum index d45196d791..1988b14a72 100644 --- a/go/go.sum +++ b/go/go.sum @@ -168,8 +168,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.15.1-0.20230606174340-20dde39da840 h1:LNrH1zxCioDcpTMkTV/cr3LI4AlyxrmILI17OUbijUU= -github.com/dolthub/go-mysql-server v0.15.1-0.20230606174340-20dde39da840/go.mod h1:TP8QrAsULBEK7nP0BHRSfZ9l8oiAeaRzIREHVk2wnz0= +github.com/dolthub/go-mysql-server v0.15.1-0.20230607160120-febad34dabd4 h1:NT1otNLjn/eIWU4p4VV3jGUSiGrVGwc9qAdcSeQoLPg= +github.com/dolthub/go-mysql-server v0.15.1-0.20230607160120-febad34dabd4/go.mod h1:TP8QrAsULBEK7nP0BHRSfZ9l8oiAeaRzIREHVk2wnz0= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ= diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go index 7572f79a89..4245cac830 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd.go @@ -31,6 +31,7 @@ import ( ) var DoltPath string +var DelvePath string const TestUserName = "Bats Tests" const TestEmailAddress = "bats@email.fake" @@ -45,10 +46,13 @@ func init() { } path = filepath.Clean(path) var err error + DoltPath, err = exec.LookPath(path) if err != nil { log.Printf("did not find dolt binary: %v\n", err.Error()) } + + DelvePath, _ = exec.LookPath("dlv") } // DoltUser is an abstraction for a user account that calls `dolt` CLI @@ -66,8 +70,11 @@ type DoltUser struct { tmpdir string } +var _ DoltCmdable = DoltUser{} +var _ DoltDebuggable = DoltUser{} + func NewDoltUser() (DoltUser, error) { - tmpdir, err := os.MkdirTemp("", "go-sql-server-dirver-") + tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-") if err != nil { return DoltUser{}, err } @@ -91,9 +98,31 @@ func (u DoltUser) DoltCmd(args ...string) *exec.Cmd { cmd := exec.Command(DoltPath, args...) cmd.Dir = u.tmpdir cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir) + ApplyCmdAttributes(cmd) return cmd } +func (u DoltUser) DoltDebug(debuggerPort int, args ...string) *exec.Cmd { + if DelvePath != "" { + dlvArgs := []string{ + fmt.Sprintf("--listen=:%d", debuggerPort), + "--headless", + "--api-version=2", + "--accept-multiclient", + "exec", + DoltPath, + "--", + } + cmd := exec.Command(DelvePath, append(dlvArgs, args...)...) + cmd.Dir = u.tmpdir + cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir) + ApplyCmdAttributes(cmd) + return cmd + } else { + panic("dlv not found") + } +} + func (u DoltUser) DoltExec(args ...string) error { cmd := u.DoltCmd(args...) return cmd.Run() @@ -116,6 +145,9 @@ type RepoStore struct { Dir string } +var _ DoltCmdable = RepoStore{} +var _ DoltDebuggable = RepoStore{} + func (rs RepoStore) MakeRepo(name string) (Repo, error) { path := filepath.Join(rs.Dir, name) err := os.Mkdir(path, 0750) @@ -136,6 +168,12 @@ func (rs RepoStore) DoltCmd(args ...string) *exec.Cmd { return cmd } +func (rs RepoStore) DoltDebug(debuggerPort int, args ...string) *exec.Cmd { + cmd := rs.user.DoltDebug(debuggerPort, args...) + cmd.Dir = rs.Dir + return cmd +} + type Repo struct { user DoltUser Dir string @@ -165,6 +203,7 @@ type SqlServer struct { Done chan struct{} Cmd *exec.Cmd Port int + DebugPort int Output *bytes.Buffer DBName string RecreateCmd func(args ...string) *exec.Cmd @@ -190,12 +229,36 @@ func WithPort(port int) SqlServerOpt { } } +func WithDebugPort(port int) SqlServerOpt { + return func(s *SqlServer) { + s.DebugPort = port + } +} + type DoltCmdable interface { - DoltCmd(...string) *exec.Cmd + DoltCmd(args ...string) *exec.Cmd +} + +type DoltDebuggable interface { + DoltDebug(debuggerPort int, args ...string) *exec.Cmd } func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) { cmd := dc.DoltCmd("sql-server") + return runSqlServerCommand(dc, opts, cmd) +} + +func DebugSqlServer(dc DoltCmdable, debuggerPort int, opts ...SqlServerOpt) (*SqlServer, error) { + ddb, ok := dc.(DoltDebuggable) + if !ok { + return nil, fmt.Errorf("%T does not implement DoltDebuggable", dc) + } + + cmd := ddb.DoltDebug(debuggerPort, "sql-server") + return runSqlServerCommand(dc, append(opts, WithDebugPort(debuggerPort)), cmd) +} + +func runSqlServerCommand(dc DoltCmdable, opts []SqlServerOpt, cmd *exec.Cmd) (*SqlServer, error) { stdout, err := cmd.StdoutPipe() if err != nil { return nil, err @@ -213,27 +276,34 @@ func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) { wg.Wait() close(done) }() - ret := &SqlServer{ + + server := &SqlServer{ Done: done, Cmd: cmd, Port: 3306, Output: output, - RecreateCmd: func(args ...string) *exec.Cmd { - return dc.DoltCmd(args...) - }, } for _, o := range opts { - o(ret) + o(server) } - err = ret.Cmd.Start() + + server.RecreateCmd = func(args ...string) *exec.Cmd { + if server.DebugPort > 0 { + ddb, ok := dc.(DoltDebuggable) + if !ok { + panic(fmt.Sprintf("%T does not implement DoltDebuggable", dc)) + } + return ddb.DoltDebug(server.DebugPort, args...) + } else { + return dc.DoltCmd(args...) + } + } + + err = server.Cmd.Start() if err != nil { return nil, err } - return ret, nil -} - -func (r Repo) StartSqlServer(opts ...SqlServerOpt) (*SqlServer, error) { - return StartSqlServer(r, opts...) + return server, nil } func (s *SqlServer) ErrorStop() error { diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_unix.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_unix.go index b8b6ed6973..ddb3bc3423 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_unix.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_unix.go @@ -17,7 +17,14 @@ package sql_server_driver -import "syscall" +import ( + "os/exec" + "syscall" +) + +func ApplyCmdAttributes(cmd *exec.Cmd) { + // nothing to do on unix / darwin +} func (s *SqlServer) GracefulStop() error { err := s.Cmd.Process.Signal(syscall.SIGTERM) diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_windows.go b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_windows.go old mode 100644 new mode 100755 index b365f0e972..649c895d32 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_windows.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/cmd_windows.go @@ -15,61 +15,32 @@ package sql_server_driver import ( + "os/exec" "syscall" "golang.org/x/sys/windows" ) +func ApplyCmdAttributes(cmd *exec.Cmd) { + // Creating a new process group for the process will allow GracefulStop to send the break signal to that process + // without also killing the parent process + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + func (s *SqlServer) GracefulStop() error { - dll, err := windows.LoadDLL("kernel32.dll") - if err != nil { - return err - } - defer dll.Release() - - pid := s.Cmd.Process.Pid - - f, err := dll.FindProc("AttachConsole") - if err != nil { - return err - } - r1, _, err := f.Call(uintptr(pid)) - if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED { - return err - } - - set, err := dll.FindProc("SetConsoleCtrlHandler") - if err != nil { - return err - } - r1, _, err = set.Call(0, 1) - if r1 == 0 { - return err - } - f, err = dll.FindProc("GenerateConsoleCtrlEvent") - if err != nil { - return err - } - r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid)) - if r1 == 0 { - return err - } - - f, err = dll.FindProc("FreeConsole") - if err != nil { - return err - } - _, _, err = f.Call() + err := windows.GenerateConsoleCtrlEvent(windows.CTRL_BREAK_EVENT, uint32(s.Cmd.Process.Pid)) if err != nil { return err } <-s.Done - r1, _, err = set.Call(0, 0) - if r1 == 0 { + _, err = s.Cmd.Process.Wait() + if err != nil { return err } - return s.Cmd.Wait() + return nil } diff --git a/go/libraries/doltcore/dtestutils/sql_server_driver/server.go b/go/libraries/doltcore/dtestutils/sql_server_driver/server.go index 766e117476..00178c2937 100644 --- a/go/libraries/doltcore/dtestutils/sql_server_driver/server.go +++ b/go/libraries/doltcore/dtestutils/sql_server_driver/server.go @@ -162,6 +162,10 @@ type Server struct { // the |Args| to make sure this is true. Defaults to 3308. Port int `yaml:"port"` + // DebugPort if set to a non-zero value will cause this server to be started with |dlv| listening for a debugger + // connection on the port given. + DebugPort int `yaml:"debug_port"` + // Assertions to be run against the log output of the server process // after the server process successfully terminates. LogMatches []string `yaml:"log_matches"` diff --git a/go/libraries/doltcore/merge/merge.go b/go/libraries/doltcore/merge/merge.go index 794db05949..2082072d3e 100644 --- a/go/libraries/doltcore/merge/merge.go +++ b/go/libraries/doltcore/merge/merge.go @@ -158,9 +158,6 @@ func MergeRoots( mergedRoot := ourRoot - optsWithFKChecks := opts - optsWithFKChecks.ForeignKeyChecksDisabled = true - // Merge tables one at a time. This is done based on name. With table names from ourRoot being merged first, // renaming a table will return delete/modify conflict error consistently. // TODO: merge based on a more durable table identity that persists across renames diff --git a/go/libraries/doltcore/rebase/filter_branch_test.go b/go/libraries/doltcore/rebase/filter_branch_test.go index 5f15eabd44..3b1088910e 100644 --- a/go/libraries/doltcore/rebase/filter_branch_test.go +++ b/go/libraries/doltcore/rebase/filter_branch_test.go @@ -231,12 +231,13 @@ func testFilterBranch(t *testing.T, test filterBranchTest) { require.Equal(t, 0, exitCode) } - root, err := dEnv.WorkingRoot(ctx) - require.NoError(t, err) + t.Run(a.query, func(t *testing.T) { + root, err := dEnv.WorkingRoot(ctx) + require.NoError(t, err) - actRows, err := sqle.ExecuteSelect(dEnv, root, a.query) - require.NoError(t, err) - - require.Equal(t, a.rows, actRows) + actRows, err := sqle.ExecuteSelect(dEnv, root, a.query) + require.NoError(t, err) + require.Equal(t, a.rows, actRows) + }) } } diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go index 8091650a69..f884d61453 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_controller.go @@ -443,7 +443,7 @@ func getOptionValueAsTableNames(option binlogreplication.ReplicationOption) ([]s func verifyAllTablesAreQualified(urts []sql.UnresolvedTable) error { for _, urt := range urts { - if urt.Database() == "" { + if urt.Database().Name() == "" { return fmt.Errorf("no database specified for table '%s'; "+ "all filter table names must be qualified with a database name", urt.Name()) } diff --git a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_filtering.go b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_filtering.go index 7ee5277b5c..5f115bc614 100644 --- a/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_filtering.go +++ b/go/libraries/doltcore/sqle/binlogreplication/binlog_replica_filtering.go @@ -59,7 +59,7 @@ func (fc *filterConfiguration) setDoTables(urts []sql.UnresolvedTable) error { for _, urt := range urts { table := strings.ToLower(urt.Name()) - db := strings.ToLower(urt.Database()) + db := strings.ToLower(urt.Database().Name()) if fc.doTables[db] == nil { fc.doTables[db] = make(map[string]struct{}) } @@ -86,7 +86,7 @@ func (fc *filterConfiguration) setIgnoreTables(urts []sql.UnresolvedTable) error for _, urt := range urts { table := strings.ToLower(urt.Name()) - db := strings.ToLower(urt.Database()) + db := strings.ToLower(urt.Database().Name()) if fc.ignoreTables[db] == nil { fc.ignoreTables[db] = make(map[string]struct{}) } diff --git a/go/libraries/doltcore/sqle/clusterdb/database.go b/go/libraries/doltcore/sqle/clusterdb/database.go index 501199dc89..4005cb6921 100644 --- a/go/libraries/doltcore/sqle/clusterdb/database.go +++ b/go/libraries/doltcore/sqle/clusterdb/database.go @@ -100,7 +100,7 @@ func (database) IsReadOnly() bool { return true } -func (db database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) { +func (db database) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) { // TODO: almost none of this state is actually used, but is necessary because the current session setup requires a // repo state writer return dsess.InitialDbState{ @@ -132,23 +132,25 @@ func (db database) Revision() string { return "" } +func (db database) Versioned() bool { + return false +} + func (db database) RevisionType() dsess.RevisionType { return dsess.RevisionTypeNone } -func (db database) BaseName() string { +func (db database) RevisionQualifiedName() string { + return db.Name() +} + +func (db database) RequestedName() string { return db.Name() } type noopRepoStateWriter struct{} -func (n noopRepoStateWriter) UpdateStagedRoot(ctx context.Context, newRoot *doltdb.RootValue) error { - return nil -} - -func (n noopRepoStateWriter) UpdateWorkingRoot(ctx context.Context, newRoot *doltdb.RootValue) error { - return nil -} +var _ env.RepoStateWriter = noopRepoStateWriter{} func (n noopRepoStateWriter) SetCWBHeadRef(ctx context.Context, marshalableRef ref.MarshalableRef) error { return nil diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 53cfccfc68..3ef7e44392 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -48,14 +48,15 @@ var ErrSystemTableAlter = errors.NewKind("Cannot alter table %s: system tables c // Database implements sql.Database for a dolt DB. type Database struct { - name string - ddb *doltdb.DoltDB - rsr env.RepoStateReader - rsw env.RepoStateWriter - gs globalstate.GlobalState - editOpts editor.Options - revision string - revType dsess.RevisionType + baseName string + requestedName string + ddb *doltdb.DoltDB + rsr env.RepoStateReader + rsw env.RepoStateWriter + gs globalstate.GlobalState + editOpts editor.Options + revision string + revType dsess.RevisionType } var _ dsess.SqlDatabase = Database{} @@ -74,6 +75,7 @@ var _ sql.TriggerDatabase = Database{} var _ sql.VersionedDatabase = Database{} var _ sql.ViewDatabase = Database{} var _ sql.EventDatabase = Database{} +var _ sql.AliasedDatabase = Database{} type ReadOnlyDatabase struct { Database @@ -86,8 +88,8 @@ func (r ReadOnlyDatabase) IsReadOnly() bool { return true } -func (r ReadOnlyDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) { - return initialDBState(ctx, r, branch) +func (r ReadOnlyDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) { + return initialDBState(ctx, r, r.revision) } // Revision implements dsess.RevisionDatabase @@ -95,13 +97,12 @@ func (db Database) Revision() string { return db.revision } -func (db Database) RevisionType() dsess.RevisionType { - return db.revType +func (db Database) Versioned() bool { + return true } -func (db Database) BaseName() string { - base, _ := dsess.SplitRevisionDbName(db) - return base +func (db Database) RevisionType() dsess.RevisionType { + return db.revType } func (db Database) EditOptions() editor.Options { @@ -116,12 +117,13 @@ func NewDatabase(ctx context.Context, name string, dbData env.DbData, editOpts e } return Database{ - name: name, - ddb: dbData.Ddb, - rsr: dbData.Rsr, - rsw: dbData.Rsw, - gs: globalState, - editOpts: editOpts, + baseName: name, + requestedName: name, + ddb: dbData.Ddb, + rsr: dbData.Rsr, + rsw: dbData.Rsw, + gs: globalState, + editOpts: editOpts, }, nil } @@ -135,13 +137,32 @@ func initialDBState(ctx *sql.Context, db dsess.SqlDatabase, branch string) (dses return initialDbState(ctx, db, branch) } -func (db Database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) { - return initialDBState(ctx, db, branch) +func (db Database) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) { + return initialDBState(ctx, db, db.revision) } // Name returns the name of this database, set at creation time. func (db Database) Name() string { - return db.name + return db.RequestedName() +} + +// AliasedName is what allows databases named e.g. `mydb/b1` to work with the grant and info schema tables that expect +// a base (no revision qualifier) db name +func (db Database) AliasedName() string { + return db.baseName +} + +// RevisionQualifiedName returns the name of this database including its revision qualifier, if any. This method should +// be used whenever accessing internal state of a database and its tables. +func (db Database) RevisionQualifiedName() string { + if db.revision == "" { + return db.baseName + } + return db.baseName + dsess.DbRevisionDelimiter + db.revision +} + +func (db Database) RequestedName() string { + return db.requestedName } // GetDoltDB gets the underlying DoltDB of the Database @@ -177,7 +198,7 @@ func (db Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Ta // We start by first checking whether the input table is a temporary table. Temporary tables with name `x` take // priority over persisted tables of name `x`. ds := dsess.DSessFromSess(ctx.Session) - if tbl, ok := ds.GetTemporaryTable(ctx, db.Name(), tblName); ok { + if tbl, ok := ds.GetTemporaryTable(ctx, db.RevisionQualifiedName(), tblName); ok { return tbl, ok, nil } @@ -256,15 +277,15 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds case strings.HasPrefix(lwrName, doltdb.DoltDiffTablePrefix): if head == nil { var err error - head, err = ds.GetHeadCommit(ctx, db.Name()) + head, err = ds.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } } - suffix := tblName[len(doltdb.DoltDiffTablePrefix):] - dt, err := dtables.NewDiffTable(ctx, suffix, db.ddb, root, head) + tableName := tblName[len(doltdb.DoltDiffTablePrefix):] + dt, err := dtables.NewDiffTable(ctx, tableName, db.ddb, root, head) if err != nil { return nil, false, err } @@ -290,7 +311,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds if head == nil { var err error - head, err = ds.GetHeadCommit(ctx, db.Name()) + head, err = ds.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } @@ -327,7 +348,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds case doltdb.LogTableName: if head == nil { var err error - head, err = ds.GetHeadCommit(ctx, db.Name()) + head, err = ds.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } @@ -337,29 +358,29 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds case doltdb.DiffTableName: if head == nil { var err error - head, err = ds.GetHeadCommit(ctx, db.Name()) + head, err = ds.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } } - dt, found = dtables.NewUnscopedDiffTable(ctx, db.name, db.ddb, head), true + dt, found = dtables.NewUnscopedDiffTable(ctx, db.RevisionQualifiedName(), db.ddb, head), true case doltdb.ColumnDiffTableName: if head == nil { var err error - head, err = ds.GetHeadCommit(ctx, db.Name()) + head, err = ds.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } } - dt, found = dtables.NewColumnDiffTable(ctx, db.name, db.ddb, head), true + dt, found = dtables.NewColumnDiffTable(ctx, db.RevisionQualifiedName(), db.ddb, head), true case doltdb.TableOfTablesInConflictName: - dt, found = dtables.NewTableOfTablesInConflict(ctx, db.name, db.ddb), true + dt, found = dtables.NewTableOfTablesInConflict(ctx, db.RevisionQualifiedName(), db.ddb), true case doltdb.TableOfTablesWithViolationsName: dt, found = dtables.NewTableOfTablesConstraintViolations(ctx, root), true case doltdb.SchemaConflictsTableName: - dt, found = dtables.NewSchemaConflictsTable(ctx, db.name, db.ddb), true + dt, found = dtables.NewSchemaConflictsTable(ctx, db.RevisionQualifiedName(), db.ddb), true case doltdb.BranchesTableName: dt, found = dtables.NewBranchesTable(ctx, db), true case doltdb.RemoteBranchesTableName: @@ -373,17 +394,17 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds case doltdb.StatusTableName: sess := dsess.DSessFromSess(ctx.Session) adapter := dsess.NewSessionStateAdapter( - sess, db.name, + sess, db.RevisionQualifiedName(), map[string]env.Remote{}, map[string]env.BranchConfig{}, map[string]env.Remote{}) - ws, err := sess.WorkingSet(ctx, db.name) + ws, err := sess.WorkingSet(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } - dt, found = dtables.NewStatusTable(ctx, db.name, db.ddb, ws, adapter), true + dt, found = dtables.NewStatusTable(ctx, db.ddb, ws, adapter), true case doltdb.MergeStatusTableName: - dt, found = dtables.NewMergeStatusTable(db.name), true + dt, found = dtables.NewMergeStatusTable(db.RevisionQualifiedName()), true case doltdb.TagsTableName: dt, found = dtables.NewTagsTable(ctx, db.ddb), true case dtables.AccessTableName: @@ -412,6 +433,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds return dt, found, nil } + // TODO: this should reuse the root, not lookup the db state again return db.getTable(ctx, root, tblName) } @@ -533,15 +555,15 @@ func (db Database) GetTableNamesAsOf(ctx *sql.Context, time interface{}) ([]stri return filterDoltInternalTables(tblNames), nil } -// getTable returns the user table with the given name from the root given +// getTable returns the user table with the given baseName from the root given func (db Database) getTable(ctx *sql.Context, root *doltdb.RootValue, tableName string) (sql.Table, bool, error) { sess := dsess.DSessFromSess(ctx.Session) - dbState, ok, err := sess.LookupDbState(ctx, db.name) + dbState, ok, err := sess.LookupDbState(ctx, db.RevisionQualifiedName()) if err != nil { return nil, false, err } if !ok { - return nil, false, fmt.Errorf("no state for database %s", db.name) + return nil, false, fmt.Errorf("no state for database %s", db.RevisionQualifiedName()) } key, err := doltdb.NewDataCacheKey(root) @@ -637,7 +659,7 @@ func filterDoltInternalTables(tblNames []string) []string { // GetRoot returns the root value for this database session func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) { sess := dsess.DSessFromSess(ctx.Session) - dbState, ok, err := sess.LookupDbState(ctx, db.Name()) + dbState, ok, err := sess.LookupDbState(ctx, db.RevisionQualifiedName()) if err != nil { return nil, err } @@ -645,7 +667,7 @@ func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) { return nil, fmt.Errorf("no root value found in session") } - return dbState.GetRoots().Working, nil + return dbState.WorkingRoot(), nil } // GetWorkingSet gets the current working set for the database. @@ -656,30 +678,30 @@ func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) { // where users avoid handling the WorkingSet directly. func (db Database) GetWorkingSet(ctx *sql.Context) (*doltdb.WorkingSet, error) { sess := dsess.DSessFromSess(ctx.Session) - dbState, ok, err := sess.LookupDbState(ctx, db.Name()) + dbState, ok, err := sess.LookupDbState(ctx, db.RevisionQualifiedName()) if err != nil { return nil, err } if !ok { return nil, fmt.Errorf("no root value found in session") } - if dbState.WorkingSet == nil { + if dbState.WorkingSet() == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } - return dbState.WorkingSet, nil + return dbState.WorkingSet(), nil } // SetRoot should typically be called on the Session, which is where this state lives. But it's available here as a // convenience. func (db Database) SetRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error { sess := dsess.DSessFromSess(ctx.Session) - return sess.SetRoot(ctx, db.name, newRoot) + return sess.SetRoot(ctx, db.RevisionQualifiedName(), newRoot) } // GetHeadRoot returns root value for the current session head func (db Database) GetHeadRoot(ctx *sql.Context) (*doltdb.RootValue, error) { sess := dsess.DSessFromSess(ctx.Session) - head, err := sess.GetHeadCommit(ctx, db.name) + head, err := sess.GetHeadCommit(ctx, db.RevisionQualifiedName()) if err != nil { return nil, err } @@ -689,7 +711,7 @@ func (db Database) GetHeadRoot(ctx *sql.Context) (*doltdb.RootValue, error) { // DropTable drops the table with the name given. // The planner returns the correct case sensitive name in tableName func (db Database) DropTable(ctx *sql.Context, tableName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } if doltdb.IsNonAlterableSystemTable(tableName) { @@ -699,7 +721,7 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error { return db.dropTable(ctx, tableName) } -// dropTable drops the table with the name given, without any business logic checks +// dropTable drops the table with the baseName given, without any business logic checks func (db Database) dropTable(ctx *sql.Context, tableName string) error { ds := dsess.DSessFromSess(ctx.Session) if _, ok := ds.GetTemporaryTable(ctx, db.Name(), tableName); ok { @@ -733,7 +755,7 @@ func (db Database) dropTable(ctx *sql.Context, tableName string) error { } if schema.HasAutoIncrement(sch) { - ddb, _ := ds.GetDoltDB(ctx, db.name) + ddb, _ := ds.GetDoltDB(ctx, db.RevisionQualifiedName()) err = db.removeTableFromAutoIncrementTracker(ctx, tableName, ddb, ws.Ref()) if err != nil { return err @@ -796,7 +818,7 @@ func (db Database) removeTableFromAutoIncrementTracker( // CreateTable creates a table with the name and schema given. func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, collation sql.CollationID) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } if strings.ToLower(tableName) == doltdb.DocTableName { @@ -817,7 +839,7 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima // CreateIndexedTable creates a table with the name and schema given. func (db Database) CreateIndexedTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } if strings.ToLower(tableName) == doltdb.DocTableName { @@ -836,7 +858,7 @@ func (db Database) CreateIndexedTable(ctx *sql.Context, tableName string, sch sq return db.createIndexedSqlTable(ctx, tableName, sch, idxDef, collation) } -// Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks. +// createSqlTable is the private version of CreateTable. It doesn't enforce any table name checks. func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, collation sql.CollationID) error { ws, err := db.GetWorkingSet(ctx) if err != nil { @@ -878,7 +900,7 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Pr return db.createDoltTable(ctx, tableName, root, doltSch) } -// Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks. +// createIndexedSqlTable is the private version of createSqlTable. It doesn't enforce any table name checks. func (db Database) createIndexedSqlTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error { ws, err := db.GetWorkingSet(ctx) if err != nil { @@ -926,7 +948,7 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, tableName string, sch return db.createDoltTable(ctx, tableName, root, doltSch) } -// createDoltTable creates a table on the database using the given dolt schema while not enforcing table name checks. +// createDoltTable creates a table on the database using the given dolt schema while not enforcing table baseName checks. func (db Database) createDoltTable(ctx *sql.Context, tableName string, root *doltdb.RootValue, doltSch schema.Schema) error { if exists, err := root.HasTable(ctx, tableName); err != nil { return err @@ -969,19 +991,19 @@ func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, pkSc return ErrInvalidTableName.New(tableName) } - tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, db.name, db.editOpts, collation) + tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, db.RevisionQualifiedName(), db.editOpts, collation) if err != nil { return err } ds := dsess.DSessFromSess(ctx.Session) - ds.AddTemporaryTable(ctx, db.Name(), tmp) + ds.AddTemporaryTable(ctx, db.RevisionQualifiedName(), tmp) return nil } // RenameTable implements sql.TableRenamer func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } root, err := db.GetRoot(ctx) @@ -1018,11 +1040,11 @@ func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error // Flush flushes the current batch of outstanding changes and returns any errors. func (db Database) Flush(ctx *sql.Context) error { sess := dsess.DSessFromSess(ctx.Session) - dbState, _, err := sess.LookupDbState(ctx, db.Name()) + dbState, _, err := sess.LookupDbState(ctx, db.RevisionQualifiedName()) if err != nil { return err } - editSession := dbState.WriteSession + editSession := dbState.WriteSession() ws, err := editSession.Flush(ctx) if err != nil { @@ -1056,7 +1078,7 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie } ds := dsess.DSessFromSess(ctx.Session) - dbState, _, err := ds.LookupDbState(ctx, db.name) + dbState, _, err := ds.LookupDbState(ctx, db.RevisionQualifiedName()) if err != nil { return sql.ViewDefinition{}, false, err } @@ -1138,7 +1160,7 @@ func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) { // it can exist in a sql session later. Returns sql.ErrExistingView if a view // with that name already exists. func (db Database) CreateView(ctx *sql.Context, name string, selectStatement, createViewStmt string) error { - err := sql.ErrExistingView.New(db.name, name) + err := sql.ErrExistingView.New(db.Name(), name) return db.addFragToSchemasTable(ctx, "view", name, createViewStmt, time.Unix(0, 0).UTC(), err) } @@ -1146,7 +1168,7 @@ func (db Database) CreateView(ctx *sql.Context, name string, selectStatement, cr // dolt database. Returns sql.ErrNonExistingView if the view did not // exist. func (db Database) DropView(ctx *sql.Context, name string) error { - err := sql.ErrViewDoesNotExist.New(db.name, name) + err := sql.ErrViewDoesNotExist.New(db.baseName, name) return db.dropFragFromSchemasTable(ctx, "view", name, err) } @@ -1295,7 +1317,7 @@ func (db Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureD // SaveStoredProcedure implements sql.StoredProcedureDatabase. func (db Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } return DoltProceduresAddProcedure(ctx, db, spd) @@ -1303,14 +1325,14 @@ func (db Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedure // DropStoredProcedure implements sql.StoredProcedureDatabase. func (db Database) DropStoredProcedure(ctx *sql.Context, name string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } return DoltProceduresDropProcedure(ctx, db, name) } func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, definition string, created time.Time, existingErr error) (err error) { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } tbl, err := getOrCreateDoltSchemasTable(ctx, db) @@ -1347,7 +1369,7 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin } func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } @@ -1436,7 +1458,7 @@ func (db Database) GetCollation(ctx *sql.Context) sql.CollationID { // SetCollation implements the interface sql.CollatedDatabase. func (db Database) SetCollation(ctx *sql.Context, collation sql.CollationID) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil { return err } if collation == sql.Collation_Unspecified { diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 2b3a50d16e..30bd0dc30c 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -171,7 +171,9 @@ func (p DoltDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.File p.mu.Lock() defer p.mu.Unlock() - dbLocation, ok := p.dbLocations[dbname] + baseName, _ := dsess.SplitRevisionDbName(dbname) + + dbLocation, ok := p.dbLocations[strings.ToLower(baseName)] if !ok { return nil, sql.ErrDatabaseNotFound.New(dbname) } @@ -255,17 +257,27 @@ func (p DoltDatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool { } func (p DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Database) { - p.mu.RLock() + currentDb := ctx.GetCurrentDatabase() + currBase, currRev := dsess.SplitRevisionDbName(currentDb) + p.mu.RLock() showBranches, _ := dsess.GetBooleanSystemVar(ctx, dsess.ShowBranchDatabases) all = make([]sql.Database, 0, len(p.databases)) - var foundDatabase bool - currDb := strings.ToLower(ctx.GetCurrentDatabase()) for _, db := range p.databases { - if strings.ToLower(db.Name()) == currDb { - foundDatabase = true + base, _ := dsess.SplitRevisionDbName(db.Name()) + + // If there's a revision database in use, swap that one in for its base db, but keep the same name + if currRev != "" && strings.ToLower(currBase) == strings.ToLower(base) { + rdb, ok, err := p.databaseForRevision(ctx, currentDb, currBase) + if err != nil || !ok { + // TODO: this interface is wrong, needs to return errors + ctx.GetLogger().Warnf("error fetching revision databases: %s", err.Error()) + } else { + db = rdb + } } + all = append(all, db) if showBranches { @@ -276,33 +288,10 @@ func (p DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Database continue } all = append(all, revisionDbs...) - - // if one of the revisions we just expanded matches the curr db, mark it so we don't double-include that - // revision db - if !foundDatabase && currDb != "" { - for _, revisionDb := range revisionDbs { - if strings.ToLower(revisionDb.Name()) == currDb { - foundDatabase = true - } - } - } } } p.mu.RUnlock() - // If the current database is not one of the primary databases, it must be a transitory revision database - if !foundDatabase && currDb != "" { - revDb, ok, err := p.databaseForRevision(ctx, currDb) - if err != nil { - // We can't return an error from this interface function, so just log a message - ctx.GetLogger().Warnf("unable to load %q as a database revision: %s", ctx.GetCurrentDatabase(), err.Error()) - } else if !ok { - ctx.GetLogger().Warnf("unable to load %q as a database revision", ctx.GetCurrentDatabase()) - } else { - all = append(all, revDb) - } - } - // Because we store databases in a map, sort to get a consistent ordering sort.Slice(all, func(i, j int) bool { return strings.ToLower(all[i].Name()) < strings.ToLower(all[j].Name()) @@ -339,7 +328,7 @@ func (p DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db dsess.SqlDatab revDbs := make([]sql.Database, len(branches)) for i, branch := range branches { - revDb, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath())) + revDb, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath()), db.Name()) if err != nil { return nil, err } @@ -388,16 +377,7 @@ func (p DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name stri sess := dsess.DSessFromSess(ctx.Session) newEnv := env.Load(ctx, env.GetCurrentUserHomeDir, newFs, p.dbFactoryUrl, "TODO") - // if currentDB is empty, it will create the database with the default format which is the old format newDbStorageFormat := types.Format_Default - if curDB := sess.GetCurrentDatabase(); curDB != "" { - if sess.HasDB(ctx, curDB) { - if ddb, ok := sess.GetDoltDB(ctx, curDB); ok { - newDbStorageFormat = ddb.ValueReadWriter().Format() - } - } - } - err = newEnv.InitRepo(ctx, newDbStorageFormat, sess.Username(), sess.Email(), p.defaultBranch) if err != nil { return err @@ -616,11 +596,8 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote( // DropDatabase implements the sql.MutableDatabaseProvider interface func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error { - isRevisionDatabase, err := p.isRevisionDatabase(ctx, name) - if err != nil { - return err - } - if isRevisionDatabase { + _, revision := dsess.SplitRevisionDbName(name) + if revision != "" { return fmt.Errorf("unable to drop revision database: %s", name) } @@ -635,7 +612,7 @@ func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error db := p.databases[dbKey] ddb := db.(Database).ddb - err = ddb.Close() + err := ddb.Close() if err != nil { return err } @@ -736,27 +713,30 @@ func (p DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context, n return nil } -func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string) (dsess.SqlDatabase, bool, error) { - if !strings.Contains(revDB, dsess.DbRevisionDelimiter) { +func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revisionQualifiedName string, requestedName string) (dsess.SqlDatabase, bool, error) { + if !strings.Contains(revisionQualifiedName, dsess.DbRevisionDelimiter) { return nil, false, nil } - parts := strings.SplitN(revDB, dsess.DbRevisionDelimiter, 2) - dbName, revSpec := parts[0], parts[1] + parts := strings.SplitN(revisionQualifiedName, dsess.DbRevisionDelimiter, 2) + baseName, rev := parts[0], parts[1] + + // Look in the session cache for this DB before doing any IO to figure out what's being asked for + sess := dsess.DSessFromSess(ctx.Session) + dbCache := sess.DatabaseCache(ctx) + db, ok := dbCache.GetCachedRevisionDb(revisionQualifiedName, requestedName) + if ok { + return db, true, nil + } p.mu.RLock() - candidate, ok := p.databases[formatDbMapKeyName(dbName)] + srcDb, ok := p.databases[formatDbMapKeyName(baseName)] p.mu.RUnlock() if !ok { return nil, false, nil } - srcDb, ok := candidate.(dsess.SqlDatabase) - if !ok { - return nil, false, nil - } - - dbType, resolvedRevSpec, err := revisionDbType(ctx, srcDb, revSpec) + dbType, resolvedRevSpec, err := revisionDbType(ctx, srcDb, rev) if err != nil { return nil, false, err } @@ -764,7 +744,8 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string switch dbType { case dsess.RevisionTypeBranch: // fetch the upstream head if this is a replicated db - if replicaDb, ok := srcDb.(ReadReplicaDatabase); ok { + replicaDb, ok := srcDb.(ReadReplicaDatabase) + if ok && replicaDb.ValidReplicaState(ctx) { // TODO move this out of analysis phase, should only happen at read time, when the transaction begins (like is // the case with a branch that already exists locally) err := p.ensureReplicaHeadExists(ctx, resolvedRevSpec, replicaDb) @@ -773,14 +754,15 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string } } - db, err := revisionDbForBranch(ctx, srcDb, resolvedRevSpec) + db, err := revisionDbForBranch(ctx, srcDb, resolvedRevSpec, requestedName) // preserve original user case in the case of not found if sql.ErrDatabaseNotFound.Is(err) { - return nil, false, sql.ErrDatabaseNotFound.New(revDB) + return nil, false, sql.ErrDatabaseNotFound.New(revisionQualifiedName) } else if err != nil { return nil, false, err } + dbCache.CacheRevisionDb(db) return db, true, nil case dsess.RevisionTypeTag: // TODO: this should be an interface, not a struct @@ -795,10 +777,12 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string return nil, false, nil } - db, err := revisionDbForTag(ctx, srcDb.(Database), resolvedRevSpec) + db, err := revisionDbForTag(ctx, srcDb.(Database), resolvedRevSpec, requestedName) if err != nil { return nil, false, err } + + dbCache.CacheRevisionDb(db) return db, true, nil case dsess.RevisionTypeCommit: // TODO: this should be an interface, not a struct @@ -811,16 +795,19 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string if !ok { return nil, false, nil } - db, err := revisionDbForCommit(ctx, srcDb.(Database), revSpec) + db, err := revisionDbForCommit(ctx, srcDb.(Database), rev, requestedName) if err != nil { return nil, false, err } + + dbCache.CacheRevisionDb(db) return db, true, nil case dsess.RevisionTypeNone: - // not an error, ok = false will get handled as a not found error in a layer above as appropriate - return nil, false, nil + // Returning an error with the fully qualified db name here is our only opportunity to do so in some cases (such + // as when a branch is deleted by another client) + return nil, false, sql.ErrDatabaseNotFound.New(revisionQualifiedName) default: - return nil, false, fmt.Errorf("unrecognized revision type for revision spec %s", revSpec) + return nil, false, fmt.Errorf("unrecognized revision type for revision spec %s", rev) } } @@ -1051,26 +1038,47 @@ func resolveAncestorSpec(ctx *sql.Context, revSpec string, ddb *doltdb.DoltDB) ( return hash.String(), nil } -// SessionDatabase implements dsess.SessionDatabaseProvider -func (p DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (dsess.SqlDatabase, bool, error) { +// BaseDatabase returns the base database for the specified database name. Meant for informational purposes when +// managing the session initialization only. Use SessionDatabase for normal database retrieval. +func (p DoltDatabaseProvider) BaseDatabase(ctx *sql.Context, name string) (dsess.SqlDatabase, bool) { + baseName := name + isRevisionDbName := strings.Contains(name, dsess.DbRevisionDelimiter) + + if isRevisionDbName { + parts := strings.SplitN(name, dsess.DbRevisionDelimiter, 2) + baseName = parts[0] + } + var ok bool p.mu.RLock() - db, ok := p.databases[formatDbMapKeyName(name)] + db, ok := p.databases[strings.ToLower(baseName)] + p.mu.RUnlock() + + return db, ok +} + +// SessionDatabase implements dsess.SessionDatabaseProvider +func (p DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (dsess.SqlDatabase, bool, error) { + baseName := name + isRevisionDbName := strings.Contains(name, dsess.DbRevisionDelimiter) + + if isRevisionDbName { + // TODO: formalize and enforce this rule (can't allow DBs with / in the name) + // TODO: some connectors will take issue with the /, we need other mechanisms to support them + parts := strings.SplitN(name, dsess.DbRevisionDelimiter, 2) + baseName = parts[0] + } + + var ok bool + p.mu.RLock() + db, ok := p.databases[strings.ToLower(baseName)] standby := *p.isStandby p.mu.RUnlock() - if ok { - return wrapForStandby(db, standby), true, nil - } - // Revision databases aren't tracked in the map, just instantiated on demand - db, ok, err := p.databaseForRevision(ctx, name) - if err != nil { - return nil, false, err - } - - // A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote + // If the database doesn't exist and this is a read replica, attempt to clone it from the remote if !ok { - db, err = p.databaseForClone(ctx, name) + var err error + db, err = p.databaseForClone(ctx, baseName) if err != nil { return nil, false, err @@ -1081,6 +1089,54 @@ func (p DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (ds } } + // Some DB implementations don't support addressing by versioned names, so return directly if we have one of those + if !db.Versioned() { + return wrapForStandby(db, standby), true, nil + } + + // Convert to a revision database before returning. If we got a non-qualified name, convert it to a qualified name + // using the session's current head + revisionQualifiedName := name + usingDefaultBranch := false + head := "" + sess := dsess.DSessFromSess(ctx.Session) + if !isRevisionDbName { + var err error + head, ok, err = sess.CurrentHead(ctx, baseName) + if err != nil { + return nil, false, err + } + + // A newly created session may not have any info on current head stored yet, in which case we get the default + // branch for the db itself instead. + if !ok { + usingDefaultBranch = true + + head, err = dsess.DefaultHead(baseName, db) + if err != nil { + return nil, false, err + } + } + + revisionQualifiedName = baseName + dsess.DbRevisionDelimiter + head + } + + db, ok, err := p.databaseForRevision(ctx, revisionQualifiedName, name) + if err != nil { + if sql.ErrDatabaseNotFound.Is(err) && usingDefaultBranch { + // We can return a better error message here in some cases + // TODO: this better error message doesn't always get returned to clients because the code path is doesn't + // return an error, only a boolean result (HasDB) + return nil, false, fmt.Errorf("cannot resolve default branch head for database '%s': '%s'", baseName, head) + } else { + return nil, false, err + } + } + + if !ok { + return nil, false, nil + } + return wrapForStandby(db, standby), true, nil } @@ -1131,21 +1187,6 @@ func (p DoltDatabaseProvider) TableFunction(_ *sql.Context, name string) (sql.Ta return nil, sql.ErrTableFunctionNotFound.New(name) } -// isRevisionDatabase returns true if the specified dbName represents a database that is tied to a specific -// branch or commit from a database (e.g. "dolt/branch1"). -func (p DoltDatabaseProvider) isRevisionDatabase(ctx *sql.Context, dbName string) (bool, error) { - db, ok, err := p.SessionDatabase(ctx, dbName) - if err != nil { - return false, err - } - if !ok { - return false, sql.ErrDatabaseNotFound.New(dbName) - } - - _, rev := dsess.SplitRevisionDbName(db) - return rev != "", nil -} - // ensureReplicaHeadExists tries to pull the latest version of a remote branch. Will fail if the branch // does not exist on the ReadReplicaDatabase's remote. func (p DoltDatabaseProvider) ensureReplicaHeadExists(ctx *sql.Context, branch string, db ReadReplicaDatabase) error { @@ -1246,9 +1287,8 @@ func isTag(ctx context.Context, db dsess.SqlDatabase, tagName string) (bool, err } // revisionDbForBranch returns a new database that is tied to the branch named by revSpec -func revisionDbForBranch(ctx context.Context, srcDb dsess.SqlDatabase, revSpec string) (dsess.SqlDatabase, error) { +func revisionDbForBranch(ctx context.Context, srcDb dsess.SqlDatabase, revSpec string, requestedName string) (dsess.SqlDatabase, error) { branch := ref.NewBranchRef(revSpec) - dbName := srcDb.Name() + dsess.DbRevisionDelimiter + revSpec static := staticRepoState{ branch: branch, @@ -1256,47 +1296,62 @@ func revisionDbForBranch(ctx context.Context, srcDb dsess.SqlDatabase, revSpec s RepoStateReader: srcDb.DbData().Rsr, } - var db dsess.SqlDatabase + baseName, _ := dsess.SplitRevisionDbName(srcDb.Name()) + // TODO: we need a base name method here switch v := srcDb.(type) { - case Database: - db = Database{ - name: dbName, - ddb: v.ddb, - rsw: static, - rsr: static, - gs: v.gs, - editOpts: v.editOpts, - revision: revSpec, - revType: dsess.RevisionTypeBranch, + case ReadOnlyDatabase: + db := Database{ + baseName: baseName, + requestedName: requestedName, + ddb: v.ddb, + rsw: static, + rsr: static, + gs: v.gs, + editOpts: v.editOpts, + revision: revSpec, + revType: dsess.RevisionTypeBranch, } + return ReadOnlyDatabase{db}, nil + case Database: + return Database{ + baseName: baseName, + requestedName: requestedName, + ddb: v.ddb, + rsw: static, + rsr: static, + gs: v.gs, + editOpts: v.editOpts, + revision: revSpec, + revType: dsess.RevisionTypeBranch, + }, nil case ReadReplicaDatabase: - db = ReadReplicaDatabase{ + return ReadReplicaDatabase{ Database: Database{ - name: dbName, - ddb: v.ddb, - rsw: static, - rsr: static, - gs: v.gs, - editOpts: v.editOpts, - revision: revSpec, - revType: dsess.RevisionTypeBranch, + baseName: baseName, + requestedName: requestedName, + ddb: v.ddb, + rsw: static, + rsr: static, + gs: v.gs, + editOpts: v.editOpts, + revision: revSpec, + revType: dsess.RevisionTypeBranch, }, remote: v.remote, srcDB: v.srcDB, tmpDir: v.tmpDir, limiter: newLimiter(), - } + }, nil + default: + panic(fmt.Sprintf("unrecognized type of database %T", srcDb)) } - - return db, nil } func initialStateForBranchDb(ctx *sql.Context, srcDb dsess.SqlDatabase) (dsess.InitialDbState, error) { - _, revSpec := dsess.SplitRevisionDbName(srcDb) + revSpec := srcDb.Revision() // TODO: this may be a disabled transaction, need to kill those - rootHash, err := dsess.TransactionRoot(ctx, srcDb) if err != nil { return dsess.InitialDbState{}, err @@ -1356,23 +1411,22 @@ func initialStateForBranchDb(ctx *sql.Context, srcDb dsess.SqlDatabase) (dsess.I return init, nil } -func revisionDbForTag(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) { - name := srcDb.Name() + dsess.DbRevisionDelimiter + revSpec - db := ReadOnlyDatabase{Database: Database{ - name: name, - ddb: srcDb.DbData().Ddb, - rsw: srcDb.DbData().Rsw, - rsr: srcDb.DbData().Rsr, - editOpts: srcDb.editOpts, - revision: revSpec, - revType: dsess.RevisionTypeTag, - }} - - return db, nil +func revisionDbForTag(ctx context.Context, srcDb Database, revSpec string, requestedName string) (ReadOnlyDatabase, error) { + baseName, _ := dsess.SplitRevisionDbName(srcDb.Name()) + return ReadOnlyDatabase{Database: Database{ + baseName: baseName, + requestedName: requestedName, + ddb: srcDb.DbData().Ddb, + rsw: srcDb.DbData().Rsw, + rsr: srcDb.DbData().Rsr, + editOpts: srcDb.editOpts, + revision: revSpec, + revType: dsess.RevisionTypeTag, + }}, nil } func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) { - _, revSpec := dsess.SplitRevisionDbName(srcDb) + revSpec := srcDb.Revision() tag := ref.NewTagRef(revSpec) cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, tag) @@ -1399,23 +1453,22 @@ func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.In return init, nil } -func revisionDbForCommit(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) { - name := srcDb.Name() + dsess.DbRevisionDelimiter + revSpec - db := ReadOnlyDatabase{Database: Database{ - name: name, - ddb: srcDb.DbData().Ddb, - rsw: srcDb.DbData().Rsw, - rsr: srcDb.DbData().Rsr, - editOpts: srcDb.editOpts, - revision: revSpec, - revType: dsess.RevisionTypeCommit, - }} - - return db, nil +func revisionDbForCommit(ctx context.Context, srcDb Database, revSpec string, requestedName string) (ReadOnlyDatabase, error) { + baseName, _ := dsess.SplitRevisionDbName(srcDb.Name()) + return ReadOnlyDatabase{Database: Database{ + baseName: baseName, + requestedName: requestedName, + ddb: srcDb.DbData().Ddb, + rsw: srcDb.DbData().Rsw, + rsr: srcDb.DbData().Rsr, + editOpts: srcDb.editOpts, + revision: revSpec, + revType: dsess.RevisionTypeCommit, + }}, nil } func initialStateForCommit(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) { - _, revSpec := dsess.SplitRevisionDbName(srcDb) + revSpec := srcDb.Revision() spec, err := doltdb.NewCommitSpec(revSpec) if err != nil { @@ -1462,6 +1515,7 @@ func (s staticRepoState) CWBHeadRef() (ref.DoltRef, error) { // formatDbMapKeyName returns formatted string of database name and/or branch name. Database name is case-insensitive, // so it's stored in lower case name. Branch name is case-sensitive, so not changed. +// TODO: branch names should be case-insensitive too func formatDbMapKeyName(name string) string { if !strings.Contains(name, dsess.DbRevisionDelimiter) { return strings.ToLower(name) diff --git a/go/libraries/doltcore/sqle/dfunctions/hashof.go b/go/libraries/doltcore/sqle/dfunctions/hashof.go index 296f39f7bd..2ab5b94425 100644 --- a/go/libraries/doltcore/sqle/dfunctions/hashof.go +++ b/go/libraries/doltcore/sqle/dfunctions/hashof.go @@ -74,6 +74,7 @@ func (t *HashOf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if strings.ToUpper(name) == "HEAD" { sess := dsess.DSessFromSess(ctx.Session) + // TODO: this should resolve the current DB through the analyzer so it can use the revision qualified name here cm, err = sess.GetHeadCommit(ctx, dbName) if err != nil { return nil, err diff --git a/go/libraries/doltcore/sqle/dolt_log_table_function.go b/go/libraries/doltcore/sqle/dolt_log_table_function.go index 103aafbfd0..0957083e78 100644 --- a/go/libraries/doltcore/sqle/dolt_log_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_log_table_function.go @@ -386,7 +386,7 @@ func (ltf *LogTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter } } else { // If revisionExpr not defined, use session head - commit, err = sess.GetHeadCommit(ctx, sqledb.Name()) + commit, err = sess.GetHeadCommit(ctx, sqledb.RevisionQualifiedName()) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go b/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go index 5002f8d872..c5e98c7e62 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go @@ -94,10 +94,6 @@ func commitTransaction(ctx *sql.Context, dSess *dsess.DoltSession, rsc *doltdb.R dsess.WaitForReplicationController(ctx, *rsc) } - // Because this transaction manipulation is happening outside the engine's awareness, we need to set it to nil here - // to get a fresh transaction started on the next statement. - // TODO: put this under engine control - ctx.SetTransaction(nil) return nil } @@ -221,11 +217,8 @@ func shouldAllowDefaultBranchDeletion(ctx *sql.Context) bool { // validateBranchNotActiveInAnySessions returns an error if the specified branch is currently // selected as the active branch for any active server sessions. func validateBranchNotActiveInAnySession(ctx *sql.Context, branchName string) error { - currentDbName, _, err := getRevisionForRevisionDatabase(ctx, ctx.GetCurrentDatabase()) - if err != nil { - return err - } - + currentDbName := ctx.GetCurrentDatabase() + currentDbName, _ = dsess.SplitRevisionDbName(currentDbName) if currentDbName == "" { return nil } @@ -242,24 +235,22 @@ func validateBranchNotActiveInAnySession(ctx *sql.Context, branchName string) er branchRef := ref.NewBranchRef(branchName) return sessionManager.Iter(func(session sql.Session) (bool, error) { - dsess, ok := session.(*dsess.DoltSession) + sess, ok := session.(*dsess.DoltSession) if !ok { return false, fmt.Errorf("unexpected session type: %T", session) } - sessionDatabase := dsess.Session.GetCurrentDatabase() - sessionDbName, _, err := getRevisionForRevisionDatabase(ctx, dsess.GetCurrentDatabase()) - if err != nil { - return false, err - } - - if len(sessionDatabase) == 0 || sessionDbName != currentDbName { + sessionDbName := sess.Session.GetCurrentDatabase() + baseName, _ := dsess.SplitRevisionDbName(sessionDbName) + if len(baseName) == 0 || baseName != currentDbName { return false, nil } - activeBranchRef, err := dsess.CWBHeadRef(ctx, sessionDatabase) + activeBranchRef, err := sess.CWBHeadRef(ctx, sessionDbName) if err != nil { - return false, err + // The above will throw an error if the current DB doesn't have a head ref, in which case we don't need to + // consider it + return false, nil } if ref.Equals(branchRef, activeBranchRef) { diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go index cfbc7eb81a..1863a5e933 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go @@ -48,12 +48,6 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { return 1, fmt.Errorf("Empty database name.") } - // non-revision database branchName is used to check out a branch on it. - dbName, _, err := getRevisionForRevisionDatabase(ctx, currentDbName) - if err != nil { - return -1, err - } - apr, err := cli.CreateCheckoutArgParser().Parse(args) if err != nil { return 1, err @@ -65,7 +59,6 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { } dSess := dsess.DSessFromSess(ctx.Session) - // dbData should use the current database data, which can be at revision database. dbData, ok := dSess.GetDbData(ctx, currentDbName) if !ok { return 1, fmt.Errorf("Could not load database %s", currentDbName) @@ -75,7 +68,7 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { // Checking out new branch. if branchOrTrack { - err = checkoutNewBranch(ctx, dbName, dbData, apr, &rsc) + err = checkoutNewBranch(ctx, currentDbName, dbData, apr, &rsc) if err != nil { return 1, err } else { @@ -92,7 +85,7 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { if isBranch, err := actions.IsBranch(ctx, dbData.Ddb, branchName); err != nil { return 1, err } else if isBranch { - err = checkoutBranch(ctx, dbName, branchName) + err = checkoutBranch(ctx, currentDbName, branchName) if errors.Is(err, doltdb.ErrWorkingSetNotFound) { // If there is a branch but there is no working set, // somehow the local branch ref was created without a @@ -108,7 +101,7 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { return 1, err } - err = checkoutBranch(ctx, dbName, branchName) + err = checkoutBranch(ctx, currentDbName, branchName) } if err != nil { return 1, err @@ -116,14 +109,14 @@ func doDoltCheckout(ctx *sql.Context, args []string) (int, error) { return 0, nil } - roots, ok := dSess.GetRoots(ctx, dbName) + roots, ok := dSess.GetRoots(ctx, currentDbName) if !ok { - return 1, fmt.Errorf("Could not load database %s", dbName) + return 1, fmt.Errorf("Could not load database %s", currentDbName) } - err = checkoutTables(ctx, roots, dbName, args) + err = checkoutTables(ctx, roots, currentDbName, args) if err != nil && apr.NArg() == 1 { - err = checkoutRemoteBranch(ctx, dbName, dbData, branchName, apr, &rsc) + err = checkoutRemoteBranch(ctx, currentDbName, dbData, branchName, apr, &rsc) } if err != nil { @@ -175,29 +168,6 @@ func createWorkingSetForLocalBranch(ctx *sql.Context, ddb *doltdb.DoltDB, branch return ddb.UpdateWorkingSet(ctx, wsRef, ws, hash.Hash{} /* current hash... */, doltdb.TodoWorkingSetMeta(), nil) } -// getRevisionForRevisionDatabase returns the root database name and revision for a database, or just the root database name if the specified db name is not a revision database. -func getRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, string, error) { - doltsess, ok := ctx.Session.(*dsess.DoltSession) - if !ok { - return "", "", fmt.Errorf("unexpected session type: %T", ctx.Session) - } - - db, ok, err := doltsess.Provider().SessionDatabase(ctx, dbName) - if err != nil { - return "", "", err - } - if !ok { - return "", "", sql.ErrDatabaseNotFound.New(dbName) - } - - rdb, ok := db.(dsess.RevisionDatabase) - if !ok { - return dbName, "", nil - } - - return rdb.BaseName(), rdb.Revision(), nil -} - // checkoutRemoteBranch checks out a remote branch creating a new local branch with the same name as the remote branch // and set its upstream. The upstream persists out of sql session. func checkoutRemoteBranch(ctx *sql.Context, dbName string, dbData env.DbData, branchName string, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error { @@ -272,10 +242,6 @@ func checkoutNewBranch(ctx *sql.Context, dbName string, dbData env.DbData, apr * if err != nil { return err } - err = checkoutBranch(ctx, dbName, newBranchName) - if err != nil { - return err - } if setTrackUpstream { err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(remoteBranchName)) @@ -285,16 +251,28 @@ func checkoutNewBranch(ctx *sql.Context, dbName string, dbData env.DbData, apr * } else if autoSetupMerge, err := loadConfig(ctx).GetString("branch.autosetupmerge"); err != nil || autoSetupMerge != "false" { remoteName, remoteBranchName = actions.ParseRemoteBranchName(startPt) refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranchName) - if err != nil { - return nil - } - err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(remoteBranchName)) - if err != nil { - return err + if err == nil { + err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(remoteBranchName)) + if err != nil { + return err + } } } - return nil + // We need to commit the transaction here or else the branch we just created isn't visible to the current transaction, + // and we are about to switch to it. So set the new branch head for the new transaction, then commit this one + sess := dsess.DSessFromSess(ctx.Session) + err = commitTransaction(ctx, sess, rsc) + if err != nil { + return err + } + + wsRef, err := ref.WorkingSetRefForHead(ref.NewBranchRef(newBranchName)) + if err != nil { + return err + } + + return sess.SwitchWorkingSet(ctx, dbName, wsRef) } func checkoutBranch(ctx *sql.Context, dbName string, branchName string) error { diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_conflicts_resolve.go b/go/libraries/doltcore/sqle/dprocedures/dolt_conflicts_resolve.go index e26aa829f0..9b782ed608 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_conflicts_resolve.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_conflicts_resolve.go @@ -419,7 +419,7 @@ func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root *dolt if err != nil { return err } - opts := state.WriteSession.GetOptions() + opts := state.WriteSession().GetOptions() tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName, sch) } if err != nil { diff --git a/go/libraries/doltcore/sqle/dsess/branch_control.go b/go/libraries/doltcore/sqle/dsess/branch_control.go new file mode 100755 index 0000000000..c878065896 --- /dev/null +++ b/go/libraries/doltcore/sqle/dsess/branch_control.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dsess + +import ( + "context" + + "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" +) + +// CheckAccessForDb checks whether the current user has the given permissions for the given database. +// This has to live here, rather than in the branch_control package, to prevent a dependency cycle with that package. +// We could also avoid this by defining branchController as an interface used by dsess. +func CheckAccessForDb(ctx context.Context, db SqlDatabase, flags branch_control.Permissions) error { + branchAwareSession := branch_control.GetBranchAwareSession(ctx) + // A nil session means we're not in the SQL context, so we allow all operations + if branchAwareSession == nil { + return nil + } + + controller := branchAwareSession.GetController() + // Any context that has a non-nil session should always have a non-nil controller, so this is an error + if controller == nil { + return branch_control.ErrMissingController.New() + } + + controller.Access.RWMutex.RLock() + defer controller.Access.RWMutex.RUnlock() + + user := branchAwareSession.GetUser() + host := branchAwareSession.GetHost() + + if db.RevisionType() != RevisionTypeBranch { + // not a branch db, no check necessary + return nil + } + + dbName, branch := SplitRevisionDbName(db.RevisionQualifiedName()) + + // Get the permissions for the branch, user, and host combination + _, perms := controller.Access.Match(dbName, branch, user, host) + // If either the flags match or the user is an admin for this branch, then we allow access + if (perms&flags == flags) || (perms&branch_control.Permissions_Admin == branch_control.Permissions_Admin) { + return nil + } + return branch_control.ErrIncorrectPermissions.New(user, host, branch) +} diff --git a/go/libraries/doltcore/sqle/dsess/database_session_state.go b/go/libraries/doltcore/sqle/dsess/database_session_state.go index bcb9a680d5..8259cb1e3c 100644 --- a/go/libraries/doltcore/sqle/dsess/database_session_state.go +++ b/go/libraries/doltcore/sqle/dsess/database_session_state.go @@ -35,13 +35,12 @@ type InitialDbState struct { // RootValue must be set. HeadCommit *doltdb.Commit // HeadRoot is the root value for databases without a HeadCommit. Nil for databases with a HeadCommit. - HeadRoot *doltdb.RootValue - ReadOnly bool - DbData env.DbData - ReadReplica *env.Remote - Remotes map[string]env.Remote - Branches map[string]env.BranchConfig - Backups map[string]env.Remote + HeadRoot *doltdb.RootValue + ReadOnly bool + DbData env.DbData + Remotes map[string]env.Remote + Branches map[string]env.BranchConfig + Backups map[string]env.Remote // If err is set, this InitialDbState is partially invalid, but may be // usable to initialize a database at a revision specifier, for @@ -54,24 +53,33 @@ type InitialDbState struct { // order for the session to manage it. type SessionDatabase interface { sql.Database - InitialDBState(ctx *sql.Context, branch string) (InitialDbState, error) + InitialDBState(ctx *sql.Context) (InitialDbState, error) } +// DatabaseSessionState is the set of all information for a given database in this session. type DatabaseSessionState struct { - dbName string - db SqlDatabase - headCommit *doltdb.Commit - headRoot *doltdb.RootValue - WorkingSet *doltdb.WorkingSet - dbData env.DbData - WriteSession writer.WriteSession - globalState globalstate.GlobalState - readOnly bool - dirty bool - readReplica *env.Remote - tmpFileDir string - - sessionCache *SessionCache + // dbName is the name of the database this state applies to. This is always the base name of the database, without + // a revision qualifier. + dbName string + // currRevSpec is the current revision spec of the database when referred to by its base name. Changes when a + // `dolt_checkout` or `use` statement is executed. + currRevSpec string + // currRevType is the current revision type of the database when referred to by its base name. Changes when a + // `dolt_checkout` or `use` statement is executed. + currRevType RevisionType + // checkedOutRevSpec is the checked out revision specifier of the database. Changes only when a `dolt_checkout` + // occurs. `USE mydb` without a revision qualifier will get this revision. + checkedOutRevSpec string + // heads records the in-memory DB state for every branch head accessed by the session + heads map[string]*branchState + // headCache records the session-caches for every branch head accessed by the session + // This is managed separately from the branch states themselves because it persists across transactions (which is + // safe because it's keyed by immutable hashes) + headCache map[string]*SessionCache + // globalState is the global state of this session (shared by all sessions for a particular db) + globalState globalstate.GlobalState + // tmpFileDir is the directory to use for temporary files for this database + tmpFileDir string // Same as InitialDbState.Err, this signifies that this // DatabaseSessionState is invalid. LookupDbState returning a @@ -79,31 +87,97 @@ type DatabaseSessionState struct { Err error } -func NewEmptyDatabaseSessionState() *DatabaseSessionState { +func newEmptyDatabaseSessionState() *DatabaseSessionState { return &DatabaseSessionState{ - sessionCache: newSessionCache(), + heads: make(map[string]*branchState), + headCache: make(map[string]*SessionCache), } } -func (d DatabaseSessionState) GetRoots() doltdb.Roots { - if d.WorkingSet == nil { +// SessionState is the public interface for dealing with session state outside this package. Session-state is always +// branch-specific. +type SessionState interface { + WorkingSet() *doltdb.WorkingSet + WorkingRoot() *doltdb.RootValue + WriteSession() writer.WriteSession + EditOpts() editor.Options + SessionCache() *SessionCache +} + +// branchState records all the in-memory session state for a particular branch head +type branchState struct { + // dbState is the parent database state for this branch head state + dbState *DatabaseSessionState + // head is the name of the branch head for this state + head string + // headCommit is the head commit for this database. May be nil for databases tied to a detached root value, in which + // case headRoot must be set. + headCommit *doltdb.Commit + // HeadRoot is the root value for databases without a headCommit. Nil for databases with a headCommit. + headRoot *doltdb.RootValue + // workingSet is the working set for this database. May be nil for databases tied to a detached root value, in which + // case headCommit must be set + workingSet *doltdb.WorkingSet + // dbData is an accessor for the underlying doltDb + dbData env.DbData + // writeSession is this head's write session + writeSession writer.WriteSession + // readOnly is true if this database is read only + readOnly bool + // dirty is true if this branch state has uncommitted changes + dirty bool +} + +// NewEmptyBranchState creates a new branch state for the given head name with the head provided, adds it to the db +// state, and returns it. The state returned is empty except for its identifiers and must be filled in by the caller. +func (dbState *DatabaseSessionState) NewEmptyBranchState(head string) *branchState { + b := &branchState{ + dbState: dbState, + head: head, + } + + dbState.heads[head] = b + _, ok := dbState.headCache[head] + if !ok { + dbState.headCache[head] = newSessionCache() + } + + return b +} + +func (bs *branchState) WorkingRoot() *doltdb.RootValue { + return bs.roots().Working +} + +var _ SessionState = (*branchState)(nil) + +func (bs *branchState) WorkingSet() *doltdb.WorkingSet { + return bs.workingSet +} + +func (bs *branchState) WriteSession() writer.WriteSession { + return bs.writeSession +} + +func (bs *branchState) SessionCache() *SessionCache { + return bs.dbState.headCache[bs.head] +} + +func (bs branchState) EditOpts() editor.Options { + return bs.WriteSession().GetOptions() +} + +func (bs *branchState) roots() doltdb.Roots { + if bs.WorkingSet() == nil { return doltdb.Roots{ - Head: d.headRoot, - Working: d.headRoot, - Staged: d.headRoot, + Head: bs.headRoot, + Working: bs.headRoot, + Staged: bs.headRoot, } } return doltdb.Roots{ - Head: d.headRoot, - Working: d.WorkingSet.WorkingRoot(), - Staged: d.WorkingSet.StagedRoot(), + Head: bs.headRoot, + Working: bs.WorkingSet().WorkingRoot(), + Staged: bs.WorkingSet().StagedRoot(), } } - -func (d *DatabaseSessionState) SessionCache() *SessionCache { - return d.sessionCache -} - -func (d DatabaseSessionState) EditOpts() editor.Options { - return d.WriteSession.GetOptions() -} diff --git a/go/libraries/doltcore/sqle/dsess/doc.go b/go/libraries/doltcore/sqle/dsess/doc.go new file mode 100755 index 0000000000..b1c1e6bb19 --- /dev/null +++ b/go/libraries/doltcore/sqle/dsess/doc.go @@ -0,0 +1,71 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dsess + +/* + +The dsess package is responsible for storing the state of every database in each session. + +The major players in this process are: + +* sqle.Database: The database implementation we provide to integrate with go-mysql-server's interface is mostly a + wrapper to provide access to the actual storage of tables and rows that are held by dsess.Session. +* sqle.DatabaseProvider: Responsible for creating a new sqle.Database for each database name asked for by the engine, + as well as for managing the details of replication on the databases it returns. +* dsess.Session: Responsible for maintaining the state of each session, including the data access for any row data. + Each physical dolt database in the provider can have the state of multiple branch heads managed by a session. This + state is loaded on demand from the provider as the client asks for different databases by name, as `dolt_checkout` + is called, etc. +* dsess.DoltTransaction: Records a start state (noms root) for each database managed by the transaction. Responsible + for committing new data as the result of a COMMIT or dolt_commit() by merging this start state with session changes + as appropriate. + +The rough flow of data between the engine and this package: + +1) START TRANSACTION calls dsess.Session.StartTransaction() to create a new dsess.DoltTransaction. This transaction + takes a snapshot of the current noms root for each database known to the provider and records these as part of the + transaction. This method clears out all cached state. +2) The engine calls DatabaseProvider.Database() to get a sqle.Database for each database name included in a query, + including statements like `USE db`. +3) Databases have access to tables, views, and other schema elements that they provide to the engine upon request as + part of query analysis, row iteration, etc. As a rule, this data is loaded from the session when asked for. Databases, + tables, views, and other structures in the sqle package are best understood as pass-through entities that always + defer to the session for their actual data. +4) When actual data is required, a table or other schema element asks the session for the data. The primary interface + for this exchange is Session.LookupDbState(), which takes a database name. +5) Eventually, the client session issues a COMMIT or DOLT_COMMIT() statement. This calls Session.CommitTransaction(), + which enforces business logic rules around the commit and then calls DoltTransaction.Commit() to persist the changes. + +Databases managed by the provider and the session can be referred to by either a base name (myDb) or a fully qualified +name (myDb/myBranch). The details of this are a little bit subtle: + +* Database names that aren't qualified with a revision specifier resolve to either a) the default branch head, or + b) whatever branch head was last checked out with dolt_checkout(). Changing the branch head referred to by an + unqualified database name is the primary purpose of dolt_checkout(). +* `mydb/branch` always resolves to that branch head, as it existed at transaction start +* Database names exposed to the engine are always `mydb`, never `mydb/branch`. This includes the result of + `select database()`. This is because the engine expects base database names when e.g. checking GRANTs, returning + information the information schema table, etc. +* sqle.Database has an external name it exposes to the engine via Name(), as well as an internal name that includes a + revision qualifier, RevisionQualifiedName(). The latter should always be used internally when accessing session data, + including rows and all other table data. It's only appropriate to use an unqualified database name when you want + the current checked out HEAD. + +It's possible to alter the data on multiple HEADS in a single session, but we currently restrict the users to +committing a single one. It doesn't need to be the checked out head -- we simply look for a single dirty branch head +state and commit that one. If there is more than one, it's an error. We may allow multiple branch heads to be updated +in a single transaction in the future. + +*/ diff --git a/go/libraries/doltcore/sqle/dsess/dolt_session_test.go b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go index 805237c5ec..8b3fd1f923 100644 --- a/go/libraries/doltcore/sqle/dsess/dolt_session_test.go +++ b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go @@ -251,6 +251,10 @@ type emptyRevisionDatabaseProvider struct { sql.DatabaseProvider } +func (e emptyRevisionDatabaseProvider) BaseDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool) { + return nil, false +} + func (e emptyRevisionDatabaseProvider) SessionDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool, error) { return nil, false, sql.ErrDatabaseNotFound.New(dbName) } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index aec019701b..c3d7c7feb2 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -48,10 +48,13 @@ const ( Batched ) +const ( + DbRevisionDelimiter = "/" +) + var ErrWorkingSetChanges = goerrors.NewKind("Cannot switch working set, session state is dirty. " + "Rollback or commit changes before changing working sets.") var ErrSessionNotPeristable = errors.New("session is not persistable") -var ErrCurrentBranchDeleted = errors.New("current branch has been force deleted. run 'USE /' to checkout a different branch, or reconnect to the server") // DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance type DoltSession struct { @@ -60,6 +63,7 @@ type DoltSession struct { username string email string dbStates map[string]*DatabaseSessionState + dbCache *DatabaseCache provider DoltDatabaseProvider tempTables map[string][]sql.Table globalsConf config.ReadWriteConfig @@ -83,6 +87,7 @@ func DefaultSession(pro DoltDatabaseProvider) *DoltSession { username: "", email: "", dbStates: make(map[string]*DatabaseSessionState), + dbCache: newDatabaseCache(), provider: pro, tempTables: make(map[string][]sql.Table), globalsConf: config.NewMapConfig(make(map[string]string)), @@ -107,6 +112,7 @@ func NewDoltSession( username: username, email: email, dbStates: make(map[string]*DatabaseSessionState), + dbCache: newDatabaseCache(), provider: pro, tempTables: make(map[string][]sql.Table), globalsConf: globals, @@ -134,50 +140,92 @@ func DSessFromSess(sess sql.Session) *DoltSession { return sess.(*DoltSession) } -// LookupDbState returns the session state for the database named -func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { +// lookupDbState is the private version of LookupDbState, returning a struct that has more information available than +// the interface returned by the public method. +func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchState, bool, error) { dbName = strings.ToLower(dbName) + + var baseName, rev string + baseName, rev = SplitRevisionDbName(dbName) + d.mu.Lock() - dbState, ok := d.dbStates[dbName] + dbState, dbStateFound := d.dbStates[baseName] d.mu.Unlock() - if ok { - return dbState, ok, nil + + if dbStateFound { + // If we got an unqualified name, use the current working set head + if rev == "" { + rev = dbState.currRevSpec + } + + branchState, ok := dbState.heads[strings.ToLower(rev)] + + if ok { + if dbState.Err != nil { + return nil, false, dbState.Err + } + + return branchState, ok, nil + } } - // TODO: this needs to include the transaction's snapshot of the DB at tx start time + // No state for this db / branch combination yet, look it up from the provider. We use the unqualified DB name (no + // branch) if the current DB has not yet been loaded into this session. It will resolve to that DB's default branch + // in that case. + revisionQualifiedName := dbName + if rev != "" { + revisionQualifiedName = revisionDbName(baseName, rev) + } - database, ok, err := d.provider.SessionDatabase(ctx, dbName) + database, ok, err := d.provider.SessionDatabase(ctx, revisionQualifiedName) if err != nil { return nil, false, err } - if !ok { return nil, false, nil } // Add the initial state to the session for future reuse - if err = d.addDB(ctx, database); err != nil { + if err := d.addDB(ctx, database); err != nil { return nil, false, err } d.mu.Lock() - dbState, ok = d.dbStates[dbName] + dbState, dbStateFound = d.dbStates[baseName] d.mu.Unlock() - if !ok { + if !dbStateFound { + // should be impossible return nil, false, sql.ErrDatabaseNotFound.New(dbName) } - return dbState, true, nil + return dbState.heads[strings.ToLower(database.Revision())], true, nil } -func (d *DoltSession) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { +func revisionDbName(baseName string, rev string) string { + return baseName + DbRevisionDelimiter + rev +} + +func SplitRevisionDbName(dbName string) (string, string) { + var baseName, rev string + parts := strings.SplitN(dbName, DbRevisionDelimiter, 2) + baseName = parts[0] + if len(parts) > 1 { + rev = parts[1] + } + return baseName, rev +} + +// LookupDbState returns the session state for the database named. Unqualified database names, e.g. `mydb` get resolved +// to the currently checked out HEAD, which could be a branch, a commit, a tag, etc. Revision-qualified database names, +// e.g. `mydb/branch1` get resolved to the session state for the revision named. +// A note on unqualified database names: unqualified names will resolve to a) the head last checked out with +// `dolt_checkout`, or b) the database's default branch, if this session hasn't called `dolt_checkout` yet. +// Also returns a bool indicating whether the database was found, and an error if one occurred. +func (d *DoltSession) LookupDbState(ctx *sql.Context, dbName string) (SessionState, bool, error) { s, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return nil, false, err } - if ok && s.Err != nil { - return nil, false, s.Err - } return s, ok, nil } @@ -187,6 +235,8 @@ func (d *DoltSession) RemoveDbState(_ *sql.Context, dbName string) error { d.mu.Lock() defer d.mu.Unlock() delete(d.dbStates, strings.ToLower(dbName)) + // also clear out any db-level caches for this db + d.dbCache.Clear() return nil } @@ -194,12 +244,12 @@ func (d *DoltSession) RemoveDbState(_ *sql.Context, dbName string) error { // happens automatically as part of statement execution, and is only necessary when the session is manually batched (as // for bulk SQL import) func (d *DoltSession) Flush(ctx *sql.Context, dbName string) error { - dbState, _, err := d.LookupDbState(ctx, dbName) + branchState, _, err := d.lookupDbState(ctx, dbName) if err != nil { return err } - ws, err := dbState.WriteSession.Flush(ctx) + ws, err := branchState.WriteSession().Flush(ctx) if err != nil { return err } @@ -211,7 +261,7 @@ func (d *DoltSession) Flush(ctx *sql.Context, dbName string) error { // to ValidateSession. This is effectively a way to disable a session. // // Used by sql/cluster logic to make sessions on a server which has -// transitioned roles termainlly error. +// transitioned roles terminally error. func (d *DoltSession) SetValidateErr(err error) { d.validateErr = err } @@ -220,152 +270,91 @@ func (d *DoltSession) SetValidateErr(err error) { // If there is no sessionState or its current working set not defined, then no need for validation, // so no error is returned. func (d *DoltSession) ValidateSession(ctx *sql.Context, dbName string) error { - if d.validateErr != nil { - return d.validateErr - } - sessionState, ok, err := d.LookupDbState(ctx, dbName) - if err != nil { - return err - } - if !ok { - return nil - } - if sessionState.WorkingSet == nil { - return nil - } - wsRef := sessionState.WorkingSet.Ref() - _, err = sessionState.dbData.Ddb.ResolveWorkingSet(ctx, wsRef) - if err == doltdb.ErrWorkingSetNotFound { - _, err = d.newWorkingSetForHead(ctx, wsRef, dbName) - // if the current head is not found, the branch was force deleted, so use nil working set. - if errors.Is(err, doltdb.ErrBranchNotFound) { - return ErrCurrentBranchDeleted - } else if err != nil { - return err - } - } else if err != nil { - return err - } - return nil + return d.validateErr } // StartTransaction refreshes the state of this session and starts a new transaction. func (d *DoltSession) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) { + // TODO: this is only necessary to support filter-branch, which needs to set a root directly and not have the + // session state altered when a transaction begins if TransactionsDisabled(ctx) { return DisabledTransaction{}, nil } - // TODO: remove this when we have true multi-db transaction support - dbName := ctx.GetTransactionDatabase() - if isNoOpTransactionDatabase(dbName) { - return DisabledTransaction{}, nil - } - // New transaction, clear all session state - d.clearRevisionDbState() + d.clear() - sessionState, ok, err := d.LookupDbState(ctx, dbName) - if err != nil { - return nil, err - } - - if !ok { - return nil, sql.ErrDatabaseNotFound.New(dbName) - } - - // There are both valid and invalid ways that a working set for the session state can be nil (e.g. connected to a - // commit hash revision DB, or the DB contents cannot be loaded). Either way this transaction is defunct. - // TODO: with multi-db transactions, such DBs should be ignored - if sessionState.WorkingSet == nil { - return DisabledTransaction{}, nil - } - - // TODO: this needs to happen for every DB in the database, not just the one named in the transaction - if sessionState != nil && sessionState.db != nil { - rrd, ok := sessionState.db.(RemoteReadReplicaDatabase) - if ok && rrd.ValidReplicaState(ctx) { - err := rrd.PullFromRemote(ctx) - if err != nil && !IgnoreReplicationErrors() { - return nil, fmt.Errorf("replication error: %w", err) - } else if err != nil { - WarnReplicationError(ctx, err) - } - } - } - - if sessionState.readOnly { - return DisabledTransaction{}, nil - } - - nomsRoots := make(map[string]hash.Hash) - for _, db := range d.provider.DoltDatabases() { + // Take a snapshot of the current noms root for every database under management + doltDatabases := d.provider.DoltDatabases() + txDbs := make([]SqlDatabase, 0, len(doltDatabases)) + for _, db := range doltDatabases { // TODO: this nil check is only necessary to support UserSpaceDatabase and clusterDatabase, come up with a better set of // interfaces to capture these capabilities ddb := db.DbData().Ddb if ddb != nil { - nomsRoot, err := ddb.NomsRoot(ctx) - if err != nil { - return nil, err + rrd, ok := db.(RemoteReadReplicaDatabase) + if ok && rrd.ValidReplicaState(ctx) { + err := rrd.PullFromRemote(ctx) + if err != nil && !IgnoreReplicationErrors() { + return nil, fmt.Errorf("replication error: %w", err) + } else if err != nil { + WarnReplicationError(ctx, err) + } } - nomsRoots[strings.ToLower(db.Name())] = nomsRoot + + // TODO: this check is relatively expensive, we should cache this value when it changes instead of looking it + // up on each transaction start + if _, v, ok := sql.SystemVariables.GetGlobal(ReadReplicaRemote); ok && v != "" { + err := ddb.Rebase(ctx) + if err != nil && !IgnoreReplicationErrors() { + return nil, err + } else if err != nil { + WarnReplicationError(ctx, err) + } + } + + txDbs = append(txDbs, db) } } - if _, v, ok := sql.SystemVariables.GetGlobal(ReadReplicaRemote); ok && v != "" { - err = sessionState.dbData.Ddb.Rebase(ctx) - if err != nil && !IgnoreReplicationErrors() { - return nil, err - } else if err != nil { - WarnReplicationError(ctx, err) - } - } - - wsRef := sessionState.WorkingSet.Ref() - ws, err := sessionState.dbData.Ddb.ResolveWorkingSet(ctx, wsRef) - // TODO: every HEAD needs a working set created when it is. We can get rid of this in a 1.0 release when this is fixed - if err == doltdb.ErrWorkingSetNotFound { - ws, err = d.newWorkingSetForHead(ctx, wsRef, dbName) - if err != nil { - return nil, err - } - } else if err != nil { + tx, err := NewDoltTransaction(ctx, txDbs, tCharacteristic) + if err != nil { return nil, err } - // logrus.Tracef("starting transaction with working root %s", ws.WorkingRoot().DebugString(ctx, true)) + // The engine sets the transaction after this call as well, but since we begin accessing data below, we need to set + // this now to avoid seeding the session state with stale data in some cases. The duplication is harmless since the + // code below cannot error. + ctx.SetTransaction(tx) - // TODO: this is going to do 2 resolves to get the head root, not ideal - err = d.SetWorkingSet(ctx, dbName, ws) + // Set session vars for every DB in this session using their current branch head + for _, db := range doltDatabases { + // faulty settings can make it impossible to load particular DB branch states, so we ignore any errors in this + // loop and just decline to set the session vars. Throwing an error on transaction start in these cases makes it + // impossible for the user to correct any problems. + bs, ok, err := d.lookupDbState(ctx, db.Name()) + if err != nil || !ok { + continue + } - // SetWorkingSet always sets the dirty bit, but by definition we are clean at transaction start - sessionState.dirty = false + _ = d.setDbSessionVars(ctx, bs, false) + } - return NewDoltTransaction(dbName, nomsRoots, ws, wsRef, sessionState.dbData, sessionState.WriteSession.GetOptions(), tCharacteristic), nil + return tx, nil } -// clearRevisionDbState clears all revision DB states for this session. This is necessary on transaction start, -// because they will be re-initialized with the current branch head / working set. -// TODO: this should happen with every dbstate, not just revision DBs. The problem is that we track the current working -// -// set *only* in the session state. We need to disentangle the metadata about a state (working ref, persists across -// transactions) from its data (re-initialized on every transaction start) -func (d *DoltSession) clearRevisionDbState() { +// clear clears all DB state for this session +func (d *DoltSession) clear() { d.mu.Lock() defer d.mu.Unlock() for _, dbState := range d.dbStates { - if len(dbState.db.Revision()) > 0 { - delete(d.dbStates, strings.ToLower(dbState.db.Name())) + for head := range dbState.heads { + delete(dbState.heads, head) } } } -// isNoOpTransactionDatabase returns whether the database name given is a non-Dolt database that shouldn't have -// transaction logic performed on it -func isNoOpTransactionDatabase(dbName string) bool { - return len(dbName) == 0 || dbName == "information_schema" || dbName == "mysql" -} - func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) { dbData, _ := d.GetDbData(nil, dbName) @@ -388,18 +377,26 @@ func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSe return doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(headRoot).WithStagedRoot(headRoot), nil } -// CommitTransaction commits the in-progress transaction for the database named. Depending on session settings, this -// may write only a new working set, or may additionally create a new dolt commit for the current HEAD. -func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error { - dbName := ctx.GetTransactionDatabase() - if isNoOpTransactionDatabase(dbName) { - return nil - } +// CommitTransaction commits the in-progress transaction. Depending on session settings, this may write only a new +// working set, or may additionally create a new dolt commit for the current HEAD. If more than one branch head has +// changes, the transaction is rejected. +func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) (err error) { + // Any non-error path must set the ctx's transaction to nil even if no work was done, because the engine only clears + // out transaction state in some cases. Changes to only branch heads (creating a new branch, reset, etc.) have no + // changes to commit visible to the transaction logic, but they still need a new transaction on the next statement. + // See comment in |commitBranchState| + defer func() { + if err == nil { + ctx.SetTransaction(nil) + } + }() if d.BatchMode() == Batched { - err := d.Flush(ctx, dbName) - if err != nil { - return err + for _, db := range d.provider.DoltDatabases() { + err = d.Flush(ctx, db.Name()) + if err != nil { + return err + } } } @@ -407,12 +404,15 @@ func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) er return nil } - // This is triggered when certain commands are sent to the server (ex. commit) when a database is not selected. - // These commands should not error. - if dbName == "" { + dirties := d.dirtyWorkingSets() + if len(dirties) == 0 { return nil } + if len(dirties) > 1 { + return ErrDirtyWorkingSets + } + performDoltCommitVar, err := d.Session.GetSessionVariable(ctx, DoltCommitOnTransactionCommit) if err != nil { return err @@ -423,8 +423,16 @@ func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) er return fmt.Errorf(fmt.Sprintf("Unexpected type for var %s: %T", DoltCommitOnTransactionCommit, performDoltCommitVar)) } + dirtyBranchState := dirties[0] if peformDoltCommitInt == 1 { - pendingCommit, err := d.PendingCommitAllStaged(ctx, dbName, actions.CommitStagedProps{ + // if the dirty working set doesn't belong to the currently checked out branch, that's an error + err = d.validateDoltCommit(ctx, dirtyBranchState) + if err != nil { + return err + } + + var pendingCommit *doltdb.PendingCommit + pendingCommit, err = d.PendingCommitAllStaged(ctx, dirtyBranchState, actions.CommitStagedProps{ Message: "Transaction commit", Date: ctx.QueryTime(), AllowEmpty: false, @@ -438,45 +446,83 @@ func (d *DoltSession) CommitTransaction(ctx *sql.Context, tx sql.Transaction) er // Nothing to stage, so fall back to CommitWorkingSet logic instead if pendingCommit == nil { - return d.CommitWorkingSet(ctx, dbName, tx) + return d.commitWorkingSet(ctx, dirtyBranchState, tx) } - _, err = d.DoltCommit(ctx, dbName, tx, pendingCommit) + _, err = d.DoltCommit(ctx, dirtyBranchState.dbState.dbName, tx, pendingCommit) return err } else { - return d.CommitWorkingSet(ctx, dbName, tx) + return d.commitWorkingSet(ctx, dirtyBranchState, tx) } } -// isDirty returns whether the working set for the database named is dirty -// TODO: remove the dbname parameter, return a global dirty bit -func (d *DoltSession) isDirty(ctx *sql.Context, dbName string) (bool, error) { - dbState, _, err := d.LookupDbState(ctx, dbName) - if err != nil { - return false, err +func (d *DoltSession) validateDoltCommit(ctx *sql.Context, dirtyBranchState *branchState) error { + currDb := ctx.GetCurrentDatabase() + if currDb == "" { + return fmt.Errorf("cannot dolt_commit with no database selected") + } + currDbBaseName, _ := SplitRevisionDbName(currDb) + dirtyDbBaseName, _ := SplitRevisionDbName(dirtyBranchState.dbState.dbName) + + if strings.ToLower(currDbBaseName) != strings.ToLower(dirtyDbBaseName) { + return fmt.Errorf("no changes to dolt_commit on database %s", currDbBaseName) } - return dbState.dirty, nil + d.mu.Lock() + dbState, ok := d.dbStates[strings.ToLower(currDbBaseName)] + d.mu.Unlock() + + if !ok { + return fmt.Errorf("no database state found for %s", currDbBaseName) + } + + dirtyBranch, err := dirtyBranchState.workingSet.Ref().ToHeadRef() + if err != nil { + return err + } + if dbState.currRevSpec != dirtyBranch.GetPath() { + return fmt.Errorf("no changes to dolt_commit on branch %s", dbState.currRevSpec) + } + + return nil +} + +var ErrDirtyWorkingSets = errors.New("Cannot commit changes on more than one branch / database") + +// dirtyWorkingSets returns all dirty working sets for this session +func (d *DoltSession) dirtyWorkingSets() []*branchState { + var dirtyStates []*branchState + for _, state := range d.dbStates { + for _, branchState := range state.heads { + if branchState.dirty { + dirtyStates = append(dirtyStates, branchState) + } + } + } + + return dirtyStates } // CommitWorkingSet commits the working set for the transaction given, without creating a new dolt commit. // Clients should typically use CommitTransaction, which performs additional checks, instead of this method. func (d *DoltSession) CommitWorkingSet(ctx *sql.Context, dbName string, tx sql.Transaction) error { - dirty, err := d.isDirty(ctx, dbName) - if err != nil { - return err - } - - if !dirty { - return nil - } - commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) { - ws, err := dtx.Commit(ctx, workingSet) + ws, err := dtx.Commit(ctx, workingSet, dbName) return ws, nil, err } - _, err = d.doCommit(ctx, dbName, tx, commitFunc) + _, err := d.commitCurrentHead(ctx, dbName, tx, commitFunc) + return err +} + +// commitWorkingSet commits the working set for the branch state given, without creating a new dolt commit. +func (d *DoltSession) commitWorkingSet(ctx *sql.Context, branchState *branchState, tx sql.Transaction) error { + commitFunc := func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) { + ws, err := dtx.Commit(ctx, workingSet, branchState.dbState.dbName) + return ws, nil, err + } + + _, err := d.commitBranchState(ctx, branchState, tx, commitFunc) return err } @@ -492,93 +538,100 @@ func (d *DoltSession) DoltCommit( ws, commit, err := dtx.DoltCommit( ctx, workingSet.WithWorkingRoot(commit.Roots.Working).WithStagedRoot(commit.Roots.Staged), - commit) + commit, + dbName) if err != nil { return nil, nil, err } - // Unlike normal COMMIT statements, CALL DOLT_COMMIT() doesn't get the current transaction cleared out by the query - // engine, so we do it here. - // TODO: the engine needs to manage this - ctx.SetTransaction(nil) - return ws, commit, err } - return d.doCommit(ctx, dbName, tx, commitFunc) + return d.commitCurrentHead(ctx, dbName, tx, commitFunc) } // doCommitFunc is a function to write to the database, which involves updating the working set and potentially // updating HEAD with a new commit type doCommitFunc func(ctx *sql.Context, dtx *DoltTransaction, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, *doltdb.Commit, error) -// doCommit exercise the business logic for a particular doCommitFunc -func (d *DoltSession) doCommit(ctx *sql.Context, dbName string, tx sql.Transaction, commitFunc doCommitFunc) (*doltdb.Commit, error) { - dbState, ok, err := d.LookupDbState(ctx, dbName) - if err != nil { - return nil, err - } else if !ok { - // It's possible that we don't have dbstate if the user has created an in-Memory database. Moreover, - // the analyzer will check for us whether a db exists or not. - // TODO: fix this - return nil, nil - } - - // TODO: validate that the transaction belongs to the DB named +// commitBranchState performs a commit for the branch state given, using the doCommitFunc provided +func (d *DoltSession) commitBranchState( + ctx *sql.Context, + branchState *branchState, + tx sql.Transaction, + commitFunc doCommitFunc, +) (*doltdb.Commit, error) { dtx, ok := tx.(*DoltTransaction) if !ok { return nil, fmt.Errorf("expected a DoltTransaction") } - mergedWorkingSet, newCommit, err := commitFunc(ctx, dtx, dbState.WorkingSet) + _, newCommit, err := commitFunc(ctx, dtx, branchState.WorkingSet()) if err != nil { return nil, err } - err = d.SetWorkingSet(ctx, dbName, mergedWorkingSet) - if err != nil { - return nil, err - } - - dbState.dirty = false + // Anything that commits a transaction needs its current transaction state cleared so that the next statement starts + // a new transaction. This should in principle be done by the engine, but it currently only understands explicit + // COMMIT statements. Any other statements that commit a transaction, including stored procedures, needs to do this + // themselves. + ctx.SetTransaction(nil) return newCommit, nil } -// PendingCommitAllStaged returns a pending commit with all tables staged. Returns nil if there are no changes to stage. -func (d *DoltSession) PendingCommitAllStaged(ctx *sql.Context, dbName string, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { - roots, ok := d.GetRoots(ctx, dbName) - if !ok { - return nil, fmt.Errorf("Couldn't get info for database %s", dbName) +// commitCurrentHead commits the current HEAD for the database given, using the doCommitFunc provided +func (d *DoltSession) commitCurrentHead(ctx *sql.Context, dbName string, tx sql.Transaction, commitFunc doCommitFunc) (*doltdb.Commit, error) { + branchState, ok, err := d.lookupDbState(ctx, dbName) + if err != nil { + return nil, err + } else if !ok { + return nil, sql.ErrDatabaseNotFound.New(dbName) } + return d.commitBranchState(ctx, branchState, tx, commitFunc) +} + +// PendingCommitAllStaged returns a pending commit with all tables staged. Returns nil if there are no changes to stage. +func (d *DoltSession) PendingCommitAllStaged(ctx *sql.Context, branchState *branchState, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { + roots := branchState.roots() + var err error roots, err = actions.StageAllTables(ctx, roots, true) if err != nil { return nil, err } - return d.NewPendingCommit(ctx, dbName, roots, props) + return d.newPendingCommit(ctx, branchState, roots, props) } // NewPendingCommit returns a new |doltdb.PendingCommit| for the database named, using the roots given, adding any // merge parent from an in progress merge as appropriate. The session working set is not updated with these new roots, // but they are set in the returned |doltdb.PendingCommit|. If there are no changes staged, this method returns nil. func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { - sessionState, _, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return nil, err } + if !ok { + return nil, fmt.Errorf("session state for database %s not found", dbName) + } - headCommit := sessionState.headCommit + return d.newPendingCommit(ctx, branchState, roots, props) +} + +// newPendingCommit returns a new |doltdb.PendingCommit| for the database and head named by |branchState| +// See NewPendingCommit +func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) { + headCommit := branchState.headCommit headHash, _ := headCommit.HashOf() - if sessionState.WorkingSet == nil { + if branchState.WorkingSet() == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } var mergeParentCommits []*doltdb.Commit - if sessionState.WorkingSet.MergeActive() { - mergeParentCommits = []*doltdb.Commit{sessionState.WorkingSet.MergeState().Commit()} + if branchState.WorkingSet().MergeActive() { + mergeParentCommits = []*doltdb.Commit{branchState.WorkingSet().MergeState().Commit()} } else if props.Amend { numParentsHeadForAmend := headCommit.NumParents() for i := 0; i < numParentsHeadForAmend; i++ { @@ -591,12 +644,12 @@ func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots do // TODO: This is not the correct way to write this commit as an amend. While this commit is running // the branch head moves backwards and concurrency control here is not principled. - newRoots, err := actions.ResetSoftToRef(ctx, sessionState.dbData, "HEAD~1") + newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1") if err != nil { return nil, err } - err = d.SetWorkingSet(ctx, dbName, sessionState.WorkingSet.WithStagedRoot(newRoots.Staged)) + err = d.SetWorkingSet(ctx, branchState.dbState.dbName, branchState.WorkingSet().WithStagedRoot(newRoots.Staged)) if err != nil { return nil, err } @@ -604,10 +657,10 @@ func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots do roots.Head = newRoots.Head } - pendingCommit, err := actions.GetCommitStaged(ctx, roots, sessionState.WorkingSet, mergeParentCommits, sessionState.dbData.Ddb, props) + pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props) if err != nil { if props.Amend { - _, err = actions.ResetSoftToRef(ctx, sessionState.dbData, headHash.String()) + _, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String()) if err != nil { return nil, err } @@ -622,40 +675,8 @@ func (d *DoltSession) NewPendingCommit(ctx *sql.Context, dbName string, roots do // Rollback rolls the given transaction back func (d *DoltSession) Rollback(ctx *sql.Context, tx sql.Transaction) error { - dbName := ctx.GetTransactionDatabase() - - if TransactionsDisabled(ctx) || dbName == "" { - return nil - } - - dirty, err := d.isDirty(ctx, dbName) - if err != nil { - return err - } - - if !dirty { - return nil - } - - dbState, ok, err := d.LookupDbState(ctx, dbName) - if err != nil { - return err - } - - dtx, ok := tx.(*DoltTransaction) - if !ok { - return fmt.Errorf("expected a DoltTransaction") - } - - // This operation usually doesn't matter, because the engine will process a `rollback` statement by first calling - // this logic, then discarding any current transaction. So the next statement will get a fresh transaction regardless, - // and this is throwaway work. It only matters if this method is used outside a standalone `rollback` statement. - err = d.SetRoot(ctx, dbName, dtx.startState.WorkingRoot()) - if err != nil { - return err - } - - dbState.dirty = false + // Nothing to do here, we just throw away all our work and let a new transaction begin next statement + d.clear() return nil } @@ -673,12 +694,12 @@ func (d *DoltSession) CreateSavepoint(ctx *sql.Context, tx sql.Transaction, save return fmt.Errorf("expected a DoltTransaction") } - dbState, ok, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return err } - dtx.CreateSavepoint(savepointName, dbState.GetRoots().Working) + dtx.CreateSavepoint(savepointName, branchState.roots().Working) return nil } @@ -733,7 +754,7 @@ func (d *DoltSession) ReleaseSavepoint(ctx *sql.Context, tx sql.Transaction, sav // GetDoltDB returns the *DoltDB for a given database by name func (d *DoltSession) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB, bool) { - dbState, ok, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return nil, false } @@ -741,11 +762,11 @@ func (d *DoltSession) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB return nil, false } - return dbState.dbData.Ddb, true + return branchState.dbData.Ddb, true } func (d *DoltSession) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bool) { - dbState, ok, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return env.DbData{}, false } @@ -753,12 +774,12 @@ func (d *DoltSession) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bo return env.DbData{}, false } - return dbState.dbData, true + return branchState.dbData, true } // GetRoots returns the current roots for a given database associated with the session func (d *DoltSession) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, bool) { - dbState, ok, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return doltdb.Roots{}, false } @@ -766,7 +787,7 @@ func (d *DoltSession) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, b return doltdb.Roots{}, false } - return dbState.GetRoots(), true + return branchState.roots(), true } // ResolveRootForRef returns the root value for the ref given, which refers to either a commit spec or is one of the @@ -832,101 +853,74 @@ func (d *DoltSession) ResolveRootForRef(ctx *sql.Context, dbName, refStr string) // SetRoot sets a new root value for the session for the database named. This is the primary mechanism by which data // changes are communicated to the engine and persisted back to disk. All data changes should be followed by a call to // update the session's root value via this method. +// The dbName given should generally be a revision-qualified database name. // Data changes contained in the |newRoot| aren't persisted until this session is committed. // TODO: rename to SetWorkingRoot func (d *DoltSession) SetRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error { - // TODO: this is redundant with work done in setRoot - sessionState, _, err := d.LookupDbState(ctx, dbName) + branchState, _, err := d.lookupDbState(ctx, dbName) if err != nil { return err } - if sessionState.WorkingSet == nil { + if branchState.WorkingSet() == nil { return doltdb.ErrOperationNotSupportedInDetachedHead } - if rootsEqual(sessionState.GetRoots().Working, newRoot) { + if rootsEqual(branchState.roots().Working, newRoot) { return nil } - if sessionState.readOnly { - // TODO: Return an error here? - return nil + if branchState.readOnly { + return fmt.Errorf("cannot set root on read-only session") } - sessionState.WorkingSet = sessionState.WorkingSet.WithWorkingRoot(newRoot) + branchState.workingSet = branchState.WorkingSet().WithWorkingRoot(newRoot) - return d.SetWorkingSet(ctx, dbName, sessionState.WorkingSet) + return d.SetWorkingSet(ctx, dbName, branchState.WorkingSet()) } // SetRoots sets new roots for the session for the database named. Typically clients should only set the working root, // via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions. // Unlike setting the working root, this method always marks the database state dirty. func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error { - // TODO: handle HEAD here? sessionState, _, err := d.LookupDbState(ctx, dbName) if err != nil { return err } - if sessionState.WorkingSet == nil { + if sessionState.WorkingSet() == nil { return doltdb.ErrOperationNotSupportedInDetachedHead } - workingSet := sessionState.WorkingSet.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged) + workingSet := sessionState.WorkingSet().WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged) return d.SetWorkingSet(ctx, dbName, workingSet) } // SetWorkingSet sets the working set for this session. -// Unlike setting the working root alone, this method always marks the session dirty. func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.WorkingSet) error { if ws == nil { panic("attempted to set a nil working set for the session") } - sessionState, _, err := d.LookupDbState(ctx, dbName) + branchState, _, err := d.lookupDbState(ctx, dbName) if err != nil { return err } - if ws.Ref() != sessionState.WorkingSet.Ref() { + if ws.Ref() != branchState.WorkingSet().Ref() { return fmt.Errorf("must switch working sets with SwitchWorkingSet") } - sessionState.WorkingSet = ws + branchState.workingSet = ws - cs, err := doltdb.NewCommitSpec(ws.Ref().GetPath()) + err = d.setDbSessionVars(ctx, branchState, true) if err != nil { return err } - branchRef, err := ws.Ref().ToHeadRef() + err = branchState.WriteSession().SetWorkingSet(ctx, ws) if err != nil { return err } - cm, err := sessionState.dbData.Ddb.Resolve(ctx, cs, branchRef) - if err != nil { - return err - } - sessionState.headCommit = cm - - headRoot, err := cm.GetRootValue(ctx) - if err != nil { - return err - } - - sessionState.headRoot = headRoot - - err = d.setSessionVarsForDb(ctx, dbName) - if err != nil { - return err - } - - err = sessionState.WriteSession.SetWorkingSet(ctx, ws) - if err != nil { - return err - } - - sessionState.dirty = true - + branchState.dirty = true return nil } @@ -940,122 +934,88 @@ func (d *DoltSession) SwitchWorkingSet( dbName string, wsRef ref.WorkingSetRef, ) error { - sessionState, _, err := d.LookupDbState(ctx, dbName) + headRef, err := wsRef.ToHeadRef() if err != nil { return err } - // TODO: should this be an error if any database in the transaction is dirty, or just this one? - if sessionState.dirty { - return ErrWorkingSetChanges.New() - } + d.mu.Lock() - // TODO: this should call session.StartTransaction once that has been cleaned up a bit - nomsRoots := make(map[string]hash.Hash) - for _, db := range d.provider.DoltDatabases() { - // TODO: this nil check is only necessary to support UserSpaceDatabase and clusterDatabase, come up with a better set of - // interfaces to capture these capabilities - ddb := db.DbData().Ddb - if ddb != nil { - nomsRoot, err := ddb.NomsRoot(ctx) - if err != nil { - return err - } - nomsRoots[strings.ToLower(db.Name())] = nomsRoot - } + baseName, _ := SplitRevisionDbName(dbName) + dbState, ok := d.dbStates[strings.ToLower(baseName)] + if !ok { + d.mu.Unlock() + return sql.ErrDatabaseNotFound.New(dbName) } + dbState.checkedOutRevSpec = headRef.GetPath() + dbState.currRevSpec = headRef.GetPath() + dbState.currRevType = RevisionTypeBranch - // TODO: resolve the working set ref with the root above - ws, err := sessionState.dbData.Ddb.ResolveWorkingSet(ctx, wsRef) + d.mu.Unlock() + + // bootstrap the db state as necessary + branchState, ok, err := d.lookupDbState(ctx, baseName+DbRevisionDelimiter+headRef.GetPath()) if err != nil { return err } - // TODO: just call SetWorkingSet? - sessionState.WorkingSet = ws + if !ok { + return sql.ErrDatabaseNotFound.New(dbName) + } - cs, err := doltdb.NewCommitSpec(ws.Ref().GetPath()) + ctx.SetCurrentDatabase(baseName) + + return d.setDbSessionVars(ctx, branchState, false) +} + +func (d *DoltSession) UseDatabase(ctx *sql.Context, db sql.Database) error { + sdb, ok := db.(SqlDatabase) + if !ok { + return fmt.Errorf("expected a SqlDatabase, got %T", db) + } + + branchState, ok, err := d.lookupDbState(ctx, sdb.RevisionQualifiedName()) if err != nil { return err } - - branchRef, err := ws.Ref().ToHeadRef() - if err != nil { - return err + if !ok { + return sql.ErrDatabaseNotFound.New(db.Name()) } - cm, err := sessionState.dbData.Ddb.Resolve(ctx, cs, branchRef) - if err != nil { - return err - } + d.mu.Lock() + defer d.mu.Unlock() - sessionState.headCommit = cm - sessionState.headRoot, err = cm.GetRootValue(ctx) - if err != nil { - return err + // Set the session state for this database according to what database name was USEd + // In the case of a revision qualified name, that will be the revision specified + // In the case of an unqualified name (USE mydb), this will be the last checked out head in this session. + _, rev := SplitRevisionDbName(sdb.RequestedName()) + dbState := branchState.dbState + if rev == "" { + dbState.currRevSpec = dbState.checkedOutRevSpec + dbState.currRevType = RevisionTypeBranch + } else { + dbState.currRevSpec = sdb.Revision() + dbState.currRevType = sdb.RevisionType() } - err = d.setSessionVarsForDb(ctx, dbName) - if err != nil { - return err - } - - h, err := ws.WorkingRoot().HashOf() - if err != nil { - return err - } - - err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String()) - if err != nil { - return err - } - - // make a fresh WriteSession, discard existing WriteSession - opts := sessionState.WriteSession.GetOptions() - nbf := ws.WorkingRoot().VRW().Format() - tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx) - if err != nil { - return err - } - sessionState.WriteSession = writer.NewWriteSession(nbf, ws, tracker, opts) - - // After switching to a new working set, we are by definition clean - sessionState.dirty = false - - // the current transaction, if there is one, needs to be restarted - tCharacteristic := sql.ReadWrite - if t := ctx.GetTransaction(); t != nil { - if t.IsReadOnly() { - tCharacteristic = sql.ReadOnly - } - } - ctx.SetTransaction(NewDoltTransaction( - dbName, - nomsRoots, - ws, - wsRef, - sessionState.dbData, - sessionState.WriteSession.GetOptions(), - tCharacteristic, - )) - return nil } func (d *DoltSession) WorkingSet(ctx *sql.Context, dbName string) (*doltdb.WorkingSet, error) { + // TODO: need to make sure we use a revision qualified DB name here sessionState, _, err := d.LookupDbState(ctx, dbName) if err != nil { return nil, err } - if sessionState.WorkingSet == nil { + if sessionState.WorkingSet() == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } - return sessionState.WorkingSet, nil + return sessionState.WorkingSet(), nil } // GetHeadCommit returns the parent commit of the current session. func (d *DoltSession) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Commit, error) { - dbState, ok, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return nil, err } @@ -1063,7 +1023,7 @@ func (d *DoltSession) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Co return nil, sql.ErrDatabaseNotFound.New(dbName) } - return dbState.headCommit, nil + return branchState.headCommit, nil } // SetSessionVariable is defined on sql.Session. We intercept it here to interpret the special semantics of the system @@ -1118,17 +1078,22 @@ func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string if convertedVal != nil { intVal = convertedVal.(int64) } + if intVal == 0 { for _, dbState := range d.dbStates { - opts := dbState.WriteSession.GetOptions() - opts.ForeignKeyChecksDisabled = true - dbState.WriteSession.SetOptions(opts) + for _, branchState := range dbState.heads { + opts := branchState.WriteSession().GetOptions() + opts.ForeignKeyChecksDisabled = true + branchState.WriteSession().SetOptions(opts) + } } } else if intVal == 1 { for _, dbState := range d.dbStates { - opts := dbState.WriteSession.GetOptions() - opts.ForeignKeyChecksDisabled = false - dbState.WriteSession.SetOptions(opts) + for _, branchState := range dbState.heads { + opts := branchState.WriteSession().GetOptions() + opts.ForeignKeyChecksDisabled = false + branchState.WriteSession().SetOptions(opts) + } } } else { return fmt.Errorf("variable 'foreign_key_checks' can't be set to the value of '%d'", intVal) @@ -1137,68 +1102,100 @@ func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string return d.Session.SetSessionVariable(ctx, key, value) } -// HasDB returns true if |sess| is tracking state for this database. -func (d *DoltSession) HasDB(_ *sql.Context, dbName string) bool { - d.mu.Lock() - defer d.mu.Unlock() - _, ok := d.dbStates[strings.ToLower(dbName)] - return ok -} - // addDB adds the database given to this session. This establishes a starting root value for this session, as well as // other state tracking metadata. func (d *DoltSession) addDB(ctx *sql.Context, db SqlDatabase) error { - DefineSystemVariablesForDB(db.Name()) + revisionQualifiedName := strings.ToLower(db.RevisionQualifiedName()) + baseName, rev := SplitRevisionDbName(revisionQualifiedName) + + DefineSystemVariablesForDB(baseName) + + tx, usingDoltTransaction := d.GetTransaction().(*DoltTransaction) - sessionState := NewEmptyDatabaseSessionState() d.mu.Lock() - d.dbStates[strings.ToLower(db.Name())] = sessionState - d.mu.Unlock() - sessionState.dbName = db.Name() - sessionState.db = db + defer d.mu.Unlock() + sessionState, sessionStateExists := d.dbStates[baseName] - _, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(db.Name())) - initialBranch := "" - if ok { - initialBranch = val.(string) + // Before computing initial state for the DB, check to see if we have it in the cache + var dbState InitialDbState + var dbStateCached bool + if usingDoltTransaction { + nomsRoot, ok := tx.GetInitialRoot(baseName) + if ok && sessionStateExists { + dbState, dbStateCached = d.dbCache.GetCachedInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName) + } } - // TODO: the branch should be already set if the DB was specified with a branch revision string - dbState, err := db.InitialDBState(ctx, initialBranch) - if err != nil { - return err + if !dbStateCached { + var err error + dbState, err = db.InitialDBState(ctx) + if err != nil { + return err + } } + if !sessionStateExists { + sessionState = newEmptyDatabaseSessionState() + d.dbStates[baseName] = sessionState + + var err error + sessionState.tmpFileDir, err = dbState.DbData.Rsw.TempTableFilesDir() + if err != nil { + if errors.Is(err, env.ErrDoltRepositoryNotFound) { + return env.ErrFailedToAccessDB.New(dbState.Db.Name()) + } + return err + } + + sessionState.dbName = baseName + + baseDb, ok := d.provider.BaseDatabase(ctx, baseName) + if !ok { + return fmt.Errorf("unable to find database %s, this is a bug", baseName) + } + + // The checkedOutRevSpec should be the checked out branch of the database if available, or the revision + // string otherwise + sessionState.checkedOutRevSpec, err = DefaultHead(baseName, baseDb) + if err != nil { + return err + } + + sessionState.currRevType = db.RevisionType() + sessionState.currRevSpec = db.Revision() + } + + if !dbStateCached && usingDoltTransaction { + nomsRoot, ok := tx.GetInitialRoot(baseName) + if ok { + d.dbCache.CacheInitialDbState(doltdb.DataCacheKey{Hash: nomsRoot}, revisionQualifiedName, dbState) + } + } + + branchState := sessionState.NewEmptyBranchState(rev) + // TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and // the writer with one that errors out // TODO: this no longer gets called at session creation time, so the error handling below never occurs when a // database is deleted out from under a running server - sessionState.dbData = dbState.DbData - tmpDir, err := dbState.DbData.Rsw.TempTableFilesDir() - if err != nil { - if errors.Is(err, env.ErrDoltRepositoryNotFound) { - return env.ErrFailedToAccessDB.New(dbState.Db.Name()) - } - return err - } - sessionState.tmpFileDir = tmpDir + branchState.dbData = dbState.DbData adapter := NewSessionStateAdapter(d, db.Name(), dbState.Remotes, dbState.Branches, dbState.Backups) - sessionState.dbData.Rsr = adapter - sessionState.dbData.Rsw = adapter - sessionState.readOnly, sessionState.readReplica = dbState.ReadOnly, dbState.ReadReplica + branchState.dbData.Rsr = adapter + branchState.dbData.Rsw = adapter + branchState.readOnly = dbState.ReadOnly // TODO: figure out how to cast this to dsqle.SqlDatabase without creating import cycles // Or better yet, get rid of EditOptions from the database, it's a session setting nbf := types.Format_Default - if sessionState.dbData.Ddb != nil { - nbf = sessionState.dbData.Ddb.Format() + if branchState.dbData.Ddb != nil { + nbf = branchState.dbData.Ddb.Format() } editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions() if dbState.Err != nil { sessionState.Err = dbState.Err } else if dbState.WorkingSet != nil { - sessionState.WorkingSet = dbState.WorkingSet + branchState.workingSet = dbState.WorkingSet // TODO: this is pretty clunky, there is a silly dependency between InitialDbState and globalstate.StateProvider // that's hard to express with the current types @@ -1212,51 +1209,45 @@ func (d *DoltSession) addDB(ctx *sql.Context, db SqlDatabase) error { if err != nil { return err } - sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, tracker, editOpts) - if err = d.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil { - return err - } - } else if dbState.HeadCommit != nil { - // WorkingSet is nil in the case of a read only, detached head DB + branchState.writeSession = writer.NewWriteSession(nbf, branchState.WorkingSet(), tracker, editOpts) + } + + // WorkingSet is nil in the case of a read only, detached head DB + if dbState.HeadCommit != nil { headRoot, err := dbState.HeadCommit.GetRootValue(ctx) if err != nil { return err } - sessionState.headRoot = headRoot + branchState.headRoot = headRoot } else if dbState.HeadRoot != nil { - sessionState.headRoot = dbState.HeadRoot + branchState.headRoot = dbState.HeadRoot } - // This has to happen after SetRoot above, since it does a stale check before its work - // TODO: this needs to be kept up to date as the working set ref changes - sessionState.headCommit = dbState.HeadCommit - - // After setting the initial root we have no state to commit - sessionState.dirty = false - - if sessionState.Err == nil { - return d.setSessionVarsForDb(ctx, db.Name()) - } + branchState.headCommit = dbState.HeadCommit return nil } +func (d *DoltSession) DatabaseCache(ctx *sql.Context) *DatabaseCache { + return d.dbCache +} + func (d *DoltSession) AddTemporaryTable(ctx *sql.Context, db string, tbl sql.Table) { - d.tempTables[db] = append(d.tempTables[db], tbl) + d.tempTables[strings.ToLower(db)] = append(d.tempTables[strings.ToLower(db)], tbl) } func (d *DoltSession) DropTemporaryTable(ctx *sql.Context, db, name string) { - tables := d.tempTables[db] - for i, tbl := range d.tempTables[db] { + tables := d.tempTables[strings.ToLower(db)] + for i, tbl := range d.tempTables[strings.ToLower(db)] { if strings.ToLower(tbl.Name()) == strings.ToLower(name) { tables = append(tables[:i], tables[i+1:]...) break } } - d.tempTables[db] = tables + d.tempTables[strings.ToLower(db)] = tables } func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql.Table, bool) { - for _, tbl := range d.tempTables[db] { + for _, tbl := range d.tempTables[strings.ToLower(db)] { if strings.ToLower(tbl.Name()) == strings.ToLower(name) { return tbl, true } @@ -1266,21 +1257,47 @@ func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql. // GetAllTemporaryTables returns all temp tables for this session. func (d *DoltSession) GetAllTemporaryTables(ctx *sql.Context, db string) ([]sql.Table, error) { - return d.tempTables[db], nil + return d.tempTables[strings.ToLower(db)], nil } // CWBHeadRef returns the branch ref for this session HEAD for the database named func (d *DoltSession) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, error) { - dbState, _, err := d.LookupDbState(ctx, dbName) + branchState, ok, err := d.lookupDbState(ctx, dbName) if err != nil { return nil, err } + if !ok { + return nil, sql.ErrDatabaseNotFound.New(dbName) + } - if dbState.WorkingSet == nil { + if branchState.dbState.currRevType != RevisionTypeBranch { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } - return dbState.WorkingSet.Ref().ToHeadRef() + return ref.NewBranchRef(branchState.dbState.currRevSpec), nil +} + +// CurrentHead returns the current head for the db named, which must be unqualifed. Used for bootstrap resolving the +// correct session head when a database name from the client is unqualified. +// TODO: audit uses, see if basename can be removed +func (d *DoltSession) CurrentHead(ctx *sql.Context, dbName string) (string, bool, error) { + dbName = strings.ToLower(dbName) + + var baseName, rev string + baseName, rev = SplitRevisionDbName(dbName) + if rev != "" { + return "", false, fmt.Errorf("invalid database name: %s", dbName) + } + + d.mu.Lock() + dbState, ok := d.dbStates[baseName] + d.mu.Unlock() + + if ok { + return dbState.currRevSpec, true, nil + } + + return "", false, nil } func (d *DoltSession) Username() string { @@ -1295,35 +1312,38 @@ func (d *DoltSession) BatchMode() batchMode { return d.batchMode } -// setSessionVarsForDb updates the three session vars that track the value of the session root hashes -func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error { - state, _, err := d.lookupDbState(ctx, dbName) - if err != nil { - return err +// setDbSessionVars updates the three session vars that track the value of the session root hashes +func (d *DoltSession) setDbSessionVars(ctx *sql.Context, state *branchState, force bool) error { + // This check is important even when we are forcing an update, because it updates the idea of staleness + varsStale := d.dbSessionVarsStale(ctx, state) + if !varsStale && !force { + return nil } + baseName := state.dbState.dbName + // Different DBs have different requirements for what state is set, so we are maximally permissive on what's expected // in the state object here - if state.WorkingSet != nil { - headRef, err := state.WorkingSet.Ref().ToHeadRef() + if state.WorkingSet() != nil { + headRef, err := state.WorkingSet().Ref().ToHeadRef() if err != nil { return err } - err = d.Session.SetSessionVariable(ctx, HeadRefKey(dbName), headRef.String()) + err = d.Session.SetSessionVariable(ctx, HeadRefKey(baseName), headRef.String()) if err != nil { return err } } - roots := state.GetRoots() + roots := state.roots() if roots.Working != nil { h, err := roots.Working.HashOf() if err != nil { return err } - err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String()) + err = d.Session.SetSessionVariable(ctx, WorkingKey(baseName), h.String()) if err != nil { return err } @@ -1334,7 +1354,7 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error if err != nil { return err } - err = d.Session.SetSessionVariable(ctx, StagedKey(dbName), h.String()) + err = d.Session.SetSessionVariable(ctx, StagedKey(baseName), h.String()) if err != nil { return err } @@ -1345,7 +1365,7 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error if err != nil { return err } - err = d.Session.SetSessionVariable(ctx, HeadKey(dbName), h.String()) + err = d.Session.SetSessionVariable(ctx, HeadKey(baseName), h.String()) if err != nil { return err } @@ -1354,6 +1374,17 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error return nil } +// dbSessionVarsStale returns whether the session vars for the database with the state provided need to be updated in +// the session +func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) bool { + dtx, ok := ctx.GetTransaction().(*DoltTransaction) + if !ok { + return true + } + + return d.dbCache.CacheSessionVars(state, dtx) +} + func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession { d.globalsConf = conf return &d @@ -1433,6 +1464,7 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) { // GetBranch implements the interface branch_control.Context. func (d *DoltSession) GetBranch() (string, error) { + // TODO: creating a new SQL context here is expensive ctx := sql.NewContext(context.Background(), sql.WithSession(d)) currentDb := d.Session.GetCurrentDatabase() @@ -1441,13 +1473,13 @@ func (d *DoltSession) GetBranch() (string, error) { return "", nil } - dbState, _, err := d.LookupDbState(ctx, currentDb) + branchState, _, err := d.LookupDbState(ctx, currentDb) if err != nil { return "", err } - if dbState.WorkingSet != nil { - branchRef, err := dbState.WorkingSet.Ref().ToHeadRef() + if branchState.WorkingSet() != nil { + branchRef, err := branchState.WorkingSet().Ref().ToHeadRef() if err != nil { return "", err } @@ -1620,22 +1652,6 @@ func InitPersistedSystemVars(dEnv *env.DoltEnv) error { return nil } -// SplitRevisionDbName splits the given database name into its base and revision parts and returns them. Non-revision -// DBs use their full name as the base name, and empty string as the revision. -func SplitRevisionDbName(db SqlDatabase) (string, string) { - sqldb, ok := db.(SqlDatabase) - if !ok { - return db.Name(), "" - } - - dbName := db.Name() - if sqldb.Revision() != "" { - dbName = strings.TrimSuffix(dbName, DbRevisionDelimiter+sqldb.Revision()) - } - - return dbName, sqldb.Revision() -} - // TransactionRoot returns the noms root for the given database in the current transaction func TransactionRoot(ctx *sql.Context, db SqlDatabase) (hash.Hash, error) { tx, ok := ctx.GetTransaction().(*DoltTransaction) @@ -1644,8 +1660,7 @@ func TransactionRoot(ctx *sql.Context, db SqlDatabase) (hash.Hash, error) { return db.DbData().Ddb.NomsRoot(ctx) } - baseName, _ := SplitRevisionDbName(db) - nomsRoot, ok := tx.GetInitialRoot(baseName) + nomsRoot, ok := tx.GetInitialRoot(db.Name()) if !ok { return hash.Hash{}, fmt.Errorf("could not resolve initial root for database %s", db.Name()) } @@ -1653,6 +1668,38 @@ func TransactionRoot(ctx *sql.Context, db SqlDatabase) (hash.Hash, error) { return nomsRoot, nil } -const ( - DbRevisionDelimiter = "/" -) +// DefaultHead returns the head for the database given when one isn't specified +func DefaultHead(baseName string, db SqlDatabase) (string, error) { + head := "" + + // First check the global variable for the default branch + _, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(baseName)) + if ok { + head = val.(string) + branchRef, err := ref.Parse(head) + if err == nil { + head = branchRef.GetPath() + } else { + head = "" + // continue to below + } + } + + // Fall back to the database's initially checked out branch + if head == "" { + rsr := db.DbData().Rsr + if rsr != nil { + headRef, err := rsr.CWBHeadRef() + if err != nil { + return "", err + } + head = headRef.GetPath() + } + } + + if head == "" { + head = db.Revision() + } + + return head, nil +} diff --git a/go/libraries/doltcore/sqle/dsess/session_cache.go b/go/libraries/doltcore/sqle/dsess/session_cache.go index 37fc55b08b..7bd320ea56 100755 --- a/go/libraries/doltcore/sqle/dsess/session_cache.go +++ b/go/libraries/doltcore/sqle/dsess/session_cache.go @@ -24,7 +24,6 @@ import ( ) // SessionCache caches various pieces of expensive to compute information to speed up future lookups in the session. -// No methods are thread safe. type SessionCache struct { indexes map[doltdb.DataCacheKey]map[string][]sql.Index tables map[doltdb.DataCacheKey]map[string]sql.Table @@ -33,12 +32,43 @@ type SessionCache struct { mu sync.RWMutex } +// DatabaseCache stores databases and their initial states, offloading the compute / IO involved in resolving a +// database name to a particular database. This is safe only because the database objects themselves don't have any +// handles to data or state, but always defer to the session. Keys in the secondary map are revision specifier strings +type DatabaseCache struct { + // revisionDbs caches databases by name. The name is always lower case and revision qualified + revisionDbs map[revisionDbCacheKey]SqlDatabase + // initialDbStates caches the initial state of databases by name for a given noms root, which is the primary key. + // The secondary key is the lower-case revision-qualified database name. + initialDbStates map[doltdb.DataCacheKey]map[string]InitialDbState + // sessionVars records a key for the most recently used session vars for each database in the session + sessionVars map[string]sessionVarCacheKey + + mu sync.RWMutex +} + +type revisionDbCacheKey struct { + dbName string + requestedName string +} + +type sessionVarCacheKey struct { + root doltdb.DataCacheKey + head string +} + const maxCachedKeys = 64 func newSessionCache() *SessionCache { return &SessionCache{} } +func newDatabaseCache() *DatabaseCache { + return &DatabaseCache{ + sessionVars: make(map[string]sessionVarCacheKey), + } +} + // CacheTableIndexes caches all indexes for the table with the name given func (c *SessionCache) CacheTableIndexes(key doltdb.DataCacheKey, table string, indexes []sql.Index) { c.mu.Lock() @@ -193,3 +223,115 @@ func (c *SessionCache) GetCachedViewDefinition(key doltdb.DataCacheKey, viewName table, ok := viewsForKey[viewName] return table, ok } + +// GetCachedRevisionDb returns the cached revision database named, and whether the cache was present +func (c *DatabaseCache) GetCachedRevisionDb(revisionDbName string, requestedName string) (SqlDatabase, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.revisionDbs == nil { + return nil, false + } + + db, ok := c.revisionDbs[revisionDbCacheKey{ + dbName: revisionDbName, + requestedName: requestedName, + }] + return db, ok +} + +// CacheRevisionDb caches the revision database named +func (c *DatabaseCache) CacheRevisionDb(database SqlDatabase) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.revisionDbs == nil { + c.revisionDbs = make(map[revisionDbCacheKey]SqlDatabase) + } + + if len(c.revisionDbs) > maxCachedKeys { + for k := range c.revisionDbs { + delete(c.revisionDbs, k) + } + } + + c.revisionDbs[revisionDbCacheKey{ + dbName: strings.ToLower(database.RevisionQualifiedName()), + requestedName: database.RequestedName(), + }] = database +} + +// GetCachedInitialDbState returns the cached initial state for the revision database named, and whether the cache +// was present +func (c *DatabaseCache) GetCachedInitialDbState(key doltdb.DataCacheKey, revisionDbName string) (InitialDbState, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.initialDbStates == nil { + return InitialDbState{}, false + } + + dbsForKey, ok := c.initialDbStates[key] + if !ok { + return InitialDbState{}, false + } + + db, ok := dbsForKey[revisionDbName] + return db, ok +} + +// CacheInitialDbState caches the initials state for the revision database named +func (c *DatabaseCache) CacheInitialDbState(key doltdb.DataCacheKey, revisionDbName string, state InitialDbState) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initialDbStates == nil { + c.initialDbStates = make(map[doltdb.DataCacheKey]map[string]InitialDbState) + } + + if len(c.initialDbStates) > maxCachedKeys { + for k := range c.initialDbStates { + delete(c.initialDbStates, k) + } + } + + dbsForKey, ok := c.initialDbStates[key] + if !ok { + dbsForKey = make(map[string]InitialDbState) + c.initialDbStates[key] = dbsForKey + } + + dbsForKey[revisionDbName] = state +} + +// CacheSessionVars updates the session var cache for the given branch state and transaction and returns whether it +// was updated. If it was updated, session vars need to be set for the state and transaction given. Otherwise they +// haven't changed and can be reused. +func (c *DatabaseCache) CacheSessionVars(branchState *branchState, transaction *DoltTransaction) bool { + c.mu.Lock() + defer c.mu.Unlock() + + dbBaseName := branchState.dbState.dbName + + existingKey, found := c.sessionVars[dbBaseName] + root, hasRoot := transaction.GetInitialRoot(dbBaseName) + if !hasRoot { + return true + } + + newKey := sessionVarCacheKey{ + root: doltdb.DataCacheKey{Hash: root}, + head: branchState.head, + } + + c.sessionVars[dbBaseName] = newKey + return !found || existingKey != newKey +} + +func (c *DatabaseCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.sessionVars = make(map[string]sessionVarCacheKey) + c.revisionDbs = make(map[revisionDbCacheKey]SqlDatabase) + c.initialDbStates = make(map[doltdb.DataCacheKey]map[string]InitialDbState) +} diff --git a/go/libraries/doltcore/sqle/dsess/session_db_provider.go b/go/libraries/doltcore/sqle/dsess/session_db_provider.go index 7aa947dc59..a3d34ce4aa 100644 --- a/go/libraries/doltcore/sqle/dsess/session_db_provider.go +++ b/go/libraries/doltcore/sqle/dsess/session_db_provider.go @@ -38,9 +38,17 @@ type RevisionDatabase interface { Revision() string // RevisionType returns the type of revision this database is pinned to. RevisionType() RevisionType - // BaseName returns the name of the database without the revision specifier. E.g.if the database is named - // "myDB/master", BaseName returns "myDB". - BaseName() string + // RevisionQualifiedName returns the fully qualified name of the database, which includes the revision if one is + // specified. + RevisionQualifiedName() string + // RequestedName returns the name of the database as requested by the user when the name was resolved to this + // database. + RequestedName() string + // Versioned returns whether this database implementation supports more than a single revision. + // TODO: This shouldn't be a necessary part of the interface, but it's required to differentiate between dolt-backed + // databases and others that we serve for custom purposes with similar pieces of functionality, and the session + // management logic intermixes these concerns. + Versioned() bool } // RevisionType represents the type of revision a database is pinned to. For branches and tags, the revision is a @@ -83,6 +91,9 @@ type DoltDatabaseProvider interface { // SessionDatabase returns the SessionDatabase for the specified database, which may name a revision of a base // database. SessionDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool, error) + // BaseDatabase returns the base database for the specified database name. Meant for informational purposes when + // managing the session initialization only. Use SessionDatabase for normal database retrieval. + BaseDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool) // DoltDatabases returns all databases known to this provider. DoltDatabases() []SqlDatabase } diff --git a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go index 700480d008..a9333dd434 100644 --- a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go +++ b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go @@ -53,12 +53,12 @@ func NewSessionStateAdapter(session *DoltSession, dbName string, remotes map[str func (s SessionStateAdapter) GetRoots(ctx context.Context) (doltdb.Roots, error) { sqlCtx := sql.NewContext(ctx) - state, _, err := s.session.LookupDbState(sqlCtx, s.dbName) + state, _, err := s.session.lookupDbState(sqlCtx, s.dbName) if err != nil { return doltdb.Roots{}, err } - return state.GetRoots(), nil + return state.roots(), nil } func (s SessionStateAdapter) CWBHeadRef() (ref.DoltRef, error) { @@ -227,10 +227,10 @@ func (s SessionStateAdapter) RemoveBackup(_ context.Context, name string) error } func (s SessionStateAdapter) TempTableFilesDir() (string, error) { - state, _, err := s.session.LookupDbState(sql.NewContext(context.Background()), s.dbName) + branchState, _, err := s.session.lookupDbState(sql.NewContext(context.Background()), s.dbName) if err != nil { return "", err } - return state.tmpFileDir, nil + return branchState.dbState.tmpFileDir, nil } diff --git a/go/libraries/doltcore/sqle/dsess/transactions.go b/go/libraries/doltcore/sqle/dsess/transactions.go index 6c0c2986bf..8a26994156 100644 --- a/go/libraries/doltcore/sqle/dsess/transactions.go +++ b/go/libraries/doltcore/sqle/dsess/transactions.go @@ -29,9 +29,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" - "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/merge" - "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/hash" @@ -78,39 +76,48 @@ func (d DisabledTransaction) IsReadOnly() bool { } type DoltTransaction struct { - sourceDbName string - startRootHash map[string]hash.Hash - startState *doltdb.WorkingSet - workingSetRef ref.WorkingSetRef - dbData env.DbData + dbStartPoints map[string]dbRoot savepoints []savepoint - mergeEditOpts editor.Options tCharacteristic sql.TransactionCharacteristic } +type dbRoot struct { + dbName string + rootHash hash.Hash + db *doltdb.DoltDB +} + type savepoint struct { name string + // TODO: we need a root value per DB here root *doltdb.RootValue } func NewDoltTransaction( - dbName string, - startingRoots map[string]hash.Hash, - startState *doltdb.WorkingSet, - workingSet ref.WorkingSetRef, - dbData env.DbData, - mergeEditOpts editor.Options, + ctx *sql.Context, + dbs []SqlDatabase, tCharacteristic sql.TransactionCharacteristic, -) *DoltTransaction { - return &DoltTransaction{ - sourceDbName: dbName, - startRootHash: startingRoots, - startState: startState, - workingSetRef: workingSet, - dbData: dbData, - mergeEditOpts: mergeEditOpts, - tCharacteristic: tCharacteristic, +) (*DoltTransaction, error) { + + startPoints := make(map[string]dbRoot) + for _, db := range dbs { + nomsRoot, err := db.DbData().Ddb.NomsRoot(ctx) + if err != nil { + return nil, err + } + + baseName, _ := SplitRevisionDbName(db.Name()) + startPoints[strings.ToLower(baseName)] = dbRoot{ + dbName: baseName, + rootHash: nomsRoot, + db: db.DbData().Ddb, + } } + + return &DoltTransaction{ + dbStartPoints: startPoints, + tCharacteristic: tCharacteristic, + }, nil } func (tx DoltTransaction) String() string { @@ -122,9 +129,12 @@ func (tx DoltTransaction) IsReadOnly() bool { return tx.tCharacteristic == sql.ReadOnly } +// GetInitialRoot returns the noms root hash for the db named, established when the transaction began. The dbName here +// is always the base name of the database, not the revision qualified one. func (tx DoltTransaction) GetInitialRoot(dbName string) (hash.Hash, bool) { - h, ok := tx.startRootHash[strings.ToLower(dbName)] - return h, ok + dbName, _ = SplitRevisionDbName(dbName) + startPoint, ok := tx.dbStartPoints[strings.ToLower(dbName)] + return startPoint.rootHash, ok } var txLock sync.Mutex @@ -137,25 +147,31 @@ var txLock sync.Mutex // if workingSet.workingRoot == ancRoot, attempt a fast-forward merge // TODO: Non-working roots aren't merged into the working set and just stomp any changes made there. We need merge // strategies for staged as well as merge state. -func (tx *DoltTransaction) Commit(ctx *sql.Context, workingSet *doltdb.WorkingSet) (*doltdb.WorkingSet, error) { - ws, _, err := tx.doCommit(ctx, workingSet, nil, txCommit) +func (tx *DoltTransaction) Commit(ctx *sql.Context, workingSet *doltdb.WorkingSet, dbName string) (*doltdb.WorkingSet, error) { + ws, _, err := tx.doCommit(ctx, workingSet, nil, txCommit, dbName) return ws, err } // transactionWrite is the logic to write an updated working set (and optionally a commit) to the database type transactionWrite func(ctx *sql.Context, tx *DoltTransaction, // the transaction being written + doltDb *doltdb.DoltDB, // the database to write to + startState *doltdb.WorkingSet, // the starting working set commit *doltdb.PendingCommit, // optional workingSet *doltdb.WorkingSet, // must be provided hash hash.Hash, // hash of the current working set to be written + mergeOps editor.Options, // editor options for merges ) (*doltdb.WorkingSet, *doltdb.Commit, error) // doltCommit is a transactionWrite function that updates the working set and commits a pending commit atomically func doltCommit(ctx *sql.Context, - tx *DoltTransaction, - commit *doltdb.PendingCommit, - workingSet *doltdb.WorkingSet, - currHash hash.Hash, + tx *DoltTransaction, // the transaction being written + doltDb *doltdb.DoltDB, // the database to write to + startState *doltdb.WorkingSet, // the starting working set + commit *doltdb.PendingCommit, // optional + workingSet *doltdb.WorkingSet, // must be provided + currHash hash.Hash, // hash of the current working set to be written + mergeOpts editor.Options, // editor options for merges ) (*doltdb.WorkingSet, *doltdb.Commit, error) { pending := *commit @@ -165,7 +181,7 @@ func doltCommit(ctx *sql.Context, } headSpec, _ := doltdb.NewCommitSpec("HEAD") - curHead, err := tx.dbData.Ddb.Resolve(ctx, headSpec, headRef) + curHead, err := doltDb.Resolve(ctx, headSpec, headRef) if err != nil { return nil, nil, err } @@ -198,14 +214,15 @@ func doltCommit(ctx *sql.Context, // updates). The merged root value becomes our new Staged root value which // is the value which we are trying to commit. start := time.Now() + result, err := merge.MergeRoots( ctx, pending.Roots.Staged, curRootVal, pending.Roots.Head, curHead, - tx.startState, - tx.mergeEditOpts, + startState, + mergeOpts, merge.MergeOpts{}) if err != nil { return nil, nil, err @@ -222,27 +239,35 @@ func doltCommit(ctx *sql.Context, workingSet = workingSet.ClearMerge() var rsc doltdb.ReplicationStatusController - newCommit, err := tx.dbData.Ddb.CommitWithWorkingSet(ctx, headRef, tx.workingSetRef, &pending, workingSet, currHash, tx.getWorkingSetMeta(ctx), &rsc) + newCommit, err := doltDb.CommitWithWorkingSet(ctx, headRef, workingSet.Ref(), &pending, workingSet, currHash, tx.getWorkingSetMeta(ctx), &rsc) WaitForReplicationController(ctx, rsc) return workingSet, newCommit, err } // txCommit is a transactionWrite function that updates the working set func txCommit(ctx *sql.Context, - tx *DoltTransaction, - _ *doltdb.PendingCommit, - workingSet *doltdb.WorkingSet, - hash hash.Hash, + tx *DoltTransaction, // the transaction being written + doltDb *doltdb.DoltDB, // the database to write to + _ *doltdb.WorkingSet, // the starting working set + _ *doltdb.PendingCommit, // optional + workingSet *doltdb.WorkingSet, // must be provided + hash hash.Hash, // hash of the current working set to be written + _ editor.Options, // editor options for merges ) (*doltdb.WorkingSet, *doltdb.Commit, error) { var rsc doltdb.ReplicationStatusController - err := tx.dbData.Ddb.UpdateWorkingSet(ctx, tx.workingSetRef, workingSet, hash, tx.getWorkingSetMeta(ctx), &rsc) + err := doltDb.UpdateWorkingSet(ctx, workingSet.Ref(), workingSet, hash, tx.getWorkingSetMeta(ctx), &rsc) WaitForReplicationController(ctx, rsc) return workingSet, nil, err } // DoltCommit commits the working set and creates a new DoltCommit as specified, in one atomic write -func (tx *DoltTransaction) DoltCommit(ctx *sql.Context, workingSet *doltdb.WorkingSet, commit *doltdb.PendingCommit) (*doltdb.WorkingSet, *doltdb.Commit, error) { - return tx.doCommit(ctx, workingSet, commit, doltCommit) +func (tx *DoltTransaction) DoltCommit( + ctx *sql.Context, + workingSet *doltdb.WorkingSet, + commit *doltdb.PendingCommit, + dbName string, +) (*doltdb.WorkingSet, *doltdb.Commit, error) { + return tx.doCommit(ctx, workingSet, commit, doltCommit, dbName) } func WaitForReplicationController(ctx *sql.Context, rsc doltdb.ReplicationStatusController) { @@ -316,7 +341,32 @@ func (tx *DoltTransaction) doCommit( workingSet *doltdb.WorkingSet, commit *doltdb.PendingCommit, writeFn transactionWrite, + dbName string, ) (*doltdb.WorkingSet, *doltdb.Commit, error) { + sess := DSessFromSess(ctx.Session) + branchState, ok, err := sess.lookupDbState(ctx, dbName) + if err != nil { + return nil, nil, err + } + if !ok { + return nil, nil, fmt.Errorf("database %s unknown to transaction, this is a bug", dbName) + } + + // Load the start state for this working set from the noms root at tx start + // Get the base DB name from the db state, not the branch state + startPoint, ok := tx.dbStartPoints[strings.ToLower(branchState.dbState.dbName)] + if !ok { + return nil, nil, fmt.Errorf("database %s unknown to transaction, this is a bug", dbName) + } + + startState, err := startPoint.db.ResolveWorkingSetAtRoot(ctx, workingSet.Ref(), startPoint.rootHash) + if err != nil { + return nil, nil, err + } + + // TODO: no-op if the working set hasn't changed since the transaction started + + mergeOpts := branchState.EditOpts() for i := 0; i < maxTxCommitRetries; i++ { updatedWs, newCommit, err := func() (*doltdb.WorkingSet, *doltdb.Commit, error) { @@ -326,11 +376,11 @@ func (tx *DoltTransaction) doCommit( newWorkingSet := false - existingWs, err := tx.dbData.Ddb.ResolveWorkingSet(ctx, tx.workingSetRef) + existingWs, err := startPoint.db.ResolveWorkingSet(ctx, workingSet.Ref()) if err == doltdb.ErrWorkingSetNotFound { // This is to handle the case where an existing DB pre working sets is committing to this HEAD for the // first time. Can be removed and called an error post 1.0 - existingWs = doltdb.EmptyWorkingSet(tx.workingSetRef) + existingWs = doltdb.EmptyWorkingSet(workingSet.Ref()) newWorkingSet = true } else if err != nil { return nil, nil, err @@ -341,7 +391,7 @@ func (tx *DoltTransaction) doCommit( return nil, nil, err } - if newWorkingSet || workingAndStagedEqual(existingWs, tx.startState) { + if newWorkingSet || workingAndStagedEqual(existingWs, startState) { // ff merge err = tx.validateWorkingSetForCommit(ctx, workingSet, isFfMerge) if err != nil { @@ -349,7 +399,7 @@ func (tx *DoltTransaction) doCommit( } var newCommit *doltdb.Commit - workingSet, newCommit, err = writeFn(ctx, tx, commit, workingSet, existingWSHash) + workingSet, newCommit, err = writeFn(ctx, tx, startPoint.db, startState, commit, workingSet, existingWSHash, mergeOpts) if err == datas.ErrOptimisticLockFailed { // this is effectively a `continue` in the loop return nil, nil, nil @@ -362,7 +412,7 @@ func (tx *DoltTransaction) doCommit( // otherwise (not a ff), merge the working sets together start := time.Now() - mergedWorkingSet, err := tx.mergeRoots(ctx, existingWs, workingSet) + mergedWorkingSet, err := tx.mergeRoots(ctx, startState, existingWs, workingSet, mergeOpts) if err != nil { return nil, nil, err } @@ -374,7 +424,7 @@ func (tx *DoltTransaction) doCommit( } var newCommit *doltdb.Commit - mergedWorkingSet, newCommit, err = writeFn(ctx, tx, commit, mergedWorkingSet, existingWSHash) + mergedWorkingSet, newCommit, err = writeFn(ctx, tx, startPoint.db, startState, commit, mergedWorkingSet, existingWSHash, mergeOpts) if err == datas.ErrOptimisticLockFailed { // this is effectively a `continue` in the loop return nil, nil, nil @@ -401,8 +451,10 @@ func (tx *DoltTransaction) doCommit( // Currently merges working and staged roots as necessary. HEAD root is only handled by the DoltCommit function. func (tx *DoltTransaction) mergeRoots( ctx *sql.Context, + startState *doltdb.WorkingSet, existingWorkingSet *doltdb.WorkingSet, workingSet *doltdb.WorkingSet, + mergeOpts editor.Options, ) (*doltdb.WorkingSet, error) { if !rootsEqual(existingWorkingSet.WorkingRoot(), workingSet.WorkingRoot()) { @@ -410,10 +462,10 @@ func (tx *DoltTransaction) mergeRoots( ctx, existingWorkingSet.WorkingRoot(), workingSet.WorkingRoot(), - tx.startState.WorkingRoot(), + startState.WorkingRoot(), workingSet, - tx.startState, - tx.mergeEditOpts, + startState, + mergeOpts, merge.MergeOpts{}) if err != nil { return nil, err @@ -426,10 +478,10 @@ func (tx *DoltTransaction) mergeRoots( ctx, existingWorkingSet.StagedRoot(), workingSet.StagedRoot(), - tx.startState.StagedRoot(), + startState.StagedRoot(), workingSet, - tx.startState, - tx.mergeEditOpts, + startState, + mergeOpts, merge.MergeOpts{}) if err != nil { return nil, err diff --git a/go/libraries/doltcore/sqle/dsess/variables.go b/go/libraries/doltcore/sqle/dsess/variables.go index 07b18431e7..033aace245 100644 --- a/go/libraries/doltcore/sqle/dsess/variables.go +++ b/go/libraries/doltcore/sqle/dsess/variables.go @@ -61,6 +61,8 @@ const URLTemplateDatabasePlaceholder = "{database}" // DefineSystemVariablesForDB defines per database dolt-session variables in the engine as necessary func DefineSystemVariablesForDB(name string) { + name, _ = SplitRevisionDbName(name) + if _, _, ok := sql.SystemVariables.GetGlobal(name + HeadKeySuffix); !ok { sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ { diff --git a/go/libraries/doltcore/sqle/dtables/ignore_table.go b/go/libraries/doltcore/sqle/dtables/ignore_table.go index d9ff8009c2..bc461f0fea 100644 --- a/go/libraries/doltcore/sqle/dtables/ignore_table.go +++ b/go/libraries/doltcore/sqle/dtables/ignore_table.go @@ -159,6 +159,7 @@ func (iw *ignoreWriter) StatementBegin(ctx *sql.Context) { dbName := ctx.GetCurrentDatabase() dSess := dsess.DSessFromSess(ctx.Session) + // TODO: this needs to use a revision qualified name roots, _ := dSess.GetRoots(ctx, dbName) dbState, ok, err := dSess.LookupDbState(ctx, dbName) if err != nil { @@ -227,7 +228,7 @@ func (iw *ignoreWriter) StatementBegin(ctx *sql.Context) { return } - if dbState.WorkingSet == nil { + if dbState.WorkingSet() == nil { iw.errDuringStatementBegin = doltdb.ErrOperationNotSupportedInDetachedHead return } @@ -235,7 +236,7 @@ func (iw *ignoreWriter) StatementBegin(ctx *sql.Context) { // We use WriteSession.SetWorkingSet instead of DoltSession.SetRoot because we want to avoid modifying the root // until the end of the transaction, but we still want the WriteSession to be able to find the newly // created table. - err = dbState.WriteSession.SetWorkingSet(ctx, dbState.WorkingSet.WithWorkingRoot(newRootValue)) + err = dbState.WriteSession().SetWorkingSet(ctx, dbState.WorkingSet().WithWorkingRoot(newRootValue)) if err != nil { iw.errDuringStatementBegin = err return @@ -244,7 +245,7 @@ func (iw *ignoreWriter) StatementBegin(ctx *sql.Context) { dSess.SetRoot(ctx, dbName, newRootValue) } - tableWriter, err := dbState.WriteSession.GetTableWriter(ctx, doltdb.IgnoreTableName, dbName, dSess.SetRoot, false) + tableWriter, err := dbState.WriteSession().GetTableWriter(ctx, doltdb.IgnoreTableName, dbName, dSess.SetRoot, false) if err != nil { iw.errDuringStatementBegin = err return diff --git a/go/libraries/doltcore/sqle/dtables/status_table.go b/go/libraries/doltcore/sqle/dtables/status_table.go index 50aa35d91b..fb6ccec869 100644 --- a/go/libraries/doltcore/sqle/dtables/status_table.go +++ b/go/libraries/doltcore/sqle/dtables/status_table.go @@ -32,7 +32,6 @@ type StatusTable struct { ddb *doltdb.DoltDB workingSet *doltdb.WorkingSet rootsProvider env.RootsProvider - dbName string } func (s StatusTable) Name() string { @@ -64,16 +63,15 @@ func (s StatusTable) PartitionRows(context *sql.Context, _ sql.Partition) (sql.R } // NewStatusTable creates a StatusTable -func NewStatusTable(_ *sql.Context, dbName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider) sql.Table { +func NewStatusTable(_ *sql.Context, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider) sql.Table { return &StatusTable{ ddb: ddb, - dbName: dbName, workingSet: ws, rootsProvider: rp, } } -// StatusIter is a sql.RowItr implementation which iterates over each commit as if it's a row in the table. +// StatusItr is a sql.RowIter implementation which iterates over each commit as if it's a row in the table. type StatusItr struct { rows []statusTableRow } diff --git a/go/libraries/doltcore/sqle/enginetest/branch_control_test.go b/go/libraries/doltcore/sqle/enginetest/branch_control_test.go index 37c65438fb..99b666361f 100644 --- a/go/libraries/doltcore/sqle/enginetest/branch_control_test.go +++ b/go/libraries/doltcore/sqle/enginetest/branch_control_test.go @@ -56,6 +56,7 @@ type BranchControlBlockTest struct { SetUpScript []string Query string ExpectedErr *errors.Kind + SkipMessage string } // TestUserSetUpScripts creates a user named "testuser@localhost", and grants them privileges on all databases and @@ -80,6 +81,11 @@ var BranchControlBlockTests = []BranchControlBlockTest{ Query: "INSERT INTO test VALUES (2, 2);", ExpectedErr: branch_control.ErrIncorrectPermissions, }, + { + Name: "INSERT on branch db", + Query: "INSERT INTO `mydb/other`.test VALUES (2, 2);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, { Name: "REPLACE", Query: "REPLACE INTO test VALUES (2, 2);", @@ -90,11 +96,21 @@ var BranchControlBlockTests = []BranchControlBlockTest{ Query: "UPDATE test SET pk = 2;", ExpectedErr: branch_control.ErrIncorrectPermissions, }, + { + Name: "UPDATE on branch db", + Query: "UPDATE `mydb/other`.test SET pk = 2;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, { Name: "DELETE", Query: "DELETE FROM test WHERE pk >= 0;", ExpectedErr: branch_control.ErrIncorrectPermissions, }, + { + Name: "DELETE from branch table", + Query: "DELETE FROM `mydb/other`.test WHERE pk >= 0;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, { Name: "TRUNCATE", Query: "TRUNCATE TABLE test;", @@ -367,6 +383,221 @@ var BranchControlBlockTests = []BranchControlBlockTest{ }, } +var BranchControlOtherDbBlockTests = []BranchControlBlockTest{ + { + Name: "INSERT", + Query: "INSERT INTO `mydb/other`.test VALUES (2, 2);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "REPLACE", + Query: "REPLACE INTO `mydb/other`.test VALUES (2, 2);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "UPDATE", + Query: "UPDATE `mydb/other`.test SET pk = 2;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "DELETE", + Query: "DELETE FROM `mydb/other`.test WHERE pk >= 0;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "TRUNCATE", + Query: "TRUNCATE TABLE `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE AUTO_INCREMENT", + SetUpScript: []string{ + "CREATE TABLE `mydb/other`.test2(pk BIGINT PRIMARY KEY AUTO_INCREMENT);", + }, + Query: "ALTER TABLE `mydb/other`.test2 AUTO_INCREMENT = 20;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ADD CHECK", + Query: "ALTER TABLE `mydb/other`.test ADD CONSTRAINT check_1 CHECK (pk > 0);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE DROP CHECK", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ADD CONSTRAINT check_1 CHECK (pk > 0);", + }, + Query: "ALTER TABLE `mydb/other`.test DROP CHECK check_1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ALTER COLUMN SET DEFAULT", + Query: "ALTER TABLE `mydb/other`.test ALTER COLUMN v1 SET DEFAULT (5);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ALTER COLUMN DROP DEFAULT", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ALTER COLUMN v1 SET DEFAULT (5);", + }, + Query: "ALTER TABLE `mydb/other`.test ALTER COLUMN v1 DROP DEFAULT;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ADD FOREIGN KEY", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ADD INDEX idx_v1 (v1);", + "CREATE TABLE `mydb/other`.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE);", + }, + Query: "ALTER TABLE `mydb/other`.test2 ADD CONSTRAINT fk_1 FOREIGN KEY (v1) REFERENCES `mydb/other`.test (v1);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE DROP FOREIGN KEY", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ADD INDEX idx_v1 (v1);", + "CREATE TABLE `mydb/other`.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT UNIQUE, CONSTRAINT fk_1 FOREIGN KEY (v1) REFERENCES `mydb/other`.test (v1));", + }, + Query: "ALTER TABLE `mydb/other`.test2 DROP FOREIGN KEY fk_1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ADD INDEX", + Query: "ALTER TABLE `mydb/other`.test ADD INDEX idx_v1 (v1);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE DROP INDEX", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ADD INDEX idx_v1 (v1);", + }, + Query: "ALTER TABLE `mydb/other`.test DROP INDEX idx_v1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE RENAME INDEX", + SetUpScript: []string{ + "ALTER TABLE `mydb/other`.test ADD INDEX idx_v1 (v1);", + }, + Query: "ALTER TABLE `mydb/other`.test RENAME INDEX idx_v1 TO idx_v1_new;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE ADD PRIMARY KEY", + SetUpScript: []string{ + "CREATE TABLE `mydb/other`.test2 (v1 BIGINT, v2 BIGINT);", + }, + Query: "ALTER TABLE `mydb/other`.test2 ADD PRIMARY KEY (v1, v2);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE DROP PRIMARY KEY", + Query: "ALTER TABLE `mydb/other`.test DROP PRIMARY KEY;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE RENAME", + Query: "ALTER TABLE `mydb/other`.test RENAME TO test_new;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "RENAME TABLE", + Query: "RENAME TABLE `mydb/other`.test TO test_new;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "ALTER TABLE ADD COLUMN", + Query: "ALTER TABLE `mydb/other`.test ADD COLUMN v2 BIGINT;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE DROP COLUMN", + Query: "ALTER TABLE `mydb/other`.test DROP COLUMN v1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE CHANGE COLUMN", + Query: "ALTER TABLE `mydb/other`.test CHANGE COLUMN v1 v1_new BIGINT;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE MODIFY COLUMN", + Query: "ALTER TABLE `mydb/other`.test MODIFY COLUMN v1 TINYINT;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "ALTER TABLE RENAME COLUMN", + Query: "ALTER TABLE `mydb/other`.test RENAME COLUMN v1 TO v1_new;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "CREATE INDEX", + Query: "CREATE INDEX idx_v1 ON `mydb/other`.test (v1);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "DROP INDEX", + SetUpScript: []string{ + "CREATE INDEX idx_v1 ON `mydb/other`.test (v1);", + }, + Query: "DROP INDEX idx_v1 ON `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "CREATE VIEW", + Query: "CREATE VIEW view_1 AS SELECT * FROM `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "DROP VIEW", + SetUpScript: []string{ + "CREATE VIEW view_1 AS SELECT * FROM `mydb/other`.test;", + }, + Query: "DROP VIEW view_1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "CREATE TRIGGER", + Query: "CREATE TRIGGER trigger_1 BEFORE INSERT ON `mydb/other`.test FOR EACH ROW SET NEW.v1 = 4;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "DROP TRIGGER", + SetUpScript: []string{ + "CREATE TRIGGER trigger_1 BEFORE INSERT ON `mydb/other`.test FOR EACH ROW SET NEW.v1 = 4;", + }, + Query: "DROP TRIGGER `mydb/other`.trigger_1;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "CREATE TABLE", + Query: "CREATE TABLE `mydb/other`.test2 (pk BIGINT PRIMARY KEY, v1 BIGINT);", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "CREATE TABLE LIKE", + Query: "CREATE TABLE `mydb/other`.test2 LIKE `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + SkipMessage: "https://github.com/dolthub/dolt/issues/6078", + }, + { + Name: "CREATE TABLE AS SELECT", + Query: "CREATE TABLE `mydb/other`.test2 AS SELECT * FROM `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, + { + Name: "DROP TABLE", + Query: "DROP TABLE `mydb/other`.test;", + ExpectedErr: branch_control.ErrIncorrectPermissions, + }, +} + var BranchControlTests = []BranchControlTest{ { Name: "Namespace entries block", @@ -1187,9 +1418,14 @@ func TestBranchControl(t *testing.T) { func TestBranchControlBlocks(t *testing.T) { for _, test := range BranchControlBlockTests { - harness := newDoltHarness(t) - defer harness.Close() t.Run(test.Name, func(t *testing.T) { + if test.SkipMessage != "" { + t.Skip(test.SkipMessage) + } + + harness := newDoltHarness(t) + defer harness.Close() + engine, err := harness.NewEngine(t) require.NoError(t, err) defer engine.Close() @@ -1211,9 +1447,59 @@ func TestBranchControlBlocks(t *testing.T) { Address: "localhost", }) enginetest.AssertErrWithCtx(t, engine, harness, userCtx, test.Query, test.ExpectedErr) + addUserQuery := "INSERT INTO dolt_branch_control VALUES ('%', 'main', 'testuser', 'localhost', 'write'), ('%', 'other', 'testuser', 'localhost', 'write');" addUserQueryResults := []sql.Row{{types.NewOkResult(2)}} enginetest.TestQueryWithContext(t, rootCtx, engine, harness, addUserQuery, addUserQueryResults, nil, nil) + + sch, iter, err := engine.Query(userCtx, test.Query) + if err == nil { + _, err = sql.RowIterToRows(userCtx, sch, iter) + } + assert.NoError(t, err) + }) + } + + // These tests are run with permission on main but not other + for _, test := range BranchControlOtherDbBlockTests { + t.Run("OtherDB_"+test.Name, func(t *testing.T) { + if test.SkipMessage != "" { + t.Skip(test.SkipMessage) + } + + harness := newDoltHarness(t) + defer harness.Close() + + engine, err := harness.NewEngine(t) + require.NoError(t, err) + defer engine.Close() + + rootCtx := enginetest.NewContext(harness) + rootCtx.NewCtxWithClient(sql.Client{ + User: "root", + Address: "localhost", + }) + engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{}) + + for _, statement := range append(TestUserSetUpScripts, test.SetUpScript...) { + enginetest.RunQueryWithContext(t, engine, harness, rootCtx, statement) + } + + addUserQuery := "INSERT INTO dolt_branch_control VALUES ('%', 'main', 'testuser', 'localhost', 'write');" + addUserQueryResults := []sql.Row{{types.NewOkResult(1)}} + enginetest.TestQueryWithContext(t, rootCtx, engine, harness, addUserQuery, addUserQueryResults, nil, nil) + + userCtx := enginetest.NewContextWithClient(harness, sql.Client{ + User: "testuser", + Address: "localhost", + }) + enginetest.AssertErrWithCtx(t, engine, harness, userCtx, test.Query, test.ExpectedErr) + + addUserQuery = "INSERT INTO dolt_branch_control VALUES ('%', 'other', 'testuser', 'localhost', 'write');" + addUserQueryResults = []sql.Row{{types.NewOkResult(1)}} + enginetest.TestQueryWithContext(t, rootCtx, engine, harness, addUserQuery, addUserQueryResults, nil, nil) + sch, iter, err := engine.Query(userCtx, test.Query) if err == nil { _, err = sql.RowIterToRows(userCtx, sch, iter) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_branch_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_branch_queries.go new file mode 100755 index 0000000000..119b2b86ce --- /dev/null +++ b/go/libraries/doltcore/sqle/enginetest/dolt_branch_queries.go @@ -0,0 +1,632 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package enginetest + +import ( + "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +var ForeignKeyBranchTests = []queries.ScriptTest{ + { + Name: "create fk on branch", + SetUpScript: []string{ + "call dolt_branch('b1')", + "use mydb/b1", + "ALTER TABLE child ADD CONSTRAINT fk_named FOREIGN KEY (v1) REFERENCES parent(v1);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "SHOW CREATE TABLE child;", + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`),\n" + + " KEY `v1` (`v1`),\n" + + " CONSTRAINT `fk_named` FOREIGN KEY (`v1`) REFERENCES `parent` (`v1`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into child values (1, 1, 1)", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "SHOW CREATE TABLE child;", + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into child values (1, 1, 1)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "insert into `mydb/b1`.child values (1, 1, 1)", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + }, + }, + { + Name: "create fk with branch checkout", + SetUpScript: []string{ + "call dolt_branch('b1')", + "call dolt_checkout('b1')", + "ALTER TABLE child ADD CONSTRAINT fk_named FOREIGN KEY (v1) REFERENCES parent(v1);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_checkout('b1')", + SkipResultsCheck: true, + }, + { + Query: "SHOW CREATE TABLE child;", + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`),\n" + + " KEY `v1` (`v1`),\n" + + " CONSTRAINT `fk_named` FOREIGN KEY (`v1`) REFERENCES `parent` (`v1`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into child values (1, 1, 1)", + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "call dolt_checkout('main')", + SkipResultsCheck: true, + }, + { + Query: "SHOW CREATE TABLE child;", + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into child values (1, 1, 1)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + }, + }, + { + Name: "create fk on branch not being used", + SetUpScript: []string{ + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "ALTER TABLE `mydb/b1`.child ADD CONSTRAINT fk_named FOREIGN KEY (v1) REFERENCES parent(v1);", + Skip: true, // Incorrectly flagged as a cross-DB foreign key relation + }, + { + Query: "SHOW CREATE TABLE `mydb/b1`.child;", + Skip: true, + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`),\n" + + " KEY `v1` (`v1`),\n" + + " CONSTRAINT `fk_named` FOREIGN KEY (`v1`) REFERENCES `parent` (`v1`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into `mydb/b1`.child values (1, 1, 1)", + Skip: true, + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "SHOW CREATE TABLE child;", + Skip: true, + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n" + + " `id` int NOT NULL,\n" + + " `v1` int,\n" + + " `v2` int,\n" + + " PRIMARY KEY (`id`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into child values (1, 1, 1)", + Skip: true, + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + }, + }, +} + +var ViewBranchTests = []queries.ScriptTest{ + { + Name: "create view on branch", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "use mydb/b1", + "create view v1 as select * from t1 where a > 2", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "select * from v1", + Expected: []sql.Row{{3, 3}}, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "select * from v1", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "select * from `mydb/b1`.v1", + Expected: []sql.Row{{3, 3}}, + }, + }, + }, + { + Name: "create view on different branch", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "create view `mydb/b1`.v1 as select * from t1 where a > 2", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "select * from v1", + Expected: []sql.Row{{3, 3}}, + Skip: true, // https://github.com/dolthub/dolt/issues/6078 + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "select * from v1", + ExpectedErr: sql.ErrTableNotFound, + Skip: true, // https://github.com/dolthub/dolt/issues/6078 + }, + { + Query: "select * from `mydb/b1`.v1", + Expected: []sql.Row{{3, 3}}, + Skip: true, // https://github.com/dolthub/dolt/issues/6078 + }, + }, + }, +} + +var DdlBranchTests = []queries.ScriptTest{ + { + Name: "create table on branch", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "use mydb/b1", + "create table t2 (a int primary key, b int)", + "insert into t2 values (4, 4)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + Expected: []sql.Row{{4, 4}}, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "select * from `mydb/b1`.t2", + Expected: []sql.Row{{4, 4}}, + }, + }, + }, + { + Name: "create table on different branch", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "create table `mydb/b1`.t2 (a int primary key, b int)", + "insert into `mydb/b1`.t2 values (4,4)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + Expected: []sql.Row{{4, 4}}, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "select * from `mydb/b1`.t2", + Expected: []sql.Row{{4, 4}}, + }, + }, + }, + { + Name: "create table on different branch, autocommit off", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "set autocommit = off", + "create table `mydb/b1`.t2 (a int primary key, b int)", + "insert into `mydb/b1`.t2 values (4,4)", + "commit", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + Expected: []sql.Row{{4, 4}}, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "select * from t2", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "select * from `mydb/b1`.t2", + Expected: []sql.Row{{4, 4}}, + }, + }, + }, + { + Name: "alter table on different branch, add column", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "alter table `mydb/b1`.t1 add column c int", + Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, + }, + { + Query: "select * from `mydb/b1`.t1", + Expected: []sql.Row{{1, 1, nil}, {2, 2, nil}, {3, 3, nil}}, + }, + { + Query: "select * from t1", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 3}}, + }, + }, + }, + { + Name: "alter table on different branch, drop column", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "alter table `mydb/b1`.t1 drop column b", + Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, + }, + { + Query: "select * from `mydb/b1`.t1", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select * from t1", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 3}}, + }, + }, + }, + { + Name: "alter table on different branch, modify column", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "alter table `mydb/b1`.t1 modify column b varchar(1) first", + Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, + }, + { + Query: "select * from `mydb/b1`.t1", + Expected: []sql.Row{{"1", 1}, {"2", 2}, {"3", 3}}, + }, + { + Query: "select * from t1", + Expected: []sql.Row{{1, 1}, {2, 2}, {3, 3}}, + }, + }, + }, + { + Name: "alter table on different branch, create and drop index", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "create index idx on `mydb/b1`.t1 (b)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, + }, + { + Query: "show create table `mydb/b1`.t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`),\n" + + " KEY `idx` (`b`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "alter table `mydb/b1`.t1 drop index idx", + Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}}, + }, + { + Query: "show create table `mydb/b1`.t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, + { + Name: "alter table on different branch, add and drop constraint", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "alter table `mydb/b1`.t1 add constraint chk1 check (b < 4)", + Expected: []sql.Row{}, + }, + { + Query: "show create table `mydb/b1`.t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`),\n" + + " CONSTRAINT `chk1` CHECK ((`b` < 4))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into `mydb/b1`.t1 values (4, 4)", + ExpectedErr: sql.ErrCheckConstraintViolated, + }, + { + Query: "show create table t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + Query: "insert into t1 values (4, 4)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "alter table `mydb/b1`.t1 drop constraint chk1", + Expected: []sql.Row{}, + }, + { + Query: "show create table `mydb/b1`.t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n" + + " `a` int NOT NULL,\n" + + " `b` int,\n" + + " PRIMARY KEY (`a`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, +} + +var BranchPlanTests = []queries.ScriptTest{ + { + Name: "use index on branch database", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "use mydb/b1", + "create index idx on t1 (b)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "explain select * from t1 where b = 1", + Expected: []sql.Row{ + {"IndexedTableAccess(t1)"}, + {" ├─ index: [t1.b]"}, + {" ├─ filters: [{[1, 1]}]"}, + {" └─ columns: [a b]"}, + }, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "explain select * from `mydb/b1`.t1 where b = 1", + Expected: []sql.Row{ + {"IndexedTableAccess(t1)"}, + {" ├─ index: [t1.b]"}, + {" ├─ filters: [{[1, 1]}]"}, + {" └─ columns: [a b]"}, + }, + }, + }, + }, + { + Name: "use index on branch database join", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "insert into t1 values (1, 1), (2, 2), (3, 3)", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "use mydb/b1", + "create index idx on t1 (b)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "explain select * from t1 t1a join t1 t1b on t1a.b = t1b.b order by 1", + Expected: []sql.Row{ + {"Sort(t1a.a ASC)"}, + {" └─ Project"}, + {" ├─ columns: [t1a.a, t1a.b, t1b.a, t1b.b]"}, + {" └─ MergeJoin"}, + {" ├─ cmp: (t1b.b = t1a.b)"}, + {" ├─ TableAlias(t1b)"}, + {" │ └─ IndexedTableAccess(t1)"}, + {" │ ├─ index: [t1.b]"}, + {" │ ├─ filters: [{[NULL, ∞)}]"}, + {" │ └─ columns: [a b]"}, + {" └─ TableAlias(t1a)"}, + {" └─ IndexedTableAccess(t1)"}, + {" ├─ index: [t1.b]"}, + {" ├─ filters: [{[NULL, ∞)}]"}, + {" └─ columns: [a b]"}, + }, + }, + { + Query: "explain select * from `mydb/main`.t1 t1a join `mydb/main`.t1 t1b on t1a.b = t1b.b order by 1", + Expected: []sql.Row{ + {"Sort(t1a.a ASC)"}, + {" └─ InnerJoin"}, + {" ├─ (t1a.b = t1b.b)"}, + {" ├─ TableAlias(t1a)"}, + {" │ └─ Table"}, + {" │ ├─ name: t1"}, + {" │ └─ columns: [a b]"}, + {" └─ TableAlias(t1b)"}, + {" └─ Table"}, + {" ├─ name: t1"}, + {" └─ columns: [a b]"}, + }, + }, + { + Query: "use mydb/main", + SkipResultsCheck: true, + }, + { + Query: "explain select * from t1 t1a join t1 t1b on t1a.b = t1b.b order by 1", + Expected: []sql.Row{ + {"Sort(t1a.a ASC)"}, + {" └─ InnerJoin"}, + {" ├─ (t1a.b = t1b.b)"}, + {" ├─ TableAlias(t1a)"}, + {" │ └─ Table"}, + {" │ ├─ name: t1"}, + {" │ └─ columns: [a b]"}, + {" └─ TableAlias(t1b)"}, + {" └─ Table"}, + {" ├─ name: t1"}, + {" └─ columns: [a b]"}, + }, + }, + { + Query: "explain select * from `mydb/b1`.t1 t1a join `mydb/b1`.t1 t1b on t1a.b = t1b.b order by 1", + Expected: []sql.Row{ + {"Sort(t1a.a ASC)"}, + {" └─ Project"}, + {" ├─ columns: [t1a.a, t1a.b, t1b.a, t1b.b]"}, + {" └─ MergeJoin"}, + {" ├─ cmp: (t1b.b = t1a.b)"}, + {" ├─ TableAlias(t1b)"}, + {" │ └─ IndexedTableAccess(t1)"}, + {" │ ├─ index: [t1.b]"}, + {" │ ├─ filters: [{[NULL, ∞)}]"}, + {" │ └─ columns: [a b]"}, + {" └─ TableAlias(t1a)"}, + {" └─ IndexedTableAccess(t1)"}, + {" ├─ index: [t1.b]"}, + {" ├─ filters: [{[NULL, ∞)}]"}, + {" └─ columns: [a b]"}, + }, + }, + }, + }, +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 75d41206bb..6014d6a12d 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -49,7 +49,7 @@ var skipPrepared bool // SkipPreparedsCount is used by the "ci-check-repo CI workflow // as a reminder to consider prepareds when adding a new // enginetest suite. -const SkipPreparedsCount = 84 +const SkipPreparedsCount = 83 const skipPreparedFlag = "DOLT_SKIP_PREPARED_ENGINETESTS" @@ -116,45 +116,51 @@ func TestSingleQuery(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { t.Skip() + var scripts = []queries.ScriptTest{ { - Name: "parallel column updates (repro issue #4547)", + Name: "ALTER TABLE RENAME COLUMN", SetUpScript: []string{ - "SET dolt_allow_commit_conflicts = on;", - "create table t (rowId int not null, col1 varchar(255), col2 varchar(255), keyCol varchar(60), dataA varchar(255), dataB varchar(255), PRIMARY KEY (rowId), UNIQUE KEY uniqKey (col1, col2, keyCol));", - "insert into t (rowId, col1, col2, keyCol, dataA, dataB) values (1, '1', '2', 'key-a', 'test1', 'test2')", - "CALL DOLT_COMMIT('-Am', 'new table');", - - "CALL DOLT_CHECKOUT('-b', 'other');", - "update t set dataA = 'other'", - "CALL DOLT_COMMIT('-am', 'update data other');", - - "CALL DOLT_CHECKOUT('main');", - "update t set dataB = 'main'", - "CALL DOLT_COMMIT('-am', 'update on main');", + "ALTER TABLE child ADD CONSTRAINT fk1 FOREIGN KEY (v1) REFERENCES parent(v1);", + "ALTER TABLE parent RENAME COLUMN v1 TO v1_new;", + "ALTER TABLE child RENAME COLUMN v1 TO v1_new;", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "CALL DOLT_MERGE('other')", - Expected: []sql.Row{{"child", uint64(1)}}, + Query: "SHOW CREATE TABLE child;", + Expected: []sql.Row{{"child", "CREATE TABLE `child` (\n `id` int NOT NULL,\n `v1_new` int,\n `v2` int,\n PRIMARY KEY (`id`),\n KEY `v1` (`v1_new`),\n CONSTRAINT `fk1` FOREIGN KEY (`v1_new`) REFERENCES `parent` (`v1_new`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, + { + Name: "ALTER TABLE MODIFY COLUMN type change not allowed", + SetUpScript: []string{ + "ALTER TABLE child ADD CONSTRAINT fk1 FOREIGN KEY (v1) REFERENCES parent(v1);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "ALTER TABLE parent MODIFY v1 MEDIUMINT;", + ExpectedErr: sql.ErrForeignKeyTypeChange, }, { - Query: "SELECT * from dolt_constraint_violations_t", - Expected: []sql.Row{}, - }, - { - Query: "SELECT * from t", - Expected: []sql.Row{ - {1, "1", "2", "key-a", "other", "main"}, - }, + Query: "ALTER TABLE child MODIFY v1 MEDIUMINT;", + ExpectedErr: sql.ErrForeignKeyTypeChange, }, }, }, } + tcc := &testCommitClock{} + cleanup := installTestCommitClock(tcc) + defer cleanup() + harness := newDoltHarness(t) - for _, test := range scripts { - enginetest.TestScript(t, harness, test) + harness.Setup(setup.MydbData, setup.Parent_childData) + for _, script := range scripts { + sql.RunWithNowFunc(tcc.Now, func() error { + enginetest.TestScript(t, harness, script) + return nil + }) } } @@ -250,24 +256,52 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScriptPrepared(t *testing.T) { t.Skip() + var script = queries.ScriptTest{ - Name: "table with commit column should maintain its data in diff", + Name: "dolt_history table filter correctness", SetUpScript: []string{ - "CREATE TABLE t (pk int PRIMARY KEY, commit varchar(20));", - "CALL DOLT_ADD('.');", - "CALL dolt_commit('-am', 'creating table t');", - "INSERT INTO t VALUES (1, '123456');", - "CALL dolt_commit('-am', 'insert data');", + "create table xy (x int primary key, y int);", + "call dolt_add('.');", + "call dolt_commit('-m', 'creating table');", + "insert into xy values (0, 1);", + "call dolt_commit('-am', 'add data');", + "insert into xy values (2, 3);", + "call dolt_commit('-am', 'add data');", + "insert into xy values (4, 5);", + "call dolt_commit('-am', 'add data');", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "SELECT to_pk, char_length(to_commit), from_pk, char_length(from_commit), diff_type from dolt_diff_t;", - Expected: []sql.Row{{1, 32, nil, 32, "added"}}, + Query: "select * from dolt_history_xy where commit_hash = (select dolt_log.commit_hash from dolt_log limit 1 offset 1) order by 1", + Expected: []sql.Row{ + sql.Row{0, 1, "itt2nrlkbl7jis4gt9aov2l32ctt08th", "billy bob", time.Date(1970, time.January, 1, 19, 0, 0, 0, time.Local)}, + sql.Row{2, 3, "itt2nrlkbl7jis4gt9aov2l32ctt08th", "billy bob", time.Date(1970, time.January, 1, 19, 0, 0, 0, time.Local)}, + }, + }, + { + Query: "select count(*) from dolt_history_xy where commit_hash = (select dolt_log.commit_hash from dolt_log limit 1 offset 1)", + Expected: []sql.Row{ + {2}, + }, + }, + { + Query: "select count(*) from dolt_history_xy where commit_hash = 'itt2nrlkbl7jis4gt9aov2l32ctt08th'", + Expected: []sql.Row{ + {2}, + }, }, }, } - harness := newDoltHarness(t) - enginetest.TestScriptPrepared(t, harness, script) + + tcc := &testCommitClock{} + cleanup := installTestCommitClock(tcc) + defer cleanup() + + sql.RunWithNowFunc(tcc.Now, func() error { + harness := newDoltHarness(t) + enginetest.TestScriptPrepared(t, harness, script) + return nil + }) } func TestVersionedQueries(t *testing.T) { @@ -340,6 +374,16 @@ func TestDoltDiffQueryPlans(t *testing.T) { } } +func TestBranchPlans(t *testing.T) { + for _, script := range BranchPlanTests { + func() { + harness := newDoltHarness(t).WithParallelism(1) + defer harness.Close() + enginetest.TestScript(t, harness, script) + }() + } +} + func TestQueryErrors(t *testing.T) { h := newDoltHarness(t) defer h.Close() @@ -350,6 +394,14 @@ func TestInfoSchema(t *testing.T) { h := newDoltHarness(t) defer h.Close() enginetest.TestInfoSchema(t, h) + + for _, script := range DoltInfoSchemaScripts { + func() { + harness := newDoltHarness(t) + defer harness.Close() + enginetest.TestScript(t, harness, script) + }() + } } func TestColumnAliases(t *testing.T) { @@ -708,6 +760,26 @@ func TestCreateTable(t *testing.T) { enginetest.TestCreateTable(t, h) } +func TestBranchDdl(t *testing.T) { + for _, script := range DdlBranchTests { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScript(t, h, script) + }() + } +} + +func TestBranchDdlPrepared(t *testing.T) { + for _, script := range DdlBranchTests { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } +} + func TestPkOrdinalsDDL(t *testing.T) { h := newDoltHarness(t) defer h.Close() @@ -855,6 +927,64 @@ func TestForeignKeys(t *testing.T) { enginetest.TestForeignKeys(t, h) } +func TestForeignKeyBranches(t *testing.T) { + setupPrefix := []string{ + "call dolt_branch('b1')", + "use mydb/b1", + } + assertionsPrefix := []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + } + for _, script := range queries.ForeignKeyTests { + // New harness for every script because we create branches + h := newDoltHarness(t) + h.Setup(setup.MydbData, setup.Parent_childData) + modifiedScript := script + modifiedScript.SetUpScript = append(setupPrefix, modifiedScript.SetUpScript...) + modifiedScript.Assertions = append(assertionsPrefix, modifiedScript.Assertions...) + enginetest.TestScript(t, h, modifiedScript) + } + + for _, script := range ForeignKeyBranchTests { + // New harness for every script because we create branches + h := newDoltHarness(t) + h.Setup(setup.MydbData, setup.Parent_childData) + enginetest.TestScript(t, h, script) + } +} + +func TestForeignKeyBranchesPrepared(t *testing.T) { + setupPrefix := []string{ + "call dolt_branch('b1')", + "use mydb/b1", + } + assertionsPrefix := []queries.ScriptTestAssertion{ + { + Query: "use mydb/b1", + SkipResultsCheck: true, + }, + } + for _, script := range queries.ForeignKeyTests { + // New harness for every script because we create branches + h := newDoltHarness(t) + h.Setup(setup.MydbData, setup.Parent_childData) + modifiedScript := script + modifiedScript.SetUpScript = append(setupPrefix, modifiedScript.SetUpScript...) + modifiedScript.Assertions = append(assertionsPrefix, modifiedScript.Assertions...) + enginetest.TestScriptPrepared(t, h, modifiedScript) + } + + for _, script := range ForeignKeyBranchTests { + // New harness for every script because we create branches + h := newDoltHarness(t) + h.Setup(setup.MydbData, setup.Parent_childData) + enginetest.TestScriptPrepared(t, h, script) + } +} + func TestCreateCheckConstraints(t *testing.T) { h := newDoltHarness(t) defer h.Close() @@ -897,6 +1027,26 @@ func TestViews(t *testing.T) { enginetest.TestViews(t, h) } +func TestBranchViews(t *testing.T) { + for _, script := range ViewBranchTests { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScript(t, h, script) + }() + } +} + +func TestBranchViewsPrepared(t *testing.T) { + for _, script := range ViewBranchTests { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } +} + func TestVersionedViews(t *testing.T) { h := newDoltHarness(t) defer h.Close() @@ -1110,7 +1260,6 @@ func TestBranchTransactions(t *testing.T) { } func TestMultiDbTransactions(t *testing.T) { - t.Skip() for _, script := range MultiDbTransactionTests { func() { h := newDoltHarness(t) @@ -1120,6 +1269,16 @@ func TestMultiDbTransactions(t *testing.T) { } } +func TestMultiDbTransactionsPrepared(t *testing.T) { + for _, script := range MultiDbTransactionTests { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } +} + func TestConcurrentTransactions(t *testing.T) { h := newDoltHarness(t) defer h.Close() @@ -1194,7 +1353,7 @@ func TestDoltRevisionDbScripts(t *testing.T) { }, { Query: "show databases;", - Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mydb/" + commithash}, {"mysql"}}, + Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mysql"}}, }, { Query: "select * from t01", @@ -1320,7 +1479,8 @@ func TestViewsWithAsOfPrepared(t *testing.T) { func TestDoltMerge(t *testing.T) { for _, script := range MergeScripts { - // dolt versioning conflicts with reset harness -- use new harness every time + // harness can't reset effectively when there are new commits / branches created, so use a new harness for + // each script func() { h := newDoltHarness(t).WithParallelism(1) defer h.Close() @@ -1339,6 +1499,28 @@ func TestDoltMerge(t *testing.T) { } } +func TestDoltMergePrepared(t *testing.T) { + for _, script := range MergeScripts { + // harness can't reset effectively when there are new commits / branches created, so use a new harness for + // each script + func() { + h := newDoltHarness(t).WithParallelism(1) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } + + if types.IsFormat_DOLT(types.Format_Default) { + for _, script := range Dolt1MergeScripts { + func() { + h := newDoltHarness(t).WithParallelism(1) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } + } +} + func TestDoltAutoIncrement(t *testing.T) { for _, script := range DoltAutoIncrementTests { // doing commits on different branches is antagonistic to engine reuse, use a new engine on each script @@ -1469,6 +1651,26 @@ func TestDoltGC(t *testing.T) { } } +func TestDoltCheckout(t *testing.T) { + for _, script := range DoltCheckoutScripts { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScript(t, h, script) + }() + } +} + +func TestDoltCheckoutPrepared(t *testing.T) { + for _, script := range DoltCheckoutScripts { + func() { + h := newDoltHarness(t) + defer h.Close() + enginetest.TestScriptPrepared(t, h, script) + }() + } +} + func TestDoltBranch(t *testing.T) { for _, script := range DoltBranchScripts { func() { @@ -1532,112 +1734,65 @@ func TestSingleTransactionScript(t *testing.T) { sql.RunWithNowFunc(tcc.Now, func() error { script := queries.TransactionTest{ - Name: "committed conflicts are seen by other sessions", + Name: "READ ONLY Transactions", SetUpScript: []string{ - "CREATE TABLE test (pk int primary key, val int)", - "CALL DOLT_ADD('.')", - "INSERT INTO test VALUES (0, 0)", - "CALL DOLT_COMMIT('-a', '-m', 'Step 1');", - "CALL DOLT_CHECKOUT('-b', 'feature-branch')", - "INSERT INTO test VALUES (1, 1);", - "UPDATE test SET val=1000 WHERE pk=0;", - "CALL DOLT_COMMIT('-a', '-m', 'this is a normal commit');", - "CALL DOLT_CHECKOUT('main');", - "UPDATE test SET val=1001 WHERE pk=0;", - "CALL DOLT_COMMIT('-a', '-m', 'update a value');", + "create table t2 (pk int primary key, val int)", + "insert into t2 values (0,0)", + "commit", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "/* client a */ start transaction", + Query: "/* client a */ set autocommit = off", + Expected: []sql.Row{{}}, + }, + { + Query: "/* client a */ create temporary table tmp(pk int primary key)", + Expected: []sql.Row{{gmstypes.NewOkResult(0)}}, + }, + { + Query: "/* client a */ START TRANSACTION READ ONLY", Expected: []sql.Row{}, }, { - Query: "/* client b */ start transaction", - Expected: []sql.Row{}, + Query: "/* client a */ INSERT INTO tmp VALUES (1)", + Expected: []sql.Row{{gmstypes.NewOkResult(1)}}, + }, + { + Query: "/* client a */ insert into t2 values (1, 1)", + ExpectedErr: sql.ErrReadOnlyTransaction, + }, + { + Query: "/* client a */ insert into t2 values (2, 2)", + ExpectedErr: sql.ErrReadOnlyTransaction, + }, + { + Query: "/* client a */ delete from t2 where pk = 0", + ExpectedErr: sql.ErrReadOnlyTransaction, }, { - Query: "/* client a */ select * from dolt_log order by date", - Expected: - // existing transaction logic - []sql.Row{ - {"j131v1r3cf6mrdjjjuqgkv4t33oa0l54", "billy bob", "bigbillieb@fake.horse", time.Date(1969, time.December, 31, 21, 0, 0, 0, time.Local), "Initialize data repository"}, - {"kcg4345ir3tjfb13mr0on1bv1m56h9if", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 1, 4, 0, 0, 0, time.Local), "checkpoint enginetest database mydb"}, - {"9jtjpggd4t5nso3mefilbde3tkfosdna", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 1, 12, 0, 0, 0, time.Local), "Step 1"}, - {"559f6kdh0mm5i1o40hs3t8dr43bkerav", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 2, 3, 0, 0, 0, time.Local), "update a value"}, - }, - // new tx logic - // []sql.Row{ - // sql.Row{"j131v1r3cf6mrdjjjuqgkv4t33oa0l54", "billy bob", "bigbillieb@fake.horse", time.Date(1969, time.December, 31, 21, 0, 0, 0, time.Local), "Initialize data repository"}, - // sql.Row{"kcg4345ir3tjfb13mr0on1bv1m56h9if", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 1, 4, 0, 0, 0, time.Local), "checkpoint enginetest database mydb"}, - // sql.Row{"pifio95ccefa03qstm1g3s1sivj1sm1d", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 1, 11, 0, 0, 0, time.Local), "Step 1"}, - // sql.Row{"rdrgqfcml1hfgj8clr0caabgu014v2g9", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 1, 20, 0, 0, 0, time.Local), "this is a normal commit"}, - // sql.Row{"shhv61eiefo9c4m9lvo5bt23i3om1ft4", "billy bob", "bigbillieb@fake.horse", time.Date(1970, time.January, 2, 2, 0, 0, 0, time.Local), "update a value"}, - // }, + Query: "/* client a */ alter table t2 add val2 int", + Expected: []sql.Row{{gmstypes.NewOkResult(0)}}, }, { - Query: "/* client a */ CALL DOLT_MERGE('feature-branch')", - Expected: []sql.Row{{0, 1}}, + Query: "/* client a */ select * from t2", + Expected: []sql.Row{{0, 0, nil}}, }, { - Query: "/* client a */ SELECT count(*) from dolt_conflicts_test", - Expected: []sql.Row{{1}}, + Query: "/* client a */ create temporary table tmp2(pk int primary key)", + ExpectedErr: sql.ErrReadOnlyTransaction, }, { - Query: "/* client b */ SELECT count(*) from dolt_conflicts_test", - Expected: []sql.Row{{0}}, - }, - { - Query: "/* client a */ set dolt_allow_commit_conflicts = 1", - Expected: []sql.Row{{}}, - }, - { - Query: "/* client a */ commit", + Query: "/* client a */ COMMIT", Expected: []sql.Row{}, }, { - Query: "/* client b */ start transaction", + Query: "/* client b */ START TRANSACTION READ ONLY", Expected: []sql.Row{}, }, { - Query: "/* client b */ SELECT count(*) from dolt_conflicts_test", - Expected: []sql.Row{{1}}, - }, - { - Query: "/* client a */ start transaction", - Expected: []sql.Row{}, - }, - { - Query: "/* client a */ CALL DOLT_MERGE('--abort')", - Expected: []sql.Row{{0, 0}}, - }, - { - Query: "/* client a */ commit", - Expected: []sql.Row{}, - }, - { - Query: "/* client b */ start transaction", - Expected: []sql.Row{}, - }, - { - Query: "/* client a */ SET @@dolt_allow_commit_conflicts = 0", - Expected: []sql.Row{{}}, - }, - { - Query: "/* client a */ CALL DOLT_MERGE('feature-branch')", - ExpectedErrStr: dsess.ErrUnresolvedConflictsCommit.Error(), - }, - { // client rolled back on merge with conflicts - Query: "/* client a */ SELECT count(*) from dolt_conflicts_test", - Expected: []sql.Row{{0}}, - }, - { - Query: "/* client a */ commit", - Expected: []sql.Row{}, - }, - { - Query: "/* client b */ SELECT count(*) from dolt_conflicts_test", - Expected: []sql.Row{{0}}, + Query: "/* client b */ SELECT * FROM t2", + Expected: []sql.Row{{0, 0, nil}}, }, }, } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index b429c70ce2..1499852ba1 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -163,7 +163,8 @@ func commitScripts(dbs []string) []setup.SetupScript { // NewEngine creates a new *gms.Engine or calls reset and clear scripts on the existing // engine for reuse. func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) { - if d.engine == nil { + initializeEngine := d.engine == nil + if initializeEngine { d.branchControl = branch_control.CreateDefaultController() pro := d.newProvider() @@ -172,7 +173,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) { d.provider = doltProvider var err error - d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), doltProvider, d.multiRepoEnv.Config(), d.branchControl) + d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl) require.NoError(t, err) e, err := enginetest.NewEngine(t, d, d.provider, d.setupData) @@ -203,6 +204,13 @@ func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) { d.engine.Analyzer.Catalog.MySQLDb = mysql_db.CreateEmptyMySQLDb() d.engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + // Get a fresh session if we are reusing the engine + if !initializeEngine { + var err error + d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl) + require.NoError(t, err) + } + ctx := enginetest.NewContext(d) e, err := enginetest.RunSetupScripts(ctx, d.engine, d.resetScripts(), d.SupportsNativeIndexCreation()) @@ -356,6 +364,10 @@ func (d *DoltHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (*gms.Eng return nil, err } + // reset the session as well since we have swapped out the database provider, which invalidates caching assumptions + d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), readOnlyProvider, d.multiRepoEnv.Config(), d.branchControl) + require.NoError(d.t, err) + return enginetest.NewEngineWithProvider(nil, d, readOnlyProvider), nil } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 9e631c5ac3..81b59e9492 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -313,9 +313,14 @@ var DoltRevisionDbScripts = []queries.ScriptTest{ Expected: []sql.Row{}, }, { - // The database name should be the revision spec we started with, not its resolved hash - Query: "select database() regexp '^mydb/[0-9a-v]{32}$', database() = 'mydb/tag1~';", - Expected: []sql.Row{{false, true}}, + // The database name is always the base name, never the revision specifier + Query: "select database()", + Expected: []sql.Row{{"mydb/tag1~"}}, + }, + { + // The branch is nil in the case of a non-branch revision DB + Query: "select active_branch()", + Expected: []sql.Row{{nil}}, }, { Query: "select * from t01;", @@ -386,12 +391,18 @@ var DoltRevisionDbScripts = []queries.ScriptTest{ Expected: []sql.Row{}, }, { - Query: "select database();", + // The database name is always the base name, never the revision specifier + Query: "select database()", Expected: []sql.Row{{"mydb/tag1"}}, }, + { + // The branch is nil in the case of a non-branch revision DB + Query: "select active_branch()", + Expected: []sql.Row{{nil}}, + }, { Query: "show databases;", - Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mydb/tag1"}, {"mysql"}}, + Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mysql"}}, }, { Query: "select * from t01;", @@ -446,12 +457,17 @@ var DoltRevisionDbScripts = []queries.ScriptTest{ }, { Query: "show databases;", - Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mydb/branch1"}, {"mysql"}}, + Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mysql"}}, }, { - Query: "select database();", + // The database name is always the base name, never the revision specifier + Query: "select database()", Expected: []sql.Row{{"mydb/branch1"}}, }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"branch1"}}, + }, { Query: "select * from t01", Expected: []sql.Row{{1, 1}, {2, 2}}, @@ -482,14 +498,14 @@ var DoltRevisionDbScripts = []queries.ScriptTest{ }, { Query: "show databases;", - Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mydb/branch1"}, {"mysql"}}, + Expected: []sql.Row{{"mydb"}, {"information_schema"}, {"mysql"}}, }, { + // Create a table in the working set to verify the main db Query: "create table working_set_table(pk int primary key);", Expected: []sql.Row{{types.NewOkResult(0)}}, }, { - // Create a table in the working set to verify the main db Query: "select table_name from dolt_diff where commit_hash='WORKING';", Expected: []sql.Row{{"working_set_table"}}, }, @@ -1930,6 +1946,277 @@ var BrokenHistorySystemTableScriptTests = []queries.ScriptTest{ }, } +var DoltCheckoutScripts = []queries.ScriptTest{ + { + Name: "dolt_checkout changes working set", + SetUpScript: []string{ + "create table t (a int primary key, b int);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('b2');", + "call dolt_branch('b3');", + "insert into t values (1, 1);", + "call dolt_commit('-Am', 'added values on main');", + "call dolt_checkout('b2');", + "insert into t values (2, 2);", + "call dolt_commit('-am', 'added values on b2');", + "call dolt_checkout('b3');", + "insert into t values (3, 3);", + "call dolt_commit('-am', 'added values on b3');", + "call dolt_checkout('main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{1, 1}}, + }, + { + Query: "call dolt_checkout('b2');", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b2"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{2, 2}}, + }, + { + Query: "call dolt_checkout('b3');", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b3"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{3, 3}}, + }, + { + Query: "call dolt_checkout('main');", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{1, 1}}, + }, + }, + }, + { + Name: "dolt_checkout mixed with USE statements", + SetUpScript: []string{ + "create table t (a int primary key, b int);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('b2');", + "call dolt_branch('b3');", + "insert into t values (1, 1);", + "call dolt_commit('-Am', 'added values on main');", + "call dolt_checkout('b2');", + "insert into t values (2, 2);", + "call dolt_commit('-am', 'added values on b2');", + "call dolt_checkout('b3');", + "insert into t values (3, 3);", + "call dolt_commit('-am', 'added values on b3');", + "call dolt_checkout('main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{1, 1}}, + }, + { + Query: "use `mydb/b2`;", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b2"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{2, 2}}, + }, + { + Query: "use `mydb/b3`;", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b3"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{3, 3}}, + }, + { + Query: "use `mydb/main`", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{1, 1}}, + }, + { + Query: "use `mydb`", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{1, 1}}, + }, + { + Query: "call dolt_checkout('b2');", + SkipResultsCheck: true, + }, + { + Query: "use `mydb/b3`", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b3"}}, + }, + // Since b2 was the last branch checked out with dolt_checkout, it's what mydb resolves to + { + Query: "use `mydb`", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b2"}}, + }, + { + Query: "select * from t;", + Expected: []sql.Row{{2, 2}}, + }, + }, + }, +} + +var DoltInfoSchemaScripts = []queries.ScriptTest{ + { + Name: "info_schema changes with dolt_checkout", + SetUpScript: []string{ + "create table t (a int primary key, b int);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('b2');", + "call dolt_branch('b3');", + "call dolt_checkout('b2');", + "alter table t add column c int;", + "call dolt_commit('-am', 'added column c on branch b2');", + "call dolt_checkout('b3');", + "alter table t add column d int;", + "call dolt_commit('-am', 'added column d on branch b3');", + "call dolt_checkout('main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}}, + }, + { + Query: "call dolt_checkout('b2');", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b2"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}, {"c"}}, + }, + { + Query: "call dolt_checkout('b3');", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b3"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}, {"d"}}, + }, + }, + }, + { + Name: "info_schema changes with USE", + SetUpScript: []string{ + "create table t (a int primary key, b int);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('b2');", + "call dolt_branch('b3');", + "call dolt_checkout('b2');", + "alter table t add column c int;", + "call dolt_commit('-am', 'added column c on branch b2');", + "call dolt_checkout('b3');", + "alter table t add column d int;", + "call dolt_commit('-am', 'added column d on branch b3');", + "use mydb/main;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select active_branch();", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}}, + }, + { + Query: "use mydb/b2;", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b2"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}, {"c"}}, + }, + { + Query: "use mydb/b3;", + SkipResultsCheck: true, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"b3"}}, + }, + { + Query: "select column_name from information_schema.columns where table_schema = 'mydb' and table_name = 't' order by 1;", + Expected: []sql.Row{{"a"}, {"b"}, {"d"}}, + }, + }, + }, +} + var DoltBranchScripts = []queries.ScriptTest{ { Name: "Create branches from HEAD with dolt_branch procedure", diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go index 9027c5f7b2..d330b2f2b3 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go @@ -89,6 +89,44 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "CALL DOLT_MERGE ff correctly works with autocommit off, no checkout", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key)", + "call DOLT_ADD('.')", + "INSERT INTO test VALUES (0),(1),(2);", + "SET autocommit = 0", + "CALL DOLT_COMMIT('-a', '-m', 'Step 1');", + "CALL DOLT_BRANCH('feature-branch')", + "use `mydb/feature-branch`", + "INSERT INTO test VALUES (3);", + "UPDATE test SET pk=1000 WHERE pk=0;", + "CALL DOLT_ADD('.');", + "CALL DOLT_COMMIT('-a', '-m', 'this is a ff');", + "use mydb/main;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // FF-Merge + Query: "CALL DOLT_MERGE('feature-branch')", + Expected: []sql.Row{{1, 0}}, + }, + { + Query: "SELECT is_merging, source, target, unmerged_tables FROM DOLT_MERGE_STATUS;", + Expected: []sql.Row{{false, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_status", + Expected: []sql.Row{}, + }, + { + Query: "select * from test order by 1", + Expected: []sql.Row{ + {1}, {2}, {3}, {1000}, + }, + }, + }, + }, { Name: "CALL DOLT_MERGE no-ff correctly works with autocommit off", SetUpScript: []string{ @@ -131,6 +169,51 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "CALL DOLT_MERGE no-ff correctly works with autocommit off, no checkout", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key)", + "call DOLT_ADD('.')", + "INSERT INTO test VALUES (0),(1),(2);", + "SET autocommit = 0", + "CALL DOLT_COMMIT('-a', '-m', 'Step 1', '--date', '2022-08-06T12:00:00');", + "CALL DOLT_BRANCH('feature-branch')", + "USE `mydb/feature-branch`", + "INSERT INTO test VALUES (3);", + "UPDATE test SET pk=1000 WHERE pk=0;", + "CALL DOLT_COMMIT('-a', '-m', 'this is a ff', '--date', '2022-08-06T12:00:01');", + "use `mydb/main`", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // No-FF-Merge + Query: "CALL DOLT_MERGE('feature-branch', '-no-ff', '-m', 'this is a no-ff')", + Expected: []sql.Row{{1, 0}}, + }, + { + Query: "SELECT is_merging, source, target, unmerged_tables FROM DOLT_MERGE_STATUS;", + Expected: []sql.Row{{false, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_status", + Expected: []sql.Row{}, + }, + { + Query: "SELECT COUNT(*) FROM dolt_log", + Expected: []sql.Row{{5}}, // includes the merge commit created by no-ff and setup commits + }, + { + Query: "select message from dolt_log order by date DESC LIMIT 1;", + Expected: []sql.Row{{"this is a no-ff"}}, // includes the merge commit created by no-ff + }, + { + Query: "select * from test order by 1", + Expected: []sql.Row{ + {1}, {2}, {3}, {1000}, + }, + }, + }, + }, { Name: "CALL DOLT_MERGE without conflicts correctly works with autocommit off with commit flag", SetUpScript: []string{ @@ -209,8 +292,12 @@ var MergeScripts = []queries.ScriptTest{ Expected: []sql.Row{{"add some more values"}}, }, { - Query: "CALL DOLT_CHECKOUT('-b', 'other-branch')", - ExpectedErr: dsess.ErrWorkingSetChanges, + Query: "CALL DOLT_CHECKOUT('-b', 'other')", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL DOLT_CHECKOUT('main')", + Expected: []sql.Row{{0}}, }, }, }, @@ -251,10 +338,6 @@ var MergeScripts = []queries.ScriptTest{ Query: "select message from dolt_log where date < '2022-08-08' order by date DESC LIMIT 1;", Expected: []sql.Row{{"update a value"}}, }, - { - Query: "CALL DOLT_CHECKOUT('-b', 'other-branch')", - ExpectedErr: dsess.ErrWorkingSetChanges, - }, { Query: "SELECT COUNT(*) FROM dolt_conflicts", Expected: []sql.Row{{1}}, @@ -273,6 +356,38 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "merge conflicts prevent new branch creation", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key, val int)", + "call DOLT_ADD('.')", + "INSERT INTO test VALUES (0, 0)", + "SET autocommit = 0", + "CALL DOLT_COMMIT('-a', '-m', 'Step 1', '--date', '2022-08-06T12:00:01');", + "CALL DOLT_CHECKOUT('-b', 'feature-branch')", + "INSERT INTO test VALUES (1, 1);", + "UPDATE test SET val=1000 WHERE pk=0;", + "CALL DOLT_COMMIT('-a', '-m', 'this is a normal commit', '--date', '2022-08-06T12:00:02');", + "CALL DOLT_CHECKOUT('main');", + "UPDATE test SET val=1001 WHERE pk=0;", + "CALL DOLT_COMMIT('-a', '-m', 'update a value', '--date', '2022-08-06T12:00:03');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL DOLT_MERGE('feature-branch', '-m', 'this is a merge')", + Expected: []sql.Row{{0, 1}}, + }, + { + Query: "SELECT is_merging, source, target, unmerged_tables FROM DOLT_MERGE_STATUS;", + Expected: []sql.Row{{true, "feature-branch", "refs/heads/main", "test"}}, + }, + { + // errors because creating a new branch implicitly commits the current transaction + Query: "CALL DOLT_CHECKOUT('-b', 'other-branch')", + ExpectedErrStr: "Merge conflict detected, transaction rolled back. Merge conflicts must be resolved using the dolt_conflicts tables before committing a transaction. To commit transactions with merge conflicts, set @@dolt_allow_commit_conflicts = 1", + }, + }, + }, { Name: "CALL DOLT_MERGE ff & squash correctly works with autocommit off", SetUpScript: []string{ @@ -326,8 +441,12 @@ var MergeScripts = []queries.ScriptTest{ Expected: []sql.Row{{1, 0}}, }, { - Query: "CALL DOLT_CHECKOUT('-b', 'other')", - ExpectedErr: dsess.ErrWorkingSetChanges, + Query: "CALL DOLT_CHECKOUT('-b', 'other')", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL DOLT_CHECKOUT('main')", + Expected: []sql.Row{{0}}, }, { Query: "SELECT * FROM test order by pk", @@ -372,6 +491,64 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "CALL DOLT_MERGE ff no checkout", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key)", + "CALL DOLT_ADD('.')", + "INSERT INTO test VALUES (0),(1),(2);", + "CALL DOLT_COMMIT('-a', '-m', 'Step 1');", + "CALL dolt_branch('feature-branch')", + "use `mydb/feature-branch`", + "INSERT INTO test VALUES (3);", + "UPDATE test SET pk=1000 WHERE pk=0;", + "CALL DOLT_COMMIT('-a', '-m', 'this is a ff');", + "use mydb/main;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // FF-Merge + Query: "CALL DOLT_MERGE('feature-branch')", + Expected: []sql.Row{{1, 0}}, + }, + { + Query: "SELECT is_merging, source, target, unmerged_tables FROM DOLT_MERGE_STATUS;", + Expected: []sql.Row{{false, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_status", + Expected: []sql.Row{}, + }, + { + Query: "CALL DOLT_CHECKOUT('-b', 'new-branch')", + Expected: []sql.Row{{0}}, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"new-branch"}}, + }, + { + Query: "INSERT INTO test VALUES (4)", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "SELECT * FROM test order by pk", + Expected: []sql.Row{{1}, {2}, {3}, {4}, {1000}}, + }, + { + Query: "use `mydb/main`", + SkipResultsCheck: true, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "SELECT * FROM test order by pk", + Expected: []sql.Row{{1}, {2}, {3}, {1000}}, + }, + }, + }, { Name: "CALL DOLT_MERGE no-ff", SetUpScript: []string{ @@ -451,6 +628,61 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "CALL DOLT_MERGE with no conflicts works, no checkout", + SetUpScript: []string{ + "CREATE TABLE test (pk int primary key)", + "CALL DOLT_ADD('.')", + "INSERT INTO test VALUES (0),(1),(2);", + "CALL DOLT_COMMIT('-a', '-m', 'Step 1', '--date', '2022-08-06T12:00:00');", + "CALL dolt_branch('feature-branch')", + "use `mydb/feature-branch`", + "INSERT INTO test VALUES (3);", + "UPDATE test SET pk=1000 WHERE pk=0;", + "CALL DOLT_COMMIT('-a', '-m', 'this is a normal commit', '--date', '2022-08-06T12:00:01');", + "use mydb/main", + "INSERT INTO test VALUES (5),(6),(7);", + "CALL DOLT_COMMIT('-a', '-m', 'add some more values', '--date', '2022-08-06T12:00:02');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL DOLT_MERGE('feature-branch', '--no-commit', '--commit')", + ExpectedErrStr: "cannot define both 'commit' and 'no-commit' flags at the same time", + }, + { + Query: "CALL DOLT_MERGE('feature-branch', '-m', 'this is a merge')", + Expected: []sql.Row{{0, 0}}, + }, + { + Query: "SELECT COUNT(*) from dolt_status", + Expected: []sql.Row{{0}}, + }, + { + Query: "SELECT COUNT(*) FROM dolt_log", + Expected: []sql.Row{{6}}, // includes the merge commit and a new commit created by successful merge + }, + { + Query: "select message from dolt_log where date > '2022-08-08' order by date DESC LIMIT 1;", + Expected: []sql.Row{{"this is a merge"}}, + }, + { + Query: "select * from test order by pk", + Expected: []sql.Row{ + {1}, {2}, {3}, {5}, {6}, {7}, {1000}, + }, + }, + { + Query: "use `mydb/feature-branch`", + SkipResultsCheck: true, + }, + { + Query: "select * from test order by pk", + Expected: []sql.Row{ + {1}, {2}, {3}, {1000}, + }, + }, + }, + }, { Name: "CALL DOLT_MERGE with no conflicts works with no-commit flag", SetUpScript: []string{ @@ -1208,6 +1440,40 @@ var Dolt1MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "dropping constraint from one branch drops from both, no checkout", + SetUpScript: []string{ + "create table t (i int)", + "alter table t add constraint c check (i > 0)", + "call dolt_commit('-Am', 'initial commit')", + + "call dolt_branch('other')", + "use mydb/other", + "insert into t values (1)", + "alter table t drop constraint c", + "call dolt_commit('-Am', 'changes to other')", + + "use mydb/main", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t values (-1)", + ExpectedErr: sql.ErrCheckConstraintViolated, + }, + { + Query: "CALL DOLT_MERGE('other');", + Expected: []sql.Row{{1, 0}}, + }, + { + Query: "select * from t", + Expected: []sql.Row{{1}}, + }, + { + Query: "insert into t values (-1)", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + }, + }, { Name: "merge constraint with valid data on different branches", SetUpScript: []string{ diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go index c3922e04e7..1f8a44ee2c 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go @@ -133,7 +133,7 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ }, { Query: "/* client a */ SHOW DATABASES;", - Expected: []sql.Row{{"dolt"}, {"dolt/branch1"}, {"information_schema"}, {"mysql"}}, + Expected: []sql.Row{{"dolt"}, {"information_schema"}, {"mysql"}}, }, { Query: "/* client a */ CALL DOLT_BRANCH('-d', 'branch2');", @@ -145,7 +145,7 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ }, { Query: "/* client a */ SHOW DATABASES;", - Expected: []sql.Row{{"dolt"}, {"dolt/branch1"}, {"information_schema"}, {"mysql"}}, + Expected: []sql.Row{{"dolt"}, {"information_schema"}, {"mysql"}}, }, { // Call a stored procedure since this searches across all databases and will @@ -180,7 +180,7 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ }, { Query: "/* client a */ SHOW DATABASES;", - Expected: []sql.Row{{"dolt"}, {"dolt/branch1"}, {"information_schema"}, {"mysql"}}, + Expected: []sql.Row{{"dolt"}, {"information_schema"}, {"mysql"}}, }, { Query: "/* client a */ CALL DOLT_BRANCH('-m', 'branch2', 'newName');", @@ -192,7 +192,7 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ }, { Query: "/* client a */ SHOW DATABASES;", - Expected: []sql.Row{{"dolt"}, {"dolt/branch1"}, {"information_schema"}, {"mysql"}}, + Expected: []sql.Row{{"dolt"}, {"information_schema"}, {"mysql"}}, }, { // Call a stored procedure since this searches across all databases and will @@ -235,11 +235,11 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ }, { Query: "/* client a */ select name from dolt_branches;", - ExpectedErrStr: "Error 1105: current branch has been force deleted. run 'USE /' to checkout a different branch, or reconnect to the server", + ExpectedErrStr: "Error 1105: branch not found", }, { Query: "/* client a */ CALL DOLT_CHECKOUT('main');", - ExpectedErrStr: "Error 1105: current branch has been force deleted. run 'USE /' to checkout a different branch, or reconnect to the server", + ExpectedErrStr: "Error 1105: Could not load database dolt", }, { Query: "/* client a */ USE dolt/main;", @@ -283,12 +283,13 @@ var DoltBranchMultiSessionScriptTests = []queries.ScriptTest{ Expected: []sql.Row{{"main"}}, }, { - Query: "/* client a */ select name from dolt_branches;", - ExpectedErrStr: "Error 1105: current branch has been force deleted. run 'USE /' to checkout a different branch, or reconnect to the server", + // client a still sees the branches and can use them because it's in a transaction + Query: "/* client a */ select name from dolt_branches;", + Expected: []sql.Row{{"branch1"}, {"main"}}, }, { - Query: "/* client a */ CALL DOLT_CHECKOUT('main');", - ExpectedErrStr: "Error 1105: current branch has been force deleted. run 'USE /' to checkout a different branch, or reconnect to the server", + Query: "/* client a */ CALL DOLT_CHECKOUT('main');", + Expected: []sql.Row{{0}}, }, { Query: "/* client a */ USE dolt/main;", @@ -405,7 +406,6 @@ var DropDatabaseMultiSessionScriptTests = []queries.ScriptTest{ Expected: []sql.Row{}, }, { - // At this point, this is an invalid revision database, and any queries against it will fail. Query: "/* client b */ select database();", Expected: []sql.Row{{"db01/branch1"}}, }, @@ -462,7 +462,7 @@ func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { if len(assertion.ExpectedErrStr) > 0 { require.EqualError(t, err, assertion.ExpectedErrStr) } else if assertion.ExpectedErr != nil { - require.True(t, assertion.ExpectedErr.Is(err)) + require.True(t, assertion.ExpectedErr.Is(err), "expected error %v, got %v", assertion.ExpectedErr, err) } else if assertion.Expected != nil { require.NoError(t, err) assertResultsEqual(t, assertion.Expected, rows) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go index 74757d14f1..ae5970c4be 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_commit_test.go @@ -384,35 +384,30 @@ func TestDoltTransactionCommitAutocommit(t *testing.T) { if !ok { t.Fatal("'mydb' database not found") } - cs, err := doltdb.NewCommitSpec("HEAD") + + headSpec, err := doltdb.NewCommitSpec("HEAD") require.NoError(t, err) headRefs, err := db.GetHeadRefs(context.Background()) require.NoError(t, err) - commit3, err := db.Resolve(context.Background(), cs, headRefs[0]) + head, err := db.Resolve(context.Background(), headSpec, headRefs[0]) require.NoError(t, err) - cm3, err := commit3.GetCommitMeta(context.Background()) + headMeta, err := head.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Contains(t, cm3.Description, "Transaction commit") + require.Contains(t, headMeta.Description, "Transaction commit") - as, err := doltdb.NewAncestorSpec("~1") + ancestorSpec, err := doltdb.NewAncestorSpec("~1") require.NoError(t, err) - commit2, err := commit3.GetAncestor(context.Background(), as) + parent, err := head.GetAncestor(context.Background(), ancestorSpec) require.NoError(t, err) - cm2, err := commit2.GetCommitMeta(context.Background()) + parentMeta, err := parent.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Contains(t, cm2.Description, "Transaction commit") + require.Contains(t, parentMeta.Description, "Transaction commit") - commit1, err := commit2.GetAncestor(context.Background(), as) + grandParent, err := parent.GetAncestor(context.Background(), ancestorSpec) require.NoError(t, err) - cm1, err := commit1.GetCommitMeta(context.Background()) + grandparentMeta, err := grandParent.GetCommitMeta(context.Background()) require.NoError(t, err) - require.Equal(t, "Transaction commit", cm1.Description) - - commit0, err := commit1.GetAncestor(context.Background(), as) - require.NoError(t, err) - cm0, err := commit0.GetCommitMeta(context.Background()) - require.NoError(t, err) - require.Equal(t, "checkpoint enginetest database mydb", cm0.Description) + require.Equal(t, "checkpoint enginetest database mydb", grandparentMeta.Description) } func TestDoltTransactionCommitLateFkResolution(t *testing.T) { diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go index ebd07816d6..57a9e9cf95 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_transaction_queries.go @@ -2368,6 +2368,16 @@ var MultiDbTransactionTests = []queries.ScriptTest{ {types.OkResult{RowsAffected: 1}}, }, }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from `mydb/b1`.t1 order by a", + Expected: []sql.Row{ + {1}, {2}, + }, + }, { Query: "commit", Expected: []sql.Row{}, @@ -2388,4 +2398,550 @@ var MultiDbTransactionTests = []queries.ScriptTest{ }, }, }, + { + Name: "committing to another branch with autocommit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = on", // unnecessary but make it explicit + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `mydb/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_checkout('b1')", + SkipResultsCheck: true, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{{1}}, + }, + }, + }, + { + Name: "committing to another branch with dolt_transaction_commit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = 0", + "set dolt_transaction_commit = on", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `mydb/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into `mydb/b1`.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from `mydb/b1`.t1 order by a", + Expected: []sql.Row{ + {1}, {2}, + }, + }, + { + Query: "commit", + ExpectedErrStr: "no changes to dolt_commit on branch main", + }, + { + Query: "use mydb/b1", + Expected: []sql.Row{}, + }, + { + Query: "commit", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{ + {1}, {2}, + }, + }, + }, + }, + { + Name: "committing to another branch with dolt_commit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = off", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `mydb/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_commit('-am', 'changes on b1')", + ExpectedErrStr: "nothing to commit", // this error is different from what you get with @@dolt_transaction_commit + }, + { + Query: "use mydb/b1", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_commit('-am', 'other changes on b1')", + SkipResultsCheck: true, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{{1}}, + }, + { + Query: "select message from dolt_log order by date desc limit 1", + Expected: []sql.Row{{"other changes on b1"}}, + }, + }, + }, + { + Name: "committing to another branch with autocommit and dolt_transaction_commit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = on", // unnecessary but make it explicit + "set dolt_transaction_commit = on", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `mydb/b1`.t1 values (1)", + ExpectedErrStr: "no changes to dolt_commit on branch main", + }, + { + Query: "use mydb/b1", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "commit", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{ + {1}, + }, + }, + }, + }, + { + Name: "active_branch with dolt_checkout and use", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = 0", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `mydb/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into `mydb/b1`.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_checkout('b1')", + SkipResultsCheck: true, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"b1"}}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{ + {1}, {2}, + }, + }, + { + Query: "call dolt_checkout('main')", + SkipResultsCheck: true, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "use `mydb/b1`", + Expected: []sql.Row{}, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"b1"}}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{{1}, {2}}, + }, + { + Query: "use mydb", + Expected: []sql.Row{}, + }, + { + Query: "select active_branch()", + Expected: []sql.Row{{"main"}}, + }, + { + Query: "commit", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_checkout('b1')", + SkipResultsCheck: true, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{ + {1}, {2}, + }, + }, + }, + }, + { + Name: "committing to another database", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "create database db1", + "use db1", + "create table t1 (a int)", + "use mydb", + "set autocommit = 0", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into db1.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into db1.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from db1.t1 order by a", + Expected: []sql.Row{{1}, {2}}, + }, + { + Query: "commit", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from db1.t1 order by a", + Expected: []sql.Row{{1}, {2}}, + }, + }, + }, + { + Name: "committing to another database with dolt_commit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "create database db1", + "use db1", + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "use mydb/b1", + "set autocommit = off", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `db1/b1`.t1 values (1)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "call dolt_commit('-am', 'changes on b1')", + ExpectedErrStr: "nothing to commit", // this error is different from what you get with @@dolt_transaction_commit + }, + { + Query: "use db1/b1", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_commit('-am', 'other changes on b1')", + SkipResultsCheck: true, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{{1}}, + }, + { + Query: "select message from dolt_log order by date desc limit 1", + Expected: []sql.Row{{"other changes on b1"}}, + }, + }, + }, + { + Name: "committing to another branch on another database", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "create database db1", + "use db1", + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "use mydb", + "set autocommit = 0", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `db1/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into `db1/b1`.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from db1.t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from `db1/b1`.t1 order by a", + Expected: []sql.Row{{1}, {2}}, + }, + { + Query: "commit", + Expected: []sql.Row{}, + }, + { + Query: "select * from t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from db1.t1 order by a", + Expected: []sql.Row{}, + }, + { + Query: "select * from `db1/b1`.t1 order by a", + Expected: []sql.Row{{1}, {2}}, + }, + }, + }, + { + Name: "committing to another branch on another database with dolt_transaction_commit and autocommit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "create database db1", + "use db1", + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "use mydb/b1", + "set autocommit = 1", + "set dolt_transaction_commit = 1", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `db1/b1`.t1 values (1)", + ExpectedErrStr: "no changes to dolt_commit on database mydb", + }, + }, + }, + { + Name: "committing to another branch on another database with dolt_transaction_commit, no autocommit", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "create database db1", + "use db1", + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "commit", + "use mydb/b1", + "set autocommit = off", + "set dolt_transaction_commit = 1", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into `db1/b1`.t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "commit", + ExpectedErrStr: "no changes to dolt_commit on database mydb", + }, + }, + }, + { + Name: "committing to more than one branch at a time", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = 0", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into `mydb/b1`.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "commit", + ExpectedErrStr: "Cannot commit changes on more than one branch / database", + }, + }, + }, + { + Name: "committing to more than one branch at a time with checkout", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "call dolt_branch('b1')", + "set autocommit = 0", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "call dolt_checkout('b1')", + SkipResultsCheck: true, + }, + { + Query: "insert into t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "commit", + ExpectedErrStr: "Cannot commit changes on more than one branch / database", + }, + }, + }, + { + Name: "committing to more than one database at a time", + SetUpScript: []string{ + "create table t1 (a int)", + "call dolt_add('.')", + "call dolt_commit('-am', 'new table')", + "create database db2", + "set autocommit = 0", + "create table db2.t1 (a int)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t1 values (1)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "insert into db2.t1 values (2)", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "commit", + ExpectedErrStr: "Cannot commit changes on more than one branch / database", + }, + }, + }, } diff --git a/go/libraries/doltcore/sqle/enginetest/privilege_test.go b/go/libraries/doltcore/sqle/enginetest/privilege_test.go new file mode 100755 index 0000000000..7baab9c43e --- /dev/null +++ b/go/libraries/doltcore/sqle/enginetest/privilege_test.go @@ -0,0 +1,1022 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package enginetest + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/enginetest" + "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/mysql_db" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/stretchr/testify/require" +) + +var revisionDatabasePrivsSetupPrefix = []string{ + "call dolt_branch('b1')", + "use mydb/b1", +} + +// The subset of tests in priv_auth_queries.go to run with alternate branch logic. Not all of them are suitable +// because they use non-qualified database names in their queries +var revisionDatabasePrivilegeScriptNames = []string{ + "Binlog replication privileges", + "Valid users without privileges may use the dual table", + "Basic SELECT and INSERT privilege checking", + "Basic revoke SELECT privilege", + "Basic revoke all global static privileges", + "Grant Role with SELECT Privilege", + "Revoke role currently granted to a user", + "Drop role currently granted to a user", + "Show grants on a user from the root account", + "information_schema.columns table 'privileges' column gets correct values", + "information_schema.column_statistics shows columns with privileges only", + "information_schema.statistics shows tables with privileges only", +} + +// TestRevisionDatabasePrivileges is a spot-check of privilege checking on the original privilege test scripts, +// but with a revisioned database as the current db +func TestRevisionDatabasePrivileges(t *testing.T) { + testsToRun := make(map[string]bool) + for _, name := range revisionDatabasePrivilegeScriptNames { + testsToRun[name] = true + } + + var scripts []queries.UserPrivilegeTest + for _, script := range queries.UserPrivTests { + if testsToRun[script.Name] { + scripts = append(scripts, script) + } + } + + require.Equal(t, len(scripts), len(testsToRun), + "Error in test setup: one or more expected tests not found. "+ + "Did the name of a test change?") + + for _, script := range scripts { + harness := newDoltHarness(t) + harness.Setup(setup.MydbData, setup.MytableData) + t.Run(script.Name, func(t *testing.T) { + engine := mustNewEngine(t, harness) + defer engine.Close() + + ctx := enginetest.NewContext(harness) + ctx.NewCtxWithClient(sql.Client{ + User: "root", + Address: "localhost", + }) + engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{}) + + for _, statement := range append(revisionDatabasePrivsSetupPrefix, script.SetUpScript...) { + if harness.SkipQueryTest(statement) { + t.Skip() + } + enginetest.RunQueryWithContext(t, engine, harness, ctx, statement) + } + + for _, assertion := range script.Assertions { + if harness.SkipQueryTest(assertion.Query) { + t.Skipf("Skipping query %s", assertion.Query) + } + + user := assertion.User + host := assertion.Host + if user == "" { + user = "root" + } + if host == "" { + host = "localhost" + } + ctx := enginetest.NewContextWithClient(harness, sql.Client{ + User: user, + Address: host, + }) + ctx.SetCurrentDatabase("mydb/b1") + + if assertion.ExpectedErr != nil { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, harness, ctx, assertion.Query, assertion.ExpectedErr) + }) + } else if assertion.ExpectedErrStr != "" { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, harness, ctx, assertion.Query, nil, assertion.ExpectedErrStr) + }) + } else { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.TestQueryWithContext(t, ctx, engine, harness, assertion.Query, assertion.Expected, nil, nil) + }) + } + } + }) + } +} + +// Privilege test scripts for revision databases. Due to limitations in test construction, test assertions are always +// performed with current db = mydb/b1, write scripts accordingly +var DoltOnlyRevisionDbPrivilegeTests = []queries.UserPrivilegeTest{ + { + Name: "Basic database and table name visibility", + SetUpScript: []string{ + "use mydb", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1);", + "call dolt_commit('-Am', 'first commit')", + "call dolt_branch('b1')", + "use mydb/b1", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON mydb.* TO test_role;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*1*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*1*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*2*/", + Expected: []sql.Row{{1}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*2*/", + ExpectedErr: sql.ErrTableNotFound, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { // Ensure we've reverted to initial state (all SELECTs after REVOKEs are doing this) + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*3*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*3*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM mydb.test;/*4*/", + Expected: []sql.Row{{1}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*4*/", + ExpectedErr: sql.ErrTableNotFound, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*5*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*5*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.test TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*6*/", + Expected: []sql.Row{{1}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*6*/", + ExpectedErr: sql.ErrTableAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.test FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*7*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*7*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.test2 TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*8*/", + ExpectedErr: sql.ErrTableAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*8*/", + ExpectedErr: sql.ErrTableNotFound, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.test2 FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*9*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*9*/", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT test_role TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;/*10*/", + Expected: []sql.Row{{1}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test2;/*10*/", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "Basic SELECT and INSERT privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (4);", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT INSERT ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (4);", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}, {4}}, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}, {4}}, + }, + }, + }, + { + Name: "Basic UPDATE privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "UPDATE test set pk = 4 where pk = 3;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT UPDATE ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (4);", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "tester", + Host: "localhost", + Query: "UPDATE test set pk = 4 where pk = 3;", + Expected: []sql.Row{{types.OkResult{ + RowsAffected: 1, + Info: plan.UpdateInfo{ + Matched: 1, + Updated: 1, + }, + }}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {4}}, + }, + }, + }, + { + Name: "Basic DELETE privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "DELETE from test where pk = 3;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT DELETE ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (4);", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "tester", + Host: "localhost", + Query: "DELETE from test where pk = 3;", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}}, + }, + }, + }, + { + Name: "Basic CREATE TABLE privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "CREATE TABLE t2 (a int primary key);", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT CREATE ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "CREATE TABLE t2 (a int primary key);", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "show tables;", + Expected: []sql.Row{{"mytable"}, {"test"}, {"t2"}}, + }, + }, + }, + { + Name: "Basic DROP TABLE privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "DROP TABLE test;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT DROP ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "DROP TABLE TEST", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "show tables;", + Expected: []sql.Row{{"mytable"}}, + }, + }, + }, + { + Name: "Basic ALTER TABLE privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "ALTER TABLE test add column a int;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT ALTER ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "ALTER TABLE test add column a int;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "desc test;", + Expected: []sql.Row{ + {"pk", "bigint", "NO", "PRI", "NULL", ""}, + {"a", "int", "YES", "", "NULL", ""}, + }, + }, + }, + }, + { + Name: "Basic INDEX privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY, a int);", + "INSERT INTO test VALUES (1,1), (2,2), (3,3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "create index t1 on test(a) ;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT select ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "create index t1 on test(a) ;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT index ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "create index t1 on test(a) ;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "desc test;", + Expected: []sql.Row{ + {"pk", "bigint", "NO", "PRI", "NULL", ""}, + {"a", "int", "YES", "MUL", "NULL", ""}, + }, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE index ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "drop index t1 on test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT index ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "drop index t1 on test;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "desc test;", + Expected: []sql.Row{ + {"pk", "bigint", "NO", "PRI", "NULL", ""}, + {"a", "int", "YES", "", "NULL", ""}, + }, + }, + }, + }, + { + Name: "Basic constraint privilege checking", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY, a int);", + "INSERT INTO test VALUES (1,1), (2,2), (3,3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "alter table test add constraint CHECK (NULL = NULL);", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT select ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test add constraint CHECK (NULL = NULL);", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT alter ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test add constraint chk1 CHECK (a < 10);", + Expected: []sql.Row{}, + }, + { + User: "tester", + Host: "localhost", + Query: "show create table test;", + Expected: []sql.Row{ + {"test", "CREATE TABLE `test` (\n" + + " `pk` bigint NOT NULL,\n" + + " `a` int,\n" + + " PRIMARY KEY (`pk`),\n" + + " CONSTRAINT `chk1` CHECK ((`a` < 10))\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE alter ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test drop check chk1;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT alter ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test drop check chk1;", + Expected: []sql.Row{}, + }, + { + User: "tester", + Host: "localhost", + Query: "show create table test;", + Expected: []sql.Row{ + {"test", "CREATE TABLE `test` (\n" + + " `pk` bigint NOT NULL,\n" + + " `a` int,\n" + + " PRIMARY KEY (`pk`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test add constraint chk1 CHECK (a < 10);", + Expected: []sql.Row{}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE alter ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test drop constraint chk1;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT alter ON mydb.* TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "alter table test drop constraint chk1;", + Expected: []sql.Row{}, + }, + { + User: "tester", + Host: "localhost", + Query: "show create table test;", + Expected: []sql.Row{ + {"test", "CREATE TABLE `test` (\n" + + " `pk` bigint NOT NULL,\n" + + " `a` int,\n" + + " PRIMARY KEY (`pk`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + }, + }, + { + Name: "Basic revoke SELECT privilege", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + "GRANT SELECT ON mydb.* TO tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.* FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, + }, + }, + }, + { + Name: "Grant Role with SELECT Privilege", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON mydb.* TO test_role;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT test_role TO tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", uint16(1)}}, + }, + }, + }, + { + Name: "Revoke role currently granted to a user", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "call dolt_commit('-Am', 'first commit');", + "call dolt_branch('b1')", + "use mydb/b1;", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON mydb.* TO test_role;", + "GRANT test_role TO tester@localhost;", + }, + Assertions: []queries.UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", uint16(1)}}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE test_role FROM tester@localhost;", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'test_role';", + Expected: []sql.Row{{1}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{1}}, + }, + }, + }, +} + +func TestDoltOnlyRevisionDatabasePrivileges(t *testing.T) { + for _, script := range DoltOnlyRevisionDbPrivilegeTests { + harness := newDoltHarness(t) + harness.Setup(setup.MydbData, setup.MytableData) + t.Run(script.Name, func(t *testing.T) { + engine := mustNewEngine(t, harness) + defer engine.Close() + + ctx := enginetest.NewContext(harness) + ctx.NewCtxWithClient(sql.Client{ + User: "root", + Address: "localhost", + }) + engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{}) + + for _, statement := range script.SetUpScript { + enginetest.RunQueryWithContext(t, engine, harness, ctx, statement) + } + + for _, assertion := range script.Assertions { + user := assertion.User + host := assertion.Host + if user == "" { + user = "root" + } + if host == "" { + host = "localhost" + } + ctx := enginetest.NewContextWithClient(harness, sql.Client{ + User: user, + Address: host, + }) + ctx.SetCurrentDatabase("mydb/b1") + + if assertion.ExpectedErr != nil { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, harness, ctx, assertion.Query, assertion.ExpectedErr) + }) + } else if assertion.ExpectedErrStr != "" { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, harness, ctx, assertion.Query, nil, assertion.ExpectedErrStr) + }) + } else { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.TestQueryWithContext(t, ctx, engine, harness, assertion.Query, assertion.Expected, nil, nil) + }) + } + } + }) + } +} diff --git a/go/libraries/doltcore/sqle/read_replica_database.go b/go/libraries/doltcore/sqle/read_replica_database.go index fc4b3485ca..58592c4060 100644 --- a/go/libraries/doltcore/sqle/read_replica_database.go +++ b/go/libraries/doltcore/sqle/read_replica_database.go @@ -53,8 +53,6 @@ var _ dsess.RemoteReadReplicaDatabase = ReadReplicaDatabase{} var ErrFailedToLoadReplicaDB = errors.New("failed to load replica database") var ErrInvalidReplicateHeadsSetting = errors.New("invalid replicate heads setting") -var ErrFailedToCastToReplicaDb = errors.New("failed to cast to ReadReplicaDatabase") -var ErrCannotCreateReplicaRevisionDbForCommit = errors.New("cannot create replica revision db for commit") var EmptyReadReplica = ReadReplicaDatabase{} @@ -96,8 +94,8 @@ func (rrd ReadReplicaDatabase) ValidReplicaState(ctx *sql.Context) bool { // InitialDBState implements dsess.SessionDatabase // This seems like a pointless override from the embedded Database implementation, but it's necessary to pass the // correct pointer type to the session initializer. -func (rrd ReadReplicaDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) { - return initialDBState(ctx, rrd, branch) +func (rrd ReadReplicaDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) { + return initialDBState(ctx, rrd, rrd.revision) } func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error { @@ -117,7 +115,7 @@ func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error { } dSess := dsess.DSessFromSess(ctx.Session) - currentBranchRef, err := dSess.CWBHeadRef(ctx, rrd.name) + currentBranchRef, err := dSess.CWBHeadRef(ctx, rrd.baseName) if err != nil && !dsess.IgnoreReplicationErrors() { return err } else if err != nil { diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index b53e381d0c..2650d6af23 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -222,8 +222,7 @@ func (t *DoltTable) workingRoot(ctx *sql.Context) (*doltdb.RootValue, error) { return root, nil } -// getRoot returns the appropriate root value for this session. The only controlling factor -// is whether this is a temporary table or not. +// getRoot returns the current root value for this session, to be used for all table data access. func (t *DoltTable) getRoot(ctx *sql.Context) (*doltdb.RootValue, error) { return t.db.GetRoot(ctx) } @@ -244,7 +243,7 @@ func (t *DoltTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { } sess := dsess.DSessFromSess(ctx.Session) - dbState, ok, err := sess.LookupDbState(ctx, t.db.Name()) + dbState, ok, err := sess.LookupDbState(ctx, t.db.RevisionQualifiedName()) if err != nil { return nil, err } @@ -498,7 +497,7 @@ func (t *WritableDoltTable) WithProjections(colNames []string) sql.Table { // Inserter implements sql.InsertableTable func (t *WritableDoltTable) Inserter(ctx *sql.Context) sql.RowInserter { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return sqlutil.NewStaticErrorEditor(err) } te, err := t.getTableEditor(ctx) @@ -520,13 +519,13 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri } } - state, _, err := ds.LookupDbState(ctx, t.db.name) + state, _, err := ds.LookupDbState(ctx, t.db.RevisionQualifiedName()) if err != nil { return nil, err } setter := ds.SetRoot - ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), setter, batched) + ed, err = state.WriteSession().GetTableWriter(ctx, t.tableName, t.db.RevisionQualifiedName(), setter, batched) if err != nil { return nil, err @@ -540,7 +539,7 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri // Deleter implements sql.DeletableTable func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return sqlutil.NewStaticErrorEditor(err) } te, err := t.getTableEditor(ctx) @@ -552,7 +551,7 @@ func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter { // Replacer implements sql.ReplaceableTable func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return sqlutil.NewStaticErrorEditor(err) } te, err := t.getTableEditor(ctx) @@ -564,7 +563,7 @@ func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer { // Truncate implements sql.TruncateableTable func (t *WritableDoltTable) Truncate(ctx *sql.Context) (int, error) { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return 0, err } table, err := t.DoltTable.DoltTable(ctx) @@ -635,13 +634,13 @@ func (t *WritableDoltTable) truncate( } } - ws, err := sess.WorkingSet(ctx, t.db.name) + ws, err := sess.WorkingSet(ctx, t.db.RevisionQualifiedName()) if err != nil { return nil, err } if schema.HasAutoIncrement(sch) { - ddb, _ := sess.GetDoltDB(ctx, t.db.name) + ddb, _ := sess.GetDoltDB(ctx, t.db.RevisionQualifiedName()) err = t.db.removeTableFromAutoIncrementTracker(ctx, t.Name(), ddb, ws.Ref()) if err != nil { return nil, err @@ -701,7 +700,7 @@ func copyConstraintViolationsAndConflicts(ctx context.Context, from, to *doltdb. // Updater implements sql.UpdatableTable func (t *WritableDoltTable) Updater(ctx *sql.Context) sql.RowUpdater { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return sqlutil.NewStaticErrorEditor(err) } te, err := t.getTableEditor(ctx) @@ -713,7 +712,7 @@ func (t *WritableDoltTable) Updater(ctx *sql.Context) sql.RowUpdater { // AutoIncrementSetter implements sql.AutoIncrementTable func (t *WritableDoltTable) AutoIncrementSetter(ctx *sql.Context) sql.AutoIncrementSetter { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return sqlutil.NewStaticErrorEditor(err) } te, err := t.getTableEditor(ctx) @@ -1100,7 +1099,7 @@ func (t *AlterableDoltTable) WithProjections(colNames []string) sql.Table { // AddColumn implements sql.AlterableTable func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) @@ -1237,7 +1236,7 @@ func (t *AlterableDoltTable) RewriteInserter( newColumn *sql.Column, idxCols []sql.IndexColumn, ) (sql.RowInserter, error) { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return nil, err } err := validateSchemaChange(t.Name(), oldSchema, newSchema, oldColumn, newColumn, idxCols) @@ -1248,7 +1247,7 @@ func (t *AlterableDoltTable) RewriteInserter( sess := dsess.DSessFromSess(ctx.Session) // Begin by creating a new table with the same name and the new schema, then removing all its existing rows - dbState, ok, err := sess.LookupDbState(ctx, t.db.Name()) + dbState, ok, err := sess.LookupDbState(ctx, t.db.RevisionQualifiedName()) if err != nil { return nil, err } @@ -1257,12 +1256,12 @@ func (t *AlterableDoltTable) RewriteInserter( return nil, fmt.Errorf("database %s not found in session", t.db.Name()) } - ws := dbState.WorkingSet + ws := dbState.WorkingSet() if ws == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } - head, err := sess.GetHeadCommit(ctx, t.db.Name()) + head, err := sess.GetHeadCommit(ctx, t.db.RevisionQualifiedName()) if err != nil { return nil, err } @@ -1366,7 +1365,7 @@ func (t *AlterableDoltTable) RewriteInserter( // We can't just call getTableEditor here because it uses the session state, which we can't update until after the // rewrite operation - opts := dbState.WriteSession.GetOptions() + opts := dbState.WriteSession().GetOptions() opts.ForeignKeyChecksDisabled = true newRoot, err := ws.WorkingRoot().PutTable(ctx, t.Name(), dt) @@ -1391,7 +1390,7 @@ func (t *AlterableDoltTable) RewriteInserter( } writeSession := writer.NewWriteSession(dt.Format(), newWs, ait, opts) - ed, err := writeSession.GetTableWriter(ctx, t.Name(), t.db.Name(), sess.SetRoot, false) + ed, err := writeSession.GetTableWriter(ctx, t.Name(), t.db.RevisionQualifiedName(), sess.SetRoot, false) if err != nil { return nil, err } @@ -1614,7 +1613,7 @@ func (t *AlterableDoltTable) dropColumnData(ctx *sql.Context, updatedTable *dolt // ModifyColumn implements sql.AlterableTable. ModifyColumn operations are only used for operations that change only // the schema of a table, not the data. For those operations, |RewriteInserter| is used. func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } ws, err := t.db.GetWorkingSet(ctx) @@ -1681,7 +1680,7 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c if existingCol.AutoIncrement && !col.AutoIncrement { // TODO: this isn't transactional, and it should be sess := dsess.DSessFromSess(ctx.Session) - ddb, _ := sess.GetDoltDB(ctx, t.db.name) + ddb, _ := sess.GetDoltDB(ctx, t.db.RevisionQualifiedName()) err = t.db.removeTableFromAutoIncrementTracker(ctx, t.Name(), ddb, ws.Ref()) if err != nil { return err @@ -1781,7 +1780,7 @@ func allocatePrefixLengths(idxCols []sql.IndexColumn) []uint16 { // CreateIndex implements sql.IndexAlterableTable func (t *AlterableDoltTable) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } if idx.Constraint != sql.IndexConstraint_None && idx.Constraint != sql.IndexConstraint_Unique && idx.Constraint != sql.IndexConstraint_Spatial { @@ -1856,7 +1855,7 @@ func (t *AlterableDoltTable) CreateIndex(ctx *sql.Context, idx sql.IndexDef) err // DropIndex implements sql.IndexAlterableTable func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } // We disallow removing internal dolt_ tables from SQL directly @@ -1884,7 +1883,7 @@ func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error // RenameIndex implements sql.IndexAlterableTable func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } // RenameIndex will error if there is a name collision or an index does not exist @@ -2022,7 +2021,7 @@ func (t *AlterableDoltTable) createForeignKey( // AddForeignKey implements sql.ForeignKeyTable func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKeyConstraint) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } if sqlFk.Name != "" && !doltdb.IsValidForeignKeyName(sqlFk.Name) { @@ -2080,7 +2079,7 @@ func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKe // DropForeignKey implements sql.ForeignKeyTable func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) @@ -2108,7 +2107,7 @@ func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) err // UpdateForeignKey implements sql.ForeignKeyTable func (t *AlterableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, sqlFk sql.ForeignKeyConstraint) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) @@ -2384,9 +2383,9 @@ func (t *AlterableDoltTable) updateFromRoot(ctx *sql.Context, root *doltdb.RootV // When we update this table we need to also clear any cached versions of the object, since they may now have // incorrect schema information sess := dsess.DSessFromSess(ctx.Session) - dbState, ok, err := sess.LookupDbState(ctx, t.db.name) + dbState, ok, err := sess.LookupDbState(ctx, t.db.RevisionQualifiedName()) if !ok { - return fmt.Errorf("no db state found for %s", t.db.name) + return fmt.Errorf("no db state found for %s", t.db.RevisionQualifiedName()) } dbState.SessionCache().ClearTableCache() @@ -2395,7 +2394,7 @@ func (t *AlterableDoltTable) updateFromRoot(ctx *sql.Context, root *doltdb.RootV } func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) @@ -2451,7 +2450,7 @@ func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefin } func (t *AlterableDoltTable) DropCheck(ctx *sql.Context, chName string) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) @@ -2502,7 +2501,7 @@ func (t *AlterableDoltTable) ModifyStoredCollation(ctx *sql.Context, collation s } func (t *AlterableDoltTable) ModifyDefaultCollation(ctx *sql.Context, collation sql.CollationID) error { - if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { + if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } root, err := t.getRoot(ctx) diff --git a/go/libraries/doltcore/sqle/temp_table.go b/go/libraries/doltcore/sqle/temp_table.go index 2795e309e3..652f808903 100644 --- a/go/libraries/doltcore/sqle/temp_table.go +++ b/go/libraries/doltcore/sqle/temp_table.go @@ -78,7 +78,7 @@ func NewTempTable( return nil, fmt.Errorf("database %s not found in session", db) } - ws := dbState.WorkingSet + ws := dbState.WorkingSet() if ws == nil { return nil, doltdb.ErrOperationNotSupportedInDetachedHead } @@ -155,7 +155,7 @@ func setTempTableRoot(t *TempTable) func(ctx *sql.Context, dbName string, newRoo return fmt.Errorf("database %s not found in session", t.dbName) } - ws := dbState.WorkingSet + ws := dbState.WorkingSet() if ws == nil { return doltdb.ErrOperationNotSupportedInDetachedHead } diff --git a/go/libraries/doltcore/sqle/user_space_database.go b/go/libraries/doltcore/sqle/user_space_database.go index 285e090767..721429a787 100644 --- a/go/libraries/doltcore/sqle/user_space_database.go +++ b/go/libraries/doltcore/sqle/user_space_database.go @@ -76,7 +76,7 @@ func (db *UserSpaceDatabase) GetTableNames(ctx *sql.Context) ([]string, error) { return resultingTblNames, nil } -func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) { +func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error) { return dsess.InitialDbState{ Db: db, ReadOnly: true, @@ -111,10 +111,18 @@ func (db *UserSpaceDatabase) Revision() string { return "" } +func (db *UserSpaceDatabase) Versioned() bool { + return false +} + func (db *UserSpaceDatabase) RevisionType() dsess.RevisionType { return dsess.RevisionTypeNone } -func (db *UserSpaceDatabase) BaseName() string { +func (db *UserSpaceDatabase) RevisionQualifiedName() string { + return db.Name() +} + +func (db *UserSpaceDatabase) RequestedName() string { return db.Name() } diff --git a/integration-tests/bats/config.bats b/integration-tests/bats/config.bats index 118e4f23dd..9b815cb515 100644 --- a/integration-tests/bats/config.bats +++ b/integration-tests/bats/config.bats @@ -189,7 +189,7 @@ teardown() { } @test "config: SQL can create databases with no user and email set" { - dolt sql -b -q " + dolt sql -q " CREATE DATABASE testdb; use testdb; CREATE TABLE test (pk int primary key, c1 varchar(1));" diff --git a/integration-tests/bats/db-revision-specifiers.bats b/integration-tests/bats/db-revision-specifiers.bats index d27f8de8eb..7342bc1428 100644 --- a/integration-tests/bats/db-revision-specifiers.bats +++ b/integration-tests/bats/db-revision-specifiers.bats @@ -44,7 +44,7 @@ use $database_name/branch1; show databases; SQL [ "$status" -eq "0" ] - [[ "$output" =~ "$database_name/branch1" ]] || false + [[ "$output" =~ "$database_name" ]] || false # Can be used as part of a fully qualified table name run dolt sql -q "SELECT * FROM \`$database_name/branch1\`.test" -r=csv @@ -80,7 +80,7 @@ use $database_name/v1; show databases; SQL [ "$status" -eq "0" ] - [[ "$output" =~ "$database_name/v1" ]] || false + [[ "$output" =~ "$database_name" ]] || false # Can be used as part of a fully qualified table name run dolt sql -q "SELECT * FROM \`$database_name/v1\`.test" -r=csv @@ -95,7 +95,7 @@ use $database_name/v1; insert into test values (100, 'beige'); SQL [ "$status" -ne "0" ] - [[ "$output" =~ "$database_name/v1 is read-only" ]] || false + [[ "$output" =~ "$database_name is read-only" ]] || false } @test "db-revision-specifiers: commit-qualified database revisions" { @@ -118,7 +118,7 @@ use $database_name/$commit; show databases; SQL [ "$status" -eq "0" ] - [[ "$output" =~ "$database_name/$commit" ]] || false + [[ "$output" =~ "$database_name" ]] || false # Can be used as part of a fully qualified table name run dolt sql -q "SELECT * FROM \`$database_name/$commit\`.test" -r=csv @@ -133,5 +133,5 @@ use $database_name/$commit; insert into test values (100, 'beige'); SQL [ "$status" -ne "0" ] - [[ "$output" =~ "$database_name/$commit is read-only" ]] || false + [[ "$output" =~ "$database_name is read-only" ]] || false } diff --git a/integration-tests/bats/deleted-branches.bats b/integration-tests/bats/deleted-branches.bats index 25cb41bf34..adf9bc3a4e 100644 --- a/integration-tests/bats/deleted-branches.bats +++ b/integration-tests/bats/deleted-branches.bats @@ -62,7 +62,7 @@ force_delete_main_branch_on_sqlserver() { run dolt sql-client --use-db "dolt_repo_$$" -u dolt -P $PORT \ -q "call dolt_checkout('to_keep');" [ $status -ne 0 ] - [[ "$output" =~ "branch not found" ]] || false + [[ "$output" =~ "database not found" ]] || false } @test "deleted-branches: dolt branch from the CLI does not allow deleting the last branch" { @@ -144,10 +144,11 @@ force_delete_main_branch_on_sqlserver() { # We are able to use a database branch revision in the connection string dolt sql-client --use-db "dolt_repo_$$/main" -u dolt -P $PORT -q "SELECT * FROM test;" - # Trying to checkout a new branch throws an error, but doesn't panic - run dolt sql-client --use-db "dolt_repo_$$/main" -u dolt -P $PORT -q "CALL DOLT_CHECKOUT('to_keep');" - [ $status -ne 0 ] - [[ "$output" =~ "branch not found" ]] || false + # Trying to checkout a new branch works + dolt sql-client --use-db "dolt_repo_$$/main" -u dolt -P $PORT -q "CALL DOLT_CHECKOUT('to_keep');" + + run dolt branch + [[ "$output" =~ "to_keep" ]] || false } @test "deleted-branches: dolt_checkout() from sql-server doesn't panic when connected to a revision db and the db's default branch is invalid" { @@ -159,10 +160,11 @@ force_delete_main_branch_on_sqlserver() { # We are able to use a database branch revision in the connection string dolt sql-client --use-db "dolt_repo_$$/to_keep" -u dolt -P $PORT -q "SELECT * FROM test;" - # Trying to checkout a new branch throws an error, but doesn't panic - run dolt sql-client --use-db "dolt_repo_$$/to_keep" -u dolt -P $PORT -q "CALL DOLT_CHECKOUT('to_checkout');" - [ $status -ne 0 ] - [[ "$output" =~ "branch not found" ]] || false + # Trying to checkout a new branch works + dolt sql-client --use-db "dolt_repo_$$/to_keep" -u dolt -P $PORT -q "CALL DOLT_CHECKOUT('to_checkout');" + + run dolt branch + [[ "$output" =~ "to_checkout" ]] || false } @test "deleted-branches: dolt_checkout() from sql-server works when the db's default branch is invalid, but the global default_branch var is valid" { diff --git a/integration-tests/bats/remotes-sql-server.bats b/integration-tests/bats/remotes-sql-server.bats index 4adfa20726..bd43f97191 100644 --- a/integration-tests/bats/remotes-sql-server.bats +++ b/integration-tests/bats/remotes-sql-server.bats @@ -332,7 +332,8 @@ teardown() { cd repo1 dolt checkout -b feature-branch - dolt commit -am "new commit" + dolt sql -q "create table newTable (a int primary key)" + dolt commit -Am "new commit" dolt push remote1 feature-branch cd ../repo2 @@ -347,8 +348,7 @@ teardown() { # Can't connect to a specific branch with dolt sql-client run dolt sql-client --use-db "repo2/feature-branch" -u dolt -P $PORT -q "SHOW Tables" [ $status -eq 0 ] - [[ $output =~ "feature-branch" ]] || false - [[ $output =~ "test" ]] || false + [[ $output =~ "newTable" ]] || false } @test "remotes-sql-server: connect to hash works" { @@ -502,7 +502,7 @@ teardown() { run dolt sql-client --use-db repo2/feature -P $PORT -u dolt -q "select active_branch()" [ $status -eq 1 ] - [[ "$output" =~ "database not found: repo2/feature" ]] || false + [[ "$output" =~ "'feature' matched multiple remote tracking branches" ]] || false run grep "'feature' matched multiple remote tracking branches" server_log.txt [ "${#lines[@]}" -ne 0 ] diff --git a/integration-tests/go-sql-server-driver/main_test.go b/integration-tests/go-sql-server-driver/main_test.go index cf89dd562c..f81d6084ab 100644 --- a/integration-tests/go-sql-server-driver/main_test.go +++ b/integration-tests/go-sql-server-driver/main_test.go @@ -28,6 +28,13 @@ func TestCluster(t *testing.T) { RunTestsFile(t, "tests/sql-server-cluster.yaml") } +// TestSingle is a convenience method for running a single test from within an IDE. Unskip and set to the file and name +// of the test you want to debug. See README.md in the `tests` directory for more debugging info. +func TestSingle(t *testing.T) { + // t.Skip() + RunSingleTest(t, "tests/sql-server-cluster.yaml", "primary comes up and replicates to standby") +} + func TestClusterTLS(t *testing.T) { RunTestsFile(t, "tests/sql-server-cluster-tls.yaml") } diff --git a/integration-tests/go-sql-server-driver/testdef.go b/integration-tests/go-sql-server-driver/testdef.go index e0ff885098..aa93b8b277 100644 --- a/integration-tests/go-sql-server-driver/testdef.go +++ b/integration-tests/go-sql-server-driver/testdef.go @@ -22,7 +22,7 @@ import ( "time" "database/sql" - + driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,6 +46,17 @@ type Test struct { Skip string `yaml:"skip"` } +// Set this environment variable to effectively disable timeouts for debugging. +const debugEnvKey = "DOLT_SQL_SERVER_TEST_DEBUG" +var timeout = 20 * time.Second + +func init() { + _, ok := os.LookupEnv(debugEnvKey) + if ok { + timeout = 1000 * time.Hour + } +} + func ParseTestsFile(path string) (TestDef, error) { contents, err := os.ReadFile(path) if err != nil { @@ -78,7 +89,15 @@ func MakeServer(t *testing.T, dc driver.DoltCmdable, s *driver.Server) *driver.S if s.Port != 0 { opts = append(opts, driver.WithPort(s.Port)) } - server, err := driver.StartSqlServer(dc, opts...) + + var server *driver.SqlServer + var err error + if s.DebugPort != 0 { + server, err = driver.DebugSqlServer(dc, s.DebugPort, opts...) + } else { + server, err = driver.StartSqlServer(dc, opts...) + } + require.NoError(t, err) if len(s.ErrorMatches) > 0 { err := server.ErrorStop() @@ -198,6 +217,16 @@ func RunTestsFile(t *testing.T, path string) { } } +func RunSingleTest(t *testing.T, path string, testName string) { + def, err := ParseTestsFile(path) + require.NoError(t, err) + for _, test := range def.Tests { + if test.Name == testName { + t.Run(test.Name, test.Run) + } + } +} + type retryTestingT struct { *testing.T errorfStrings []string @@ -268,7 +297,7 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q driver.Query) { args[i] = q.Args[i] } if q.Query != "" { - ctx, c := context.WithTimeout(context.Background(), 20*time.Second) + ctx, c := context.WithTimeout(context.Background(), timeout) defer c() rows, err := conn.QueryContext(ctx, q.Query, args...) if err == nil { @@ -291,7 +320,7 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q driver.Query) { require.Contains(t, *q.Result.Rows.Or, rowstrings) } } else if q.Exec != "" { - ctx, c := context.WithTimeout(context.Background(), 20*time.Second) + ctx, c := context.WithTimeout(context.Background(), timeout) defer c() _, err := conn.ExecContext(ctx, q.Exec, args...) if q.ErrorMatch == "" { diff --git a/integration-tests/go-sql-server-driver/tests/README.md b/integration-tests/go-sql-server-driver/tests/README.md new file mode 100644 index 0000000000..a4f2776c74 --- /dev/null +++ b/integration-tests/go-sql-server-driver/tests/README.md @@ -0,0 +1,78 @@ +# SQL server tests + +These are the definitions for tests that spin up multiple sql-server instances +to test things aren't easily testable otherwise, such as replication. They're +defined as YAML and run with a custom test runner, but they're just golang unit +tests you can run in your IDE. + +These are difficult to debug because they start multiple separate processes, but +there's support for attaching a debugger to make it possible. + +# Debugging a test + +First set `DOLT_SQL_SERVER_TEST_DEBUG` in your environment. This will increase +the timeouts for queries and other processes to give you time to debug. (Don't +set this when not debugging, as it will lead to failing tests that hang instead +of promptly failing). + +Next, find the test you want to debug and reference it in `TestSingle`, like so: + +```go +func TestSingle(t *testing.T) { +// t.Skip() + RunSingleTest(t, "tests/sql-server-cluster.yaml", "primary comes up and replicates to standby") +} +``` + +Then edit the test to add a `debug_port` for any servers you want to connect to. + +```yaml +- name: primary comes up and replicates to standby + multi_repos: + - name: server1 + ... + server: + args: ["--port", "3309"] + port: 3309 + debug_port: 4009 +``` + +When the test is run, the `sql-server` process will wait for the remote debugger +to connect before starting. You probably want to enable this on every server in +the test definition. Use a different port for each. + +In your IDE, set up N+1 run configurations: one for each of the N servers in the +test, and 1 to run the test itself. Follow the instructions here to create a new +remote debug configuration for each server, using the ports you defined in the +YAML file. + +https://www.jetbrains.com/help/go/attach-to-running-go-processes-with-debugger.html#step-3-create-the-remote-run-debug-configuration-on-the-client-computer + +The main test should be something like +`github.com/dolthub/dolt/integration-tests/go-sql-server-driver#TestSingle`, and +this is where you want to set the `DOLT_SQL_SERVER_TEST_DEBUG` environment +variable if you don't have it set in your main environment. + +Then Run or Debug the main test (either works fine), and wait for the console +output that indicates the server is waiting for the debugger to attach: + +``` +API server listening at: [::]:4009 +``` + +Then Debug the remote-debug configuration(s) you have set up. They should +connect to one of the running server processes, at which point they will +continue execution. Breakpoints and other debugger features should work as +normal. + +# Caveats and gotchas + +* The `dolt` binary run by these tests is whatever is found on your `$PATH`. If + you make changes locally, you need to rebuild that binary to see them + reflected. +* For debugging support, `dlv` needs to be on your `$PATH` as well. +* Some tests restart the server. When this happens, they will once again wait + for a debugger to connect. You'll need to re-invoke the appropriate + remote-debugger connection for the process to continue. +* These tests are expected to work on Windows as well. Just have `dolt.exe` and + `dlv.exe` on your windows `%PATH%` and it should all work. diff --git a/integration-tests/mysql-client-tests/node/workbenchTests/databases.js b/integration-tests/mysql-client-tests/node/workbenchTests/databases.js index ef08912a4f..4d78d06757 100644 --- a/integration-tests/mysql-client-tests/node/workbenchTests/databases.js +++ b/integration-tests/mysql-client-tests/node/workbenchTests/databases.js @@ -19,7 +19,6 @@ export const databaseTests = [ q: `SHOW DATABASES`, res: [ { Database: `${dbName}` }, - { Database: `${dbName}/main` }, { Database: "information_schema" }, { Database: "mysql" }, ], @@ -40,7 +39,6 @@ export const databaseTests = [ q: `SHOW DATABASES`, res: [ { Database: `${dbName}` }, - { Database: `${dbName}/main` }, { Database: "information_schema" }, { Database: "mysql" }, { Database: "new_db" }, diff --git a/integration-tests/mysql-client-tests/node/workbenchTests/table.js b/integration-tests/mysql-client-tests/node/workbenchTests/table.js index c9972e5738..332aec248b 100644 --- a/integration-tests/mysql-client-tests/node/workbenchTests/table.js +++ b/integration-tests/mysql-client-tests/node/workbenchTests/table.js @@ -74,7 +74,7 @@ export const tableTests = [ FROM information_schema.statistics WHERE table_schema=:tableSchema AND table_name=:tableName AND index_name!="PRIMARY" GROUP BY index_name;`, - p: { tableSchema: `${dbName}/main`, tableName: "test" }, + p: { tableSchema: `${dbName}`, tableName: "test" }, res: [ { TABLE_NAME: "test", @@ -122,19 +122,19 @@ export const tableTests = [ }, { q: `SELECT * FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE table_name=:tableName AND table_schema=:tableSchema AND referenced_table_schema IS NOT NULL`, - p: { tableName: "test_info", tableSchema: `${dbName}/main` }, + p: { tableName: "test_info", tableSchema: `${dbName}` }, res: [ { CONSTRAINT_CATALOG: "def", - CONSTRAINT_SCHEMA: `${dbName}/main`, + CONSTRAINT_SCHEMA: `${dbName}`, CONSTRAINT_NAME: "s7utamh8", TABLE_CATALOG: "def", - TABLE_SCHEMA: `${dbName}/main`, + TABLE_SCHEMA: `${dbName}`, TABLE_NAME: "test_info", COLUMN_NAME: "test_pk", ORDINAL_POSITION: 1, POSITION_IN_UNIQUE_CONSTRAINT: 1, - REFERENCED_TABLE_SCHEMA: `${dbName}/main`, + REFERENCED_TABLE_SCHEMA: `${dbName}`, REFERENCED_TABLE_NAME: "test", REFERENCED_COLUMN_NAME: "pk", },