diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 4f2c9e7d6f..2795e78e14 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -208,6 +208,68 @@ func TestScripts(t *testing.T) { enginetest.TestScripts(t, newDoltHarness(t).WithSkippedQueries(skipped)) } +// TestDoltUserPrivileges tests Dolt-specific code that needs to handle user privilege checking +func TestDoltUserPrivileges(t *testing.T) { + harness := newDoltHarness(t) + for _, script := range DoltUserPrivTests { + t.Run(script.Name, func(t *testing.T) { + myDb := harness.NewDatabase("mydb") + databases := []sql.Database{myDb} + engine := enginetest.NewEngineWithDbs(t, harness, databases) + defer engine.Close() + + ctx := enginetest.NewContextWithClient(harness, sql.Client{ + User: "root", + Address: "localhost", + }) + engine.Analyzer.Catalog.GrantTables.AddRootAccount() + + for _, statement := range script.SetUpScript { + if sh, ok := interface{}(harness).(enginetest.SkippingHarness); ok { + if sh.SkipQueryTest(statement) { + t.Skip() + } + } + enginetest.RunQueryWithContext(t, engine, ctx, statement) + } + for _, assertion := range script.Assertions { + if sh, ok := interface{}(harness).(enginetest.SkippingHarness); ok { + if sh.SkipQueryTest(assertion.Query) { + t.Skipf("Skipping query %s", assertion.Query) + } + } + + user := assertion.User + host := assertion.Host + if user == "" { + user = "root" + } + if host == "" { + host = "localhost" + } + ctx := enginetest.NewContextWithClient(harness, sql.Client{ + User: user, + Address: host, + }) + + if assertion.ExpectedErr != nil { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, ctx, assertion.Query, assertion.ExpectedErr) + }) + } else if assertion.ExpectedErrStr != "" { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.AssertErrWithCtx(t, engine, ctx, assertion.Query, nil, assertion.ExpectedErrStr) + }) + } else { + t.Run(assertion.Query, func(t *testing.T) { + enginetest.TestQueryWithContext(t, ctx, engine, assertion.Query, assertion.Expected, nil, nil) + }) + } + } + }) + } +} + func TestUserPrivileges(t *testing.T) { enginetest.TestUserPrivileges(t, newDoltHarness(t)) } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 88000181cd..92f5931ceb 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -94,6 +94,121 @@ var DoltScripts = []enginetest.ScriptTest{ }, } +// DoltUserPrivTests are tests for Dolt-specific functionality that includes privilege checking logic. +var DoltUserPrivTests = []enginetest.UserPrivilegeTest{ + { + Name: "dolt_diff table function privilege checking", + SetUpScript: []string{ + "CREATE TABLE mydb.test (pk BIGINT PRIMARY KEY);", + "CREATE TABLE mydb.test2 (pk BIGINT PRIMARY KEY);", + "SELECT DOLT_COMMIT('-am', 'creating tables test and test2');", + "INSERT INTO mydb.test VALUES (1);", + "SELECT DOLT_COMMIT('-am', 'inserting into test');", + "CREATE USER tester@localhost;", + }, + Assertions: []enginetest.UserPrivilegeTestAssertion{ + { + // Without access to the database, dolt_diff should fail with a database access error + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + // Grant single-table access to the underlying user table + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.test TO tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After granting access to mydb.test, dolt_diff should work + User: "tester", + Host: "localhost", + Query: "SELECT COUNT(*) FROM dolt_diff('test', 'main~', 'main');", + Expected: []sql.Row{{1}}, + }, + { + // With access to the db, but not the table, dolt_diff should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff('test2', 'main~', 'main');", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + // Revoke select on mydb.test + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.test from tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After revoking access, dolt_diff should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + // Grant multi-table access for all of mydb + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON mydb.* to tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After granting access to the entire db, dolt_diff should work + User: "tester", + Host: "localhost", + Query: "SELECT COUNT(*) FROM dolt_diff('test', 'main~', 'main');", + Expected: []sql.Row{{1}}, + }, + { + // Revoke multi-table access + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON mydb.* from tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After revoking access, dolt_diff should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + { + // Grant global access to *.* + User: "root", + Host: "localhost", + Query: "GRANT SELECT ON *.* to tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After granting global access to *.*, dolt_diff should work + User: "tester", + Host: "localhost", + Query: "SELECT COUNT(*) FROM dolt_diff('test', 'main~', 'main');", + Expected: []sql.Row{{1}}, + }, + { + // Revoke global access + User: "root", + Host: "localhost", + Query: "REVOKE ALL ON *.* from tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + // After revoking global access, dolt_diff should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, + }, + }, +} + var HistorySystemTableScriptTests = []enginetest.ScriptTest{ { Name: "empty table",