fix normal db compatibility that use @ character

This commit is contained in:
elianddb
2026-02-18 16:41:28 -08:00
parent fa4da6b6d8
commit 6a0b0fc895
12 changed files with 543 additions and 148 deletions
+1
View File
@@ -26,4 +26,5 @@ CLAUDE.md
.gitattributes
.de/
.cursor
AGENTS.md
+124
View File
@@ -0,0 +1,124 @@
// Copyright 2026 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prompt
import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
)
// Context contains shell prompt database and revision values.
type Context struct {
BaseDatabase string
ActiveRevision string
}
// Resolver resolves prompt context for the active SQL session. The bool return indicates whether context was resolved.
type Resolver interface {
Resolve(sqlCtx *sql.Context, queryist cli.Queryist) (Context, bool, error)
}
type doltSystemVariablesResolver struct{}
type sqlDBActiveBranchResolver struct{}
type chainedResolver struct {
resolvers []Resolver
}
// NewResolver returns prompt context resolution with a variables-first strategy and a legacy fallback.
func NewResolver() Resolver {
return chainedResolver{
resolvers: []Resolver{
doltSystemVariablesResolver{},
sqlDBActiveBranchResolver{},
},
}
}
func (p chainedResolver) Resolve(sqlCtx *sql.Context, queryist cli.Queryist) (Context, bool, error) {
for _, resolver := range p.resolvers {
context, ok, err := resolver.Resolve(sqlCtx, queryist)
if err != nil {
return Context{}, false, err
}
if ok {
return context, true, nil
}
}
return Context{}, false, nil
}
func (doltSystemVariablesResolver) Resolve(sqlCtx *sql.Context, queryist cli.Queryist) (Context, bool, error) {
variableValues, err := cli.GetSystemVariableValues(queryist, sqlCtx, dsess.DoltBaseDatabase, dsess.DoltActiveRevision)
if err != nil {
return Context{}, false, err
}
baseDatabase, hasBase := variableValues[dsess.DoltBaseDatabase]
activeRevision, hasRevision := variableValues[dsess.DoltActiveRevision]
if !hasBase || !hasRevision {
return Context{}, false, nil
}
return Context{
BaseDatabase: baseDatabase,
ActiveRevision: activeRevision,
}, true, nil
}
func (sqlDBActiveBranchResolver) Resolve(sqlCtx *sql.Context, queryist cli.Queryist) (Context, bool, error) {
dbRows, err := cli.GetRowsForSql(queryist, sqlCtx, "select database() as db")
if err != nil {
return Context{}, false, err
}
if len(dbRows) == 0 || dbRows[0] == nil {
return Context{}, true, nil
}
baseDatabase := ""
activeRevision := ""
dbName, err := cli.GetStringColumnValue(dbRows[0])
if err != nil {
return Context{}, false, err
}
baseDatabase, activeRevision = doltdb.SplitRevisionDbName(dbName)
// Revision-qualified names already contain the revision and do not require active_branch().
if activeRevision != "" {
return Context{
BaseDatabase: baseDatabase,
ActiveRevision: activeRevision,
}, true, nil
}
branchRows, err := cli.GetRowsForSql(queryist, sqlCtx, "select active_branch() as branch")
if err != nil {
return Context{}, false, err
}
if len(branchRows) > 0 && branchRows[0] != nil {
activeRevision, err = cli.GetStringColumnValue(branchRows[0])
if err != nil {
return Context{}, false, err
}
}
return Context{
BaseDatabase: baseDatabase,
ActiveRevision: activeRevision,
}, true, nil
}
+54
View File
@@ -16,6 +16,7 @@ package cli
import (
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
)
@@ -71,6 +72,42 @@ func SetSystemVar(queryist Queryist, sqlCtx *sql.Context, newVal bool) (func() e
return update, err
}
// GetSystemVariableValues returns a map of lower-case variable names to values for all variables that exist on the
// connected server. Variables missing from the result map are not supported by that server version.
func GetSystemVariableValues(queryist Queryist, sqlCtx *sql.Context, variableNames ...string) (values map[string]string, err error) {
values = make(map[string]string, len(variableNames))
var queryBuilder strings.Builder
queryBuilder.WriteString("SHOW VARIABLES WHERE VARIABLE_NAME IN (")
for i, variableName := range variableNames {
queryBuilder.WriteRune('\'')
queryBuilder.WriteString(variableName)
queryBuilder.WriteRune('\'')
if i != len(variableNames)-1 {
queryBuilder.WriteRune(',')
}
}
queryBuilder.WriteRune(')')
rows, err := GetRowsForSql(queryist, sqlCtx, queryBuilder.String())
if err != nil {
return nil, err
}
for _, row := range rows {
name, err := GetStringColumnValue(row[0])
if err != nil {
continue
}
value, err := GetStringColumnValue(row[1])
if err != nil {
continue
}
values[strings.ToLower(name)] = value
}
return values, nil
}
func GetRowsForSql(queryist Queryist, sqlCtx *sql.Context, query string) ([]sql.Row, error) {
_, rowIter, _, err := queryist.Query(sqlCtx, query)
if err != nil {
@@ -83,3 +120,20 @@ func GetRowsForSql(queryist Queryist, sqlCtx *sql.Context, query string) ([]sql.
return rows, nil
}
func GetStringColumnValue(value any) (str string, err error) {
if value == nil {
return "", nil
}
switch v := value.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
case fmt.Stringer:
return v.String(), nil
default:
return "", fmt.Errorf("unexpected type %T, expected string-like column value", value)
}
}
+33 -38
View File
@@ -40,6 +40,7 @@ import (
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/cli/prompt"
"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
@@ -705,7 +706,10 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
historyFile := filepath.Join(".sqlhistory") // history file written to working dir
db, branch, _ := getDBBranchFromSession(sqlCtx, qryist)
db, branch, ok := getDBBranchFromSession(sqlCtx, qryist)
if !ok {
return fmt.Errorf("Warning: unable to determine database branch from session")
}
dirty := false
if branch != "" {
dirty, _ = isDirty(sqlCtx, qryist)
@@ -957,12 +961,19 @@ func preprocessQuery(query, lastQuery string, cliCtx cli.CliContext) (CommandTyp
// postCommandUpdate is a helper function that is run after the shell has completed a command. It updates the the database
// if needed, and generates new prompts for the shell (based on the branch and if the workspace is dirty).
func postCommandUpdate(sqlCtx *sql.Context, qryist cli.Queryist) (string, string) {
db, branch, _ := getDBBranchFromSession(sqlCtx, qryist)
dirty := false
if branch != "" {
dirty, _ = isDirty(sqlCtx, qryist)
db, revision, ok := getDBBranchFromSession(sqlCtx, qryist)
if !ok {
cli.PrintErrln(color.YellowString("Failed to resolve database revision."))
}
return formattedPrompts(db, branch, dirty)
dirty := false
if revision != "" {
var err error
dirty, err = isDirty(sqlCtx, qryist)
if err != nil {
cli.PrintErrln(err.Error())
}
}
return formattedPrompts(db, revision, dirty)
}
// formattedPrompts returns the prompt and multiline prompt for the current session. If the db is empty, the prompt will
@@ -994,7 +1005,7 @@ func formattedPrompts(db, branch string, dirty bool) (string, string) {
return fmt.Sprintf("%s/%s%s> ", cyanDb, yellowBr, dirtyStr), multi
}
// getDBBranchFromSession returns the current database name and current branch for the session, handling all the errors
// getDBBranchFromSession returns the current database name and current branch for the session, handling all the errors
// along the way by printing red error messages to the CLI. If there was an issue getting the db name, the ok return
// value will be false and the strings will be empty.
func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, branch string, ok bool) {
@@ -1004,41 +1015,28 @@ func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string
return "", "", false
}
defer qryist.Query(sqlCtx, "set lock_warnings = 0")
if sqlCtx.Session.GetCurrentDatabase() == "" {
return "", "", true
}
_, resp, _, err := qryist.Query(sqlCtx, "select database() as db, active_branch() as branch")
resolver := prompt.NewResolver()
promptContext, resolved, err := resolver.Resolve(sqlCtx, qryist)
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return db, branch, false
cli.Println(color.RedString("Failed to resolve shell prompt: " + err.Error()))
return "", "", false
}
// Expect single row result, with two columns: db name, branch name.
row, err := resp.Next(sqlCtx)
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return db, branch, false
}
if len(row) != 2 {
cli.Println(color.RedString("Runtime error. Invalid column count."))
return db, branch, false
if !resolved {
return "", "", true
}
if row[1] == nil {
branch = ""
} else {
branch = row[1].(string)
}
if row[0] == nil {
db = ""
} else {
fullName := row[0].(string)
db, _ = doltdb.SplitRevisionDbName(fullName)
}
return db, branch, true
return promptContext.BaseDatabase, promptContext.ActiveRevision, true
}
// isDirty returns true if the workspace is dirty, false otherwise. This function _assumes_ you are on a database
// with a branch. If you are not, you will get an error.
func isDirty(sqlCtx *sql.Context, qryist cli.Queryist) (bool, error) {
const promptDirtyStatusErrPrefix = "Failed to determine shell prompt '*' (dirty) status"
_, _, _, err := qryist.Query(sqlCtx, "set lock_warnings = 1")
if err != nil {
return false, err
@@ -1048,18 +1046,15 @@ func isDirty(sqlCtx *sql.Context, qryist cli.Queryist) (bool, error) {
_, resp, _, err := qryist.Query(sqlCtx, "select count(table_name) > 0 as dirty from dolt_status")
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return false, err
return false, fmt.Errorf("%s: %w", promptDirtyStatusErrPrefix, err)
}
// Expect single row result, with one boolean column.
row, err := resp.Next(sqlCtx)
if err != nil {
cli.Println(color.RedString("Failure to get DB Name for session: " + err.Error()))
return false, err
return false, fmt.Errorf("%s: %w", promptDirtyStatusErrPrefix, err)
}
if len(row) != 1 {
cli.Println(color.RedString("Runtime error. Invalid column count."))
return false, fmt.Errorf("invalid column count")
return false, fmt.Errorf("%s: invalid column count", promptDirtyStatusErrPrefix)
}
return getStrBoolColAsBool(row[0])
+12 -13
View File
@@ -2394,27 +2394,26 @@ func RevisionDbName(baseName string, rev string) string {
return baseName + DbRevisionDelimiter + rev
}
// SplitRevisionDbName inspects the |dbName| for the DbRevisionDelimiter or DbRevisionDelimiterAlias and returns the
// separated base name and revision.
// SplitRevisionDbName returns the base database name and revision from a traditional revision-qualified name. Splits on
// the first "/".
func SplitRevisionDbName(dbName string) (string, string) {
index := strings.IndexAny(dbName, DbRevisionDelimiter+DbRevisionDelimiterAlias)
if index == -1 {
return dbName, ""
if idx := strings.Index(dbName, DbRevisionDelimiter); idx >= 0 {
return dbName[:idx], dbName[idx+1:]
}
return dbName[:index], dbName[index+1:]
return dbName, ""
}
// NormalizeRevisionDelimiter inspects the database name for a DbRevisionDelimiterAlias and if found, rewrites the
// database name to contain the normal DbRevisionDelimiter.
func NormalizeRevisionDelimiter(dbName string) (rewrite string, usesDelimiterAlias bool, err error) {
if !strings.Contains(dbName, DbRevisionDelimiterAlias) {
return dbName, false, nil
// NormalizeRevisionDelimiter rewrites "base@revision" names to "base/revision". Names that already contain "/" are
// returned unchanged so bases that include "@" keep their existing interpretation.
func NormalizeRevisionDelimiter(dbName string) (rewrite string, usesDelimiterAlias bool) {
if strings.Contains(dbName, DbRevisionDelimiter) {
return dbName, false
}
base, revision, found := strings.Cut(dbName, DbRevisionDelimiterAlias)
if !found {
return dbName, true, fmt.Errorf("could not resolve revision delimiter '%s' in %s", DbRevisionDelimiterAlias, dbName)
return dbName, found
}
return RevisionDbName(base, revision), true, nil
return RevisionDbName(base, revision), true
}
+96 -57
View File
@@ -418,7 +418,8 @@ func (p *DoltDatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool {
}
func (p *DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Database) {
_, currentRevision := doltdb.SplitRevisionDbName(ctx.GetCurrentDatabase())
normalized, usesDelimiterAlias := doltdb.NormalizeRevisionDelimiter(ctx.GetCurrentDatabase())
baseName, currentRevision := doltdb.SplitRevisionDbName(normalized)
p.mu.RLock()
showBranches, err := dsess.GetBooleanSystemVar(ctx, dsess.ShowBranchDatabases)
@@ -426,15 +427,31 @@ func (p *DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Databas
ctx.GetLogger().Warn(err)
}
potentialConflictDB, conflictDBOk := p.databases[strings.ToLower(ctx.GetCurrentDatabase())]
_, baseDBOk := p.databases[strings.ToLower(baseName)]
skipConflictDBName := ""
if conflictDBOk && baseDBOk && usesDelimiterAlias {
_, ok, err := p.databaseForRevision(ctx, normalized, ctx.GetCurrentDatabase())
if err == nil && ok {
skipConflictDBName = potentialConflictDB.AliasedName()
}
}
all = make([]sql.Database, 0, len(p.databases))
for _, db := range p.databases {
if skipConflictDBName == db.AliasedName() {
continue
}
all = append(all, db)
if showBranches && db.Name() != clusterdb.DoltClusterDbName {
revisionDbs, err := p.allRevisionDbs(ctx, db)
revisionDbs, err := p.allRevisionDbs(ctx, db, normalized)
if err != nil {
// TODO: this interface is wrong, needs to return errors
ctx.GetLogger().Warnf("error fetching revision databases: %s", err.Error())
if !sql.ErrDatabaseNotFound.Is(err) {
ctx.GetLogger().Warnf("error fetching revision databases: %s", err.Error())
}
continue
}
all = append(all, revisionDbs...)
@@ -445,16 +462,15 @@ func (p *DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Databas
// If there's a revision database in use, include it in the list (but don't double-count). When showBranches is off
// we still include the current revision db if one is in use, so the active database is always visible in SHOW
// DATABASES.
if currentRevision != "" && !showBranches {
rewrittenName, _, err := doltdb.NormalizeRevisionDelimiter(ctx.GetCurrentDatabase())
usingNormalDB := usesDelimiterAlias && conflictDBOk && !baseDBOk
if currentRevision != "" && !showBranches && !usingNormalDB {
rdb, ok, err := p.databaseForRevision(ctx, normalized, ctx.GetCurrentDatabase())
if err != nil {
ctx.GetLogger().Warn(err)
}
rdb, ok, err := p.databaseForRevision(ctx, rewrittenName, ctx.GetCurrentDatabase())
if err != nil || !ok {
// TODO: this interface is wrong, needs to return errors
ctx.GetLogger().Warnf("error fetching revision databases: %s", err.Error())
} else {
if !sql.ErrDatabaseNotFound.Is(err) {
ctx.GetLogger().Warnf("error fetching revision databases: %s", err.Error())
}
} else if ok {
all = append(all, rdb)
}
}
@@ -487,17 +503,13 @@ func (p *DoltDatabaseProvider) DoltDatabases() []dsess.SqlDatabase {
}
// allRevisionDbs returns all revision dbs for the database given
func (p *DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db dsess.SqlDatabase) ([]sql.Database, error) {
func (p *DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db dsess.SqlDatabase, currDb string) ([]sql.Database, error) {
branches, err := db.DbData().Ddb.GetBranches(ctx)
if err != nil {
return nil, err
}
revDbs := make([]sql.Database, 0, len(branches))
rewrittenName, _, err := doltdb.NormalizeRevisionDelimiter(ctx.GetCurrentDatabase())
if err != nil {
return nil, err
}
for _, branch := range branches {
revisionQualifiedName := fmt.Sprintf("%s/%s", db.Name(), branch.GetPath())
requestedName := revisionQualifiedName
@@ -505,7 +517,7 @@ func (p *DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db dsess.SqlData
var ok bool
// If the current DB matches, it means we're either using `@` or `/` delimited revision database name. So, we
// replace the revisionQualifiedName with the [ctx.GetCurrentDatabase] result to maintain the exact delimiter.
if revisionQualifiedName == rewrittenName {
if revisionQualifiedName == currDb {
requestedName = ctx.GetCurrentDatabase()
}
revDb, ok, err = p.databaseForRevision(ctx, revisionQualifiedName, requestedName)
@@ -860,7 +872,8 @@ func (p *DoltDatabaseProvider) cloneDatabaseFromRemote(
// DropDatabase implements the sql.MutableDatabaseProvider interface
func (p *DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
_, revision := doltdb.SplitRevisionDbName(name)
if revision != "" {
normalized, usesDelimiterAlias := doltdb.NormalizeRevisionDelimiter(name)
if revision != "" || (usesDelimiterAlias && p.HasDatabase(ctx, normalized)) {
return fmt.Errorf("unable to drop revision database: %s", name)
}
@@ -1058,24 +1071,15 @@ func (p *DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context,
}
func (p *DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revisionQualifiedName string, requestedName string) (dsess.SqlDatabase, bool, error) {
if !strings.Contains(revisionQualifiedName, doltdb.DbRevisionDelimiter) {
if !strings.ContainsAny(revisionQualifiedName, doltdb.DbRevisionDelimiterAlias+doltdb.DbRevisionDelimiter) {
return nil, false, nil
}
baseName, rev := doltdb.SplitRevisionDbName(revisionQualifiedName)
// Look in the session cache for this DB before doing any IO to figure out what's being asked for
sess := dsess.DSessFromSess(ctx.Session)
dbCache := sess.DatabaseCache(ctx)
db, ok := dbCache.GetCachedRevisionDb(revisionQualifiedName, requestedName)
if ok {
return db, true, nil
}
p.mu.RLock()
srcDb, ok := p.databases[formatDbMapKeyName(baseName)]
srcDb, srcOk := p.databases[formatDbMapKeyName(baseName)]
p.mu.RUnlock()
if !ok {
if !srcOk {
return nil, false, nil
}
@@ -1084,6 +1088,14 @@ func (p *DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revisionQua
return nil, false, err
}
// Cached revision db only when the revision resolves in current refs. Branch and tag changes can stale session
// cache entries, so validate the revision type before returning a cached db.
dbCache := dsess.DSessFromSess(ctx.Session).DatabaseCache(ctx)
db, ok := dbCache.GetCachedRevisionDb(revisionQualifiedName, requestedName)
if ok && dbType != dsess.RevisionTypeNone {
return db, true, nil
}
switch dbType {
case dsess.RevisionTypeBranch:
// fetch the upstream head if this is a replicated db
@@ -1442,61 +1454,70 @@ func (p *DoltDatabaseProvider) BaseDatabase(ctx *sql.Context, name string) (dses
// SessionDatabase implements dsess.SessionDatabaseProvider
func (p *DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (dsess.SqlDatabase, bool, error) {
baseName, revision := doltdb.SplitRevisionDbName(name)
normalized, usesDelimiterAlias := doltdb.NormalizeRevisionDelimiter(name)
baseName, revision := doltdb.SplitRevisionDbName(normalized)
var ok bool
p.mu.RLock()
db, ok := p.databases[strings.ToLower(baseName)]
var rawDB dsess.SqlDatabase
rawDBOk := false
if usesDelimiterAlias {
rawDB, rawDBOk = p.databases[strings.ToLower(name)]
}
standby := *p.isStandby
p.mu.RUnlock()
var err error
if !rawDBOk && usesDelimiterAlias {
rawDB, err = p.databaseForClone(ctx, strings.ToLower(name))
// Ignore error, revision needs to be evaluated first.
rawDBOk = rawDB != nil && err == nil
}
// If the database doesn't exist and this is a read replica, attempt to clone it from the remote
if !ok {
var err error
db, err = p.databaseForClone(ctx, strings.ToLower(baseName))
if err != nil {
if err != nil && !usesDelimiterAlias {
return nil, false, err
}
if db == nil {
ok = db != nil
if !ok && !rawDBOk {
return nil, false, nil
}
}
// Some DB implementations don't support addressing by versioned names, so return directly if we have one of those
if !db.Versioned() {
if ok && !db.Versioned() {
return wrapForStandby(db, standby), true, nil
}
// Convert to a revision database before returning. If we got a non-qualified name, convert it to a qualified name
// using the session's current head
usingDefaultBranch := false
head := ""
sess := dsess.DSessFromSess(ctx.Session)
revisionQualifiedName, _, err := doltdb.NormalizeRevisionDelimiter(name)
if revision == "" {
var err error
head, ok, err = sess.CurrentHead(ctx, baseName)
usingDefaultBranch := false
var head string
if ok && revision == "" {
head, usingDefaultBranch, err = p.resolveCurrentOrDefaultHead(ctx, sess, db, baseName)
if err != nil {
return nil, false, err
}
// A newly created session may not have any info on current head stored yet, in which case we get the default
// branch for the db itself instead.
if !ok {
usingDefaultBranch = true
head, err = dsess.DefaultHead(ctx, baseName, db)
if err != nil {
return nil, false, err
}
}
revisionQualifiedName = baseName + doltdb.DbRevisionDelimiter + head
normalized = baseName + doltdb.DbRevisionDelimiter + head
}
db, ok, err = p.databaseForRevision(ctx, normalized, name)
if (!ok || err != nil) && rawDBOk {
if !rawDB.Versioned() {
return wrapForStandby(rawDB, standby), true, nil
}
head, usingDefaultBranch, err = p.resolveCurrentOrDefaultHead(ctx, sess, rawDB, name)
if err != nil {
return nil, false, err
}
normalized = name + doltdb.DbRevisionDelimiter + head
db, ok, err = p.databaseForRevision(ctx, normalized, name)
}
db, ok, err = p.databaseForRevision(ctx, revisionQualifiedName, name)
if err != nil {
if sql.ErrDatabaseNotFound.Is(err) && usingDefaultBranch {
// We can return a better error message here in some cases
@@ -1515,6 +1536,24 @@ func (p *DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (d
return wrapForStandby(db, standby), true, nil
}
func (p *DoltDatabaseProvider) resolveCurrentOrDefaultHead(ctx *sql.Context, sess *dsess.DoltSession, db dsess.SqlDatabase, baseName string) (resolvedHead string, usedDefaultHead bool, err error) {
resolvedHead, ok, err := sess.CurrentHead(ctx, baseName)
if err != nil {
return "", false, err
}
if ok {
return resolvedHead, false, nil
}
// A newly created session may not have any info on current head stored yet, in which case we get the default
// branch for the db itself instead.
resolvedHead, err = dsess.DefaultHead(ctx, baseName, db)
if err != nil {
return "", false, err
}
return resolvedHead, true, nil
}
// Function implements the FunctionProvider interface
func (p *DoltDatabaseProvider) Function(_ *sql.Context, name string) (sql.Function, bool) {
fn, ok := p.functions[strings.ToLower(name)]
+39 -11
View File
@@ -176,9 +176,7 @@ func GetTableResolver(ctx *sql.Context, dbName string) (doltdb.TableResolver, er
// the interface returned by the public method.
func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchState, bool, error) {
dbName = strings.ToLower(dbName)
var baseName, rev string
baseName, rev = doltdb.SplitRevisionDbName(dbName)
baseName, rev := doltdb.SplitRevisionDbName(dbName)
d.mu.Lock()
dbState, dbStateFound := d.dbStates[baseName]
@@ -191,7 +189,6 @@ func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchSta
}
branchState, ok := dbState.heads[strings.ToLower(rev)]
if ok {
if dbState.Err != nil {
return nil, false, dbState.Err
@@ -200,15 +197,16 @@ func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*branchSta
}
}
// No state for this db / branch combination yet, look it up from the provider. We use the unqualified DB name (no
// branch) if the current DB has not yet been loaded into this session. It will resolve to that DB's default branch
// in that case.
revisionQualifiedName := dbName
if rev != "" {
revisionQualifiedName = doltdb.RevisionDbName(baseName, rev)
// usesDelimiterAlias once normalized goes to false, so any `@` character will not be treated as a delimiter.
normalized, usesDelimiterAlias := doltdb.NormalizeRevisionDelimiter(dbName)
if usesDelimiterAlias {
state, ok, err := d.lookupDbState(ctx, normalized)
if err == nil && ok {
return state, ok, err
}
}
database, ok, err := d.provider.SessionDatabase(ctx, revisionQualifiedName)
database, ok, err := d.provider.SessionDatabase(ctx, dbName)
if err != nil {
return nil, false, err
}
@@ -1308,6 +1306,36 @@ func (d *DoltSession) SetSessionVariable(ctx *sql.Context, key string, value int
return d.Session.SetSessionVariable(ctx, key, value)
}
// UseDatabase updates session state for the selected database and publishes prompt context variables so clients can
// render an unambiguous base database and active revision over the SQL wire protocol.
func (d *DoltSession) UseDatabase(ctx *sql.Context, db sql.Database) error {
if err := d.Session.UseDatabase(ctx, db); err != nil {
return err
}
baseDatabase := db.Name()
if aliased, ok := db.(sql.AliasedDatabase); ok {
baseDatabase = aliased.AliasedName()
}
activeRevision := ""
if sqlDB, ok := db.(SqlDatabase); ok {
activeRevision = sqlDB.Revision()
}
return d.setPromptContextVars(ctx, baseDatabase, activeRevision)
}
func (d *DoltSession) setPromptContextVars(ctx *sql.Context, baseDatabase, activeRevision string) error {
if err := d.Session.SetSessionVariable(ctx, DoltBaseDatabase, baseDatabase); err != nil {
return err
}
if err := d.Session.SetSessionVariable(ctx, DoltActiveRevision, activeRevision); err != nil {
return err
}
return nil
}
func (d *DoltSession) setHeadRefSessionVar(ctx *sql.Context, db, value string) error {
headRef, err := ref.Parse(value)
if err != nil {
@@ -57,6 +57,8 @@ const (
DoltLogLevel = "dolt_log_level"
ShowSystemTables = "dolt_show_system_tables"
AllowCICreation = "dolt_allow_ci_creation"
DoltBaseDatabase = "dolt_base_database"
DoltActiveRevision = "dolt_active_revision"
DoltClusterRoleVariable = "dolt_cluster_role"
DoltClusterRoleEpochVariable = "dolt_cluster_role_epoch"
@@ -404,6 +404,91 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
},
},
},
{
Name: "database revision specs: db revision delimiter alias '@' is ignored when no revision exists",
SetUpScript: []string{
"create database `mydb@branch1`;",
"create table t1(t int);",
"call dolt_commit('-Am', 'init t1');",
"create database `test-10382`;",
"use `test-10382`;",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "use `mydb@branch1`;",
Expected: []sql.Row{},
},
{
Query: "drop database `test-10382`;",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "select database();",
Expected: []sql.Row{{"mydb@branch1"}},
},
{
Query: "show databases",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb@branch1"}, {"mysql"}},
},
{
Query: "set dolt_show_branch_databases = on;",
Expected: []sql.Row{{types.NewOkResult(0)}},
},
{
Query: "show databases",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb@branch1"}, {"mydb@branch1/main"}, {"mysql"}},
},
{
Query: "use `mydb@branch1`;",
Expected: []sql.Row{},
},
{
Query: "call dolt_branch('branch2');",
Expected: []sql.Row{{0}},
},
{
Query: "use `mydb@branch1@branch2`;",
// The `@` delimiter is interpreted at the first index found, so the above is not supported.
ExpectedErr: sql.ErrDatabaseNotFound,
},
{
Query: "use `mydb@branch1`;",
Expected: []sql.Row{},
},
{
Query: "show databases",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb@branch1"}, {"mydb@branch1/main"}, {"mydb@branch1/branch2"}, {"mysql"}},
},
{
Query: "call dolt_branch('branch@');",
Expected: []sql.Row{{0}},
},
{
Query: "show databases",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb@branch1"}, {"mydb@branch1/main"}, {"mydb@branch1/branch2"}, {"mydb@branch1/branch@"}, {"mysql"}},
},
{
Query: "set dolt_show_branch_databases = off;",
Expected: []sql.Row{{types.NewOkResult(0)}},
},
{
Query: "show databases",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb@branch1"}, {"mysql"}},
},
{
Query: "select * from t1;",
ExpectedErr: sql.ErrTableNotFound,
},
{
Query: "use mydb;",
Expected: []sql.Row{},
},
{
Query: "select * from t1;",
Expected: []sql.Row{},
},
},
},
{
Name: "database revision specs: db revision delimiter alias '@'",
SetUpScript: []string{
@@ -413,29 +498,49 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
"insert into t01 values (1, 1), (2, 2);",
"call dolt_commit('-am', 'adding rows to table t01 on main');",
"call dolt_tag('tag1');",
"call dolt_branch('branch1');",
"insert into t01 values (3, 3);",
"call dolt_commit('-am', 'adding another row to table t01 on main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "create database `mydb@branch1`;",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "use mydb;",
Expected: []sql.Row{},
},
{
Query: "call dolt_branch('branch1');",
Expected: []sql.Row{{0}},
},
{
Query: "insert into t01 values (3, 3);",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "call dolt_commit('-am', 'adding rows to table t01');",
SkipResultsCheck: true,
},
{
Query: "use `mydb@main`;",
Expected: []sql.Row{},
},
{
Query: "show databases;",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb@main"}, {"mysql"}},
Query: "show databases;",
// The mydb@branch1 database is shown, not the revision `branch1` from `mydb` cause we're on `main`.
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb@branch1"}, {"mydb@main"}, {"mysql"}},
},
{
Query: "use `mydb@branch1`;",
Expected: []sql.Row{},
},
{
Query: "show databases;",
Query: "show databases;",
// The revision branch1 is shown, not the `mydb@branch1` database.
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb@branch1"}, {"mysql"}},
},
{
Query: "select database();",
Query: "select database();",
// We want to see the revision shown in the format it was requested, this is not the literal db.
Expected: []sql.Row{{"mydb@branch1"}},
},
{
@@ -468,7 +573,7 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
},
{
Query: "show databases;",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb/branch1"}, {"mysql"}},
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb/branch1"}, {"mydb@branch1"}, {"mydb@branch1/main"}, {"mysql"}},
},
{
Query: "select * from `mydb@branch1`.t01;",
@@ -478,22 +583,10 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
Query: "select * from `mydb@tag1`.t01;",
Expected: []sql.Row{{1, 1}, {2, 2}},
},
{
Query: "drop database `mydb@branch1`;",
ExpectedErrStr: "unable to drop revision database: mydb@branch1",
},
{
Query: "create database `mydb@branch1`;",
ExpectedErrStr: "can't create database mydb@branch1; database exists",
},
{
Query: "use `mydb@branch1`;",
Expected: []sql.Row{},
},
{
Query: "show databases;",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb@branch1"}, {"mysql"}},
},
{
Query: "create table parent(id int primary key);",
Expected: []sql.Row{{types.NewOkResult(0)}},
@@ -516,7 +609,7 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
},
{
Query: "show databases;",
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb/branch1"}, {"mysql"}},
Expected: []sql.Row{{"information_schema"}, {"mydb"}, {"mydb/main"}, {"mydb/branch1"}, {"mydb@branch1"}, {"mydb@branch1/main"}, {"mysql"}},
},
{
Query: "select database();",
@@ -530,6 +623,11 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
Query: "select column_name from information_schema.columns where table_schema = database() and table_name = 't01' order by ordinal_position;",
Expected: []sql.Row{{"pk"}, {"c1"}},
},
{
Query: "drop database `mydb@branch1`;",
// The name above can be resolved to a real revision so we error out, keeping parity with CREATE below.
ExpectedErrStr: "unable to drop revision database: mydb@branch1",
},
{
Query: "select table_name from information_schema.tables where table_schema = database() and table_name = 't01';",
Expected: []sql.Row{{"t01"}},
@@ -547,8 +645,29 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
Expected: []sql.Row{},
},
{
Query: "create schema `mydb@branch1`;",
ExpectedErr: sql.ErrDatabaseExists,
Query: "call dolt_branch('-D', 'branch1');",
Expected: []sql.Row{{0}},
},
{
Query: "drop database `mydb@branch1`;",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "call dolt_branch('branch1');",
Expected: []sql.Row{{0}},
},
{
Query: "create database `mydb@branch1`;",
// This is a result of GMS' internal call to the providers' HasDatabase
ExpectedErrStr: "can't create database mydb@branch1; database exists",
},
{
Query: "call dolt_branch('-D', 'branch1');",
Expected: []sql.Row{{0}},
},
{
Query: "create database `mydb@branch1`;",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
},
},
@@ -292,6 +292,20 @@ var DoltSystemVariables = []sql.SystemVariable{
Type: types.NewSystemBoolType(dsess.AllowCICreation),
Default: int8(0),
},
&sql.MysqlSystemVariable{
Name: dsess.DoltBaseDatabase,
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session),
Type: types.NewSystemStringType(dsess.DoltBaseDatabase),
Default: "",
},
&sql.MysqlSystemVariable{
Name: dsess.DoltActiveRevision,
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session),
Type: types.NewSystemStringType(dsess.DoltActiveRevision),
Default: "",
},
}
func AddDoltSystemVariables() {
@@ -554,6 +568,20 @@ func AddDoltSystemVariables() {
Type: types.NewSystemBoolType(dsess.AllowCICreation),
Default: int8(0),
},
&sql.MysqlSystemVariable{
Name: dsess.DoltBaseDatabase,
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session),
Type: types.NewSystemStringType(dsess.DoltBaseDatabase),
Default: "",
},
&sql.MysqlSystemVariable{
Name: dsess.DoltActiveRevision,
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session),
Type: types.NewSystemStringType(dsess.DoltActiveRevision),
Default: "",
},
})
sql.SystemVariables.AddSystemVariables(DoltSystemVariables)
}
@@ -6,7 +6,11 @@ set env(NO_COLOR) 1
spawn dolt sql
expect_with_defaults {>} { send "create schema mydb;\r" }
expect_with_defaults {>} { send "create database `mydb@branch1`;\r" }
expect_with_defaults {>} { send "use `mydb@branch1`;\r" }
expect_with_defaults_after {Database Changed} {mydb@branch1/main\*?>} { send "create schema `mydb`;\r" }
expect_with_defaults {>} { send "use mydb;\r" }
@@ -30,13 +34,15 @@ expect_with_defaults {>} { send "create table t1(i int);\r" }
expect_with_defaults {>} { send "call dolt_commit('-Am', 'create table t1');\r" }
expect_with_defaults {>} { send "execute use_stmt;\r" }
expect_with_defaults {>} { send "show tables;\r" }
expect_with_defaults_after {Empty set} {mydb\*?>} { send "show tables;\r" }
expect_with_defaults_after {t1} {>} { send "execute use_stmt;\r" }
expect_with_defaults {>} { send "create database `mydb@branch1`;\r" }
expect_with_defaults_after {Empty set} {mydb/[0-9a-v]{32}\*?>} { send "show tables;\r" }
expect_with_defaults_after {database exists} {mydb\*?>} { send "exit;\r" }
expect_with_defaults_after {Empty set} {mydb/[0-9a-v]{32}\*?>} { send "create database `mydb@branch1`;\r" }
expect_with_defaults_after {database exists} {mydb/[0-9a-v]{32}\*?>} { send "exit;\r" }
expect eof
exit 0
+1 -1
View File
@@ -1149,7 +1149,7 @@ expect eof
# bats test_tags=no_lambda
@test "sql-shell: sql shell respects revision database as current database" {
skiponwindows "Need to install expect and make this script work on windows."
# skiponwindows "Need to install expect and make this script work on windows."
run expect "$BATS_TEST_DIRNAME"/sql-shell-revision-db.expect
[ "$status" -eq 0 ]
}