From b02fcd2fe984a7ccf19e30e4f763c129b255ba6c Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Thu, 23 May 2024 12:55:28 -0700 Subject: [PATCH] Moved search path functionality to its package --- go/libraries/doltcore/doltdb/root_val.go | 19 +++ go/libraries/doltcore/sqle/database.go | 136 ++---------------- .../doltcore/sqle/database_provider.go | 7 +- .../doltcore/sqle/dprocedures/dolt_add.go | 4 + .../doltcore/sqle/search_path/search_path.go | 117 +++++++++++++++ 5 files changed, 152 insertions(+), 131 deletions(-) create mode 100755 go/libraries/doltcore/sqle/search_path/search_path.go diff --git a/go/libraries/doltcore/doltdb/root_val.go b/go/libraries/doltcore/doltdb/root_val.go index 698047f9be..4050ca8832 100644 --- a/go/libraries/doltcore/doltdb/root_val.go +++ b/go/libraries/doltcore/doltdb/root_val.go @@ -23,6 +23,7 @@ import ( "strings" flatbuffers "github.com/dolthub/flatbuffers/v23/go" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/gen/fb/serial" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" @@ -1211,3 +1212,21 @@ func NewDataCacheKey(rv RootValue) (DataCacheKey, error) { return DataCacheKey{hash}, nil } + +// ResolveDatabaseSchema returns the case-sensitive name for the schema requested, if it exists, and an error if +// schemas could not be loaded. +func ResolveDatabaseSchema(ctx *sql.Context, root RootValue, schemaName string) (string, bool, error) { + schemas, err := root.GetDatabaseSchemas(ctx) + if err != nil { + return "", false, err + } + + for _, databaseSchema := range schemas { + if strings.EqualFold(databaseSchema.Name, schemaName) { + return databaseSchema.Name, true, nil + } + } + + return "", false, nil +} + diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index c778ee30e9..ba9ba21b17 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -22,6 +22,7 @@ import ( "strings" "time" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/search_path" sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" @@ -677,9 +678,9 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s } var tbl *doltdb.Table - if UseSearchPath && db.schemaName == "" { + if search_path.UseSearchPath && db.schemaName == "" { var schemaName string - tableName, schemaName, tbl, ok, err = db.resolveTableWithSearchPath(ctx, root, tableName) + tableName, schemaName, tbl, ok, err = search_path.ResolveTableWithSearchPath(ctx, root, tableName) if err != nil { return nil, false, err } else if !ok { @@ -750,65 +751,6 @@ func (db Database) resolveTable(ctx *sql.Context, root doltdb.RootValue, tableNa return tableName, tbl, true, nil } -var defaultSearchPath = "doltgres,public" - -func (db Database) resolveTableWithSearchPath(ctx *sql.Context, root doltdb.RootValue, tableName string) (string, string, *doltdb.Table, bool, error) { - schemasToSearch, err := searchPath(ctx) - if err != nil { - return "", "", nil, false, err - } - - for _, schemaName := range schemasToSearch { - tablesInSchema, err := root.GetTableNames(ctx, schemaName) - if err != nil { - return "", "", nil, false, err - } - - correctedTableName, ok := sql.GetTableNameInsensitive(tableName, tablesInSchema) - if !ok { - continue - } - - // TODO: what schema name do we use for system tables? - tbl, ok, err := root.GetTable(ctx, doltdb.TableName{Name: correctedTableName, Schema: schemaName}) - if err != nil { - return "", "", nil, false, err - } else if !ok { - // Should be impossible - return "", "", nil, false, doltdb.ErrTableNotFound - } - - return correctedTableName, schemaName, tbl, true, nil - } - - return "", "", nil, false, nil -} - -// searchPath returns all the schemas in the search_path setting, with elements like "$user" expanded -func searchPath(ctx *sql.Context) ([]string, error) { - searchPathVar, err := ctx.GetSessionVariable(ctx, "search_path") - if err != nil { - return nil, err - } - - pathElems := strings.Split(searchPathVar.(string), ",") - path := make([]string, len(pathElems)) - for i, pathElem := range pathElems { - path[i] = normalizeSearchPathSchema(ctx, pathElem) - } - - return path, nil -} - -func normalizeSearchPathSchema(ctx *sql.Context, schemaName string) string { - schemaName = strings.Trim(schemaName, " ") - if schemaName == "\"$user\"" { - client := ctx.Session.Client() - return client.User - } - return schemaName -} - // newDoltTable returns a sql.Table wrapping the given underlying dolt table func (db Database) newDoltTable(tableName string, sch schema.Schema, tbl *doltdb.Table) (sql.Table, error) { readonlyTable, err := NewDoltTable(tableName, sch, tbl, db, db.editOpts) @@ -1115,8 +1057,8 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, schemaName } root := ws.WorkingRoot() - if UseSearchPath && db.schemaName == "" { - schemaName, err = firstExistingSchemaOnSearchPath(ctx, root) + if search_path.UseSearchPath && db.schemaName == "" { + schemaName, err = search_path.FirstExistingSchemaOnSearchPath(ctx, root) if err != nil { return err } @@ -1158,49 +1100,6 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, schemaName return db.createDoltTable(ctx, tableName, schemaName, root, doltSch) } -// firstExistingSchemaOnSearchPath returns the first schema in the search path that exists in the database. -func firstExistingSchemaOnSearchPath(ctx *sql.Context, root doltdb.RootValue) (string, error) { - schemas, err := searchPath(ctx) - if err != nil { - return "", err - } - - schemaName := "" - for _, s := range schemas { - var exists bool - schemaName, exists, err = resolveDatabaseSchema(ctx, root, s) - if err != nil { - return "", err - } - - if exists { - break - } - } - - // No existing schema found in the search_path and none specified in the statement means we can't create the table - if schemaName == "" { - return "", sql.ErrDatabaseNoDatabaseSchemaSelectedCreate.New() - } - - return schemaName, nil -} - -func hasDatabaseSchema(ctx context.Context, root doltdb.RootValue, schemaName string) (bool, error) { - schemas, err := root.GetDatabaseSchemas(ctx) - if err != nil { - return false, err - } - - for _, schema := range schemas { - if strings.EqualFold(schema.Name, schemaName) { - return true, nil - } - } - - return false, nil -} - // createIndexedSqlTable is the private version of createSqlTable. It doesn't enforce any table name checks. func (db Database) createIndexedSqlTable(ctx *sql.Context, tableName string, schemaName string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error { ws, err := db.GetWorkingSet(ctx) @@ -1209,8 +1108,8 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, tableName string, sch } root := ws.WorkingRoot() - if UseSearchPath && db.schemaName == "" { - schemaName, err = firstExistingSchemaOnSearchPath(ctx, root) + if search_path.UseSearchPath && db.schemaName == "" { + schemaName, err = search_path.FirstExistingSchemaOnSearchPath(ctx, root) if err != nil { return err } @@ -1324,7 +1223,7 @@ func (db Database) CreateSchema(ctx *sql.Context, schemaName string) error { return err } - _, exists, err := resolveDatabaseSchema(ctx, root, schemaName) + _, exists, err := doltdb.ResolveDatabaseSchema(ctx, root, schemaName) if err != nil { return err } @@ -1343,23 +1242,6 @@ func (db Database) CreateSchema(ctx *sql.Context, schemaName string) error { return db.SetRoot(ctx, root) } -// resolveDatabaseSchema returns the case-sensitive name for the schema requested, if it exists, and an error if -// schemas could not be loaded. -func resolveDatabaseSchema(ctx *sql.Context, root doltdb.RootValue, schemaName string) (string, bool, error) { - schemas, err := root.GetDatabaseSchemas(ctx) - if err != nil { - return "", false, err - } - - for _, databaseSchema := range schemas { - if strings.EqualFold(databaseSchema.Name, schemaName) { - return databaseSchema.Name, true, nil - } - } - - return "", false, nil -} - // GetSchema implements sql.SchemaDatabase func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseSchema, bool, error) { ws, err := db.GetWorkingSet(ctx) @@ -1381,7 +1263,7 @@ func (db Database) GetSchema(ctx *sql.Context, schemaName string) (sql.DatabaseS } // For a temporary backwards compatibility solution, always pretend the public schema exists. - // Should create it explicitly when we create a new db in future. + // We create it explicitly for new databases. if strings.EqualFold(schemaName, "public") { db.schemaName = "public" return db, true, nil diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 1fe3e96c2b..e7ab140a59 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -23,6 +23,7 @@ import ( "strings" "sync" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/search_path" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/cmd/dolt/cli" @@ -66,8 +67,6 @@ type DoltDatabaseProvider struct { isStandby *bool } -var UseSearchPath = false - var _ sql.DatabaseProvider = (*DoltDatabaseProvider)(nil) var _ sql.SchemaDatabaseProvider = (*DoltDatabaseProvider)(nil) var _ sql.FunctionProvider = (*DoltDatabaseProvider)(nil) @@ -252,7 +251,7 @@ func (p *DoltDatabaseProvider) Database(ctx *sql.Context, name string) (sql.Data func (p *DoltDatabaseProvider) SchemaDatabase(ctx *sql.Context, dbName, schemaName string) (sql.DatabaseSchema, bool, error) { // If search path isn't enabled, this becomes a simple DB lookup on the schema name, which is the qualifier specified // in the query - if !UseSearchPath { + if !search_path.UseSearchPath { database, err := p.Database(ctx, schemaName) return database, err == nil, err } @@ -485,7 +484,7 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str } // If the search path is enabled, we need to create our initial schema object (public is available by default) - if UseSearchPath { + if search_path.UseSearchPath { workingRoot, err := newEnv.WorkingRoot(ctx) if err != nil { return err diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_add.go b/go/libraries/doltcore/sqle/dprocedures/dolt_add.go index 9068f244c0..3d2292a6cb 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_add.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_add.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/search_path" "github.com/dolthub/go-mysql-server/sql" ) @@ -92,6 +93,9 @@ func doDoltAdd(ctx *sql.Context, args []string) (int, error) { } // TODO: schema name + if search_path.UseSearchPath { + + } roots, err = actions.StageTables(ctx, roots, doltdb.ToTableNames(apr.Args, doltdb.DefaultSchemaName), !apr.Contains(cli.ForceFlag)) if err != nil { return 1, err diff --git a/go/libraries/doltcore/sqle/search_path/search_path.go b/go/libraries/doltcore/sqle/search_path/search_path.go new file mode 100755 index 0000000000..a302c4d285 --- /dev/null +++ b/go/libraries/doltcore/sqle/search_path/search_path.go @@ -0,0 +1,117 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package search_path + +import ( + "strings" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/go-mysql-server/sql" +) + +// UseSearchPath is a global variable that determines whether or not to use the search path when resolving table names. +// Currently used by Doltgres +var UseSearchPath = false + +// ResolveTableWithSearchPath resolves a table name to a table in the root value, searching through the schemas in the +func ResolveTableWithSearchPath( + ctx *sql.Context, + root doltdb.RootValue, + tableName string, +) (string, string, *doltdb.Table, bool, error) { + schemasToSearch, err := SearchPath(ctx) + if err != nil { + return "", "", nil, false, err + } + + for _, schemaName := range schemasToSearch { + tablesInSchema, err := root.GetTableNames(ctx, schemaName) + if err != nil { + return "", "", nil, false, err + } + + correctedTableName, ok := sql.GetTableNameInsensitive(tableName, tablesInSchema) + if !ok { + continue + } + + // TODO: what schema name do we use for system tables? + tbl, ok, err := root.GetTable(ctx, doltdb.TableName{Name: correctedTableName, Schema: schemaName}) + if err != nil { + return "", "", nil, false, err + } else if !ok { + // Should be impossible + return "", "", nil, false, doltdb.ErrTableNotFound + } + + return correctedTableName, schemaName, tbl, true, nil + } + + return "", "", nil, false, nil +} + +// SearchPath returns all the schemas in the search_path setting, with elements like "$user" expanded +func SearchPath(ctx *sql.Context) ([]string, error) { + searchPathVar, err := ctx.GetSessionVariable(ctx, "search_path") + if err != nil { + return nil, err + } + + pathElems := strings.Split(searchPathVar.(string), ",") + path := make([]string, len(pathElems)) + for i, pathElem := range pathElems { + path[i] = normalizeSearchPathSchema(ctx, pathElem) + } + + return path, nil +} + +func normalizeSearchPathSchema(ctx *sql.Context, schemaName string) string { + schemaName = strings.Trim(schemaName, " ") + if schemaName == "\"$user\"" { + client := ctx.Session.Client() + return client.User + } + return schemaName +} + +// FirstExistingSchemaOnSearchPath returns the first schema in the search path that exists in the database. +func FirstExistingSchemaOnSearchPath(ctx *sql.Context, root doltdb.RootValue) (string, error) { + schemas, err := SearchPath(ctx) + if err != nil { + return "", err + } + + schemaName := "" + for _, s := range schemas { + var exists bool + schemaName, exists, err = doltdb.ResolveDatabaseSchema(ctx, root, s) + if err != nil { + return "", err + } + + if exists { + break + } + } + + // No existing schema found in the search_path and none specified in the statement means we can't create the table + if schemaName == "" { + return "", sql.ErrDatabaseNoDatabaseSchemaSelectedCreate.New() + } + + return schemaName, nil +} +