Get stats for same table name, different schemas

This commit is contained in:
Taylor Bantle
2024-11-22 15:17:18 -08:00
parent f79892c083
commit 37c1933ea9
7 changed files with 61 additions and 28 deletions

View File

@@ -675,6 +675,14 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
}
}
case doltdb.StatisticsTableName:
if resolve.UseSearchPath && db.schemaName == "" {
schemaName, err := resolve.FirstExistingSchemaOnSearchPath(ctx, root)
if err != nil {
return nil, false, err
}
db.schemaName = schemaName
}
var tables []string
var err error
branch, ok := asOf.(string)
@@ -686,7 +694,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
if err != nil {
return nil, false, err
}
dt, found = dtables.NewStatisticsTable(ctx, db.Name(), branch, tables), true
dt, found = dtables.NewStatisticsTable(ctx, db.Name(), db.schemaName, branch, tables), true
case doltdb.ProceduresTableName:
found = true
backingTable, _, err := db.getTable(ctx, root, doltdb.ProceduresTableName)

View File

@@ -27,6 +27,7 @@ import (
// StatisticsTable is a sql.Table implementation that implements a system table which shows the dolt commit log
type StatisticsTable struct {
dbName string
schemaName string
branch string
tableNames []string
}
@@ -35,8 +36,8 @@ var _ sql.Table = (*StatisticsTable)(nil)
var _ sql.StatisticsTable = (*StatisticsTable)(nil)
// NewStatisticsTable creates a StatisticsTable
func NewStatisticsTable(_ *sql.Context, dbName, branch string, tableNames []string) sql.Table {
return &StatisticsTable{dbName: dbName, branch: branch, tableNames: tableNames}
func NewStatisticsTable(_ *sql.Context, dbName, schemaName, branch string, tableNames []string) sql.Table {
return &StatisticsTable{dbName: dbName, schemaName: schemaName, branch: branch, tableNames: tableNames}
}
// DataLength implements sql.StatisticsTable
@@ -67,7 +68,7 @@ func (st *StatisticsTable) DataLength(ctx *sql.Context) (uint64, error) {
}
type BranchStatsProvider interface {
GetTableDoltStats(ctx *sql.Context, branch, db, table string) ([]sql.Statistic, error)
GetTableDoltStats(ctx *sql.Context, branch, db, schema, table string) ([]sql.Statistic, error)
}
// RowCount implements sql.StatisticsTable
@@ -76,7 +77,7 @@ func (st *StatisticsTable) RowCount(ctx *sql.Context) (uint64, bool, error) {
var cnt int
for _, table := range st.tableNames {
// only Dolt-specific provider has branch support
dbStats, err := dSess.StatsProvider().(BranchStatsProvider).GetTableDoltStats(ctx, st.branch, st.dbName, table)
dbStats, err := dSess.StatsProvider().(BranchStatsProvider).GetTableDoltStats(ctx, st.branch, st.dbName, st.schemaName, table)
if err != nil {
}
@@ -121,7 +122,7 @@ func (st *StatisticsTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql
statsPro := dSess.StatsProvider().(BranchStatsProvider)
var dStats []sql.Statistic
for _, table := range st.tableNames {
dbStats, err := statsPro.GetTableDoltStats(ctx, st.branch, st.dbName, table)
dbStats, err := statsPro.GetTableDoltStats(ctx, st.branch, st.dbName, st.schemaName, table)
if err != nil {
return nil, err
}

View File

@@ -32,7 +32,7 @@ import (
var ErrIncompatibleVersion = errors.New("client stats version mismatch")
func NewStatsIter(ctx *sql.Context, m prolly.Map) (*statsIter, error) {
func NewStatsIter(ctx *sql.Context, schemaName string, m prolly.Map) (*statsIter, error) {
iter, err := m.IterAll(ctx)
if err != nil {
return nil, err
@@ -43,11 +43,12 @@ func NewStatsIter(ctx *sql.Context, m prolly.Map) (*statsIter, error) {
ns := m.NodeStore()
return &statsIter{
iter: iter,
kb: keyBuilder,
vb: valueBuilder,
ns: ns,
planb: planbuilder.New(ctx, nil, nil, nil),
iter: iter,
kb: keyBuilder,
vb: valueBuilder,
ns: ns,
schemaName: schemaName,
planb: planbuilder.New(ctx, nil, nil, nil),
}, nil
}
@@ -61,6 +62,7 @@ type statsIter struct {
ns tree.NodeStore
planb *planbuilder.Builder
currentQual string
schemaName string
currentTypes []sql.Type
}
@@ -118,7 +120,7 @@ func (s *statsIter) Next(ctx *sql.Context) (sql.Row, error) {
typs[i] = strings.TrimSpace(t)
}
qual := sql.NewStatQualifier(dbName, tableName, indexName)
qual := sql.NewStatQualifier(dbName, s.schemaName, tableName, indexName)
if curQual := qual.String(); !strings.EqualFold(curQual, s.currentQual) {
s.currentQual = curQual
s.currentTypes, err = parseTypeStrings(typs)

View File

@@ -39,7 +39,8 @@ import (
func loadStats(ctx *sql.Context, db dsess.SqlDatabase, m prolly.Map) (map[sql.StatQualifier]*statspro.DoltStats, error) {
qualToStats := make(map[sql.StatQualifier]*statspro.DoltStats)
iter, err := NewStatsIter(ctx, m)
schemaName := db.SchemaName()
iter, err := NewStatsIter(ctx, schemaName, m)
if err != nil {
return nil, err
}
@@ -72,7 +73,7 @@ func loadStats(ctx *sql.Context, db dsess.SqlDatabase, m prolly.Map) (map[sql.St
typs[i] = strings.TrimSpace(t)
}
qual := sql.NewStatQualifier(dbName, tableName, indexName)
qual := sql.NewStatQualifier(dbName, schemaName, tableName, indexName)
if currentStat.Statistic.Qual.String() != qual.String() {
if !currentStat.Statistic.Qual.Empty() {
currentStat.Statistic.LowerBnd, currentStat.Tb, err = loadLowerBound(ctx, db, currentStat.Statistic.Qual, len(currentStat.Columns()))

View File

@@ -100,6 +100,10 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table
tableName := strings.ToLower(table.Name())
dbName := strings.ToLower(db)
var schemaName string
if schTab, ok := table.(sql.DatabaseSchemaTable); ok {
schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName())
}
iat, ok := table.(sql.IndexAddressableTable)
if !ok {
@@ -146,7 +150,7 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table
ctx.GetLogger().Debugf("statistics refresh: detected table schema change: %s,%s/%s", dbName, table, branch)
statDb.SetSchemaHash(branch, tableName, schHash)
stats, err := p.GetTableDoltStats(ctx, branch, dbName, tableName)
stats, err := p.GetTableDoltStats(ctx, branch, dbName, schemaName, tableName)
if err != nil {
return err
}
@@ -163,7 +167,7 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table
cols[i] = strings.TrimPrefix(strings.ToLower(c), tablePrefix)
}
qual := sql.NewStatQualifier(db, table.Name(), strings.ToLower(idx.ID()))
qual := sql.NewStatQualifier(db, schemaName, table.Name(), strings.ToLower(idx.ID()))
curStat, ok := statDb.GetStat(branch, qual)
if !ok {
curStat = NewDoltStats()

View File

@@ -163,13 +163,18 @@ func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, br
return err
}
var schemaName string
if schTab, ok := sqlTable.(sql.DatabaseSchemaTable); ok {
schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName())
}
if oldSchHash := statDb.GetSchemaHash(branch, table); oldSchHash.IsEmpty() {
statDb.SetSchemaHash(branch, table, schHash)
} else if oldSchHash != schHash {
ctx.GetLogger().Debugf("statistics refresh: detected table schema change: %s,%s/%s", dbName, table, branch)
statDb.SetSchemaHash(branch, table, schHash)
stats, err := p.GetTableDoltStats(ctx, branch, dbName, table)
stats, err := p.GetTableDoltStats(ctx, branch, schemaName, dbName, table)
if err != nil {
return err
}
@@ -191,7 +196,7 @@ func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, br
// collect indexes and ranges to be updated
var idxMetas []indexMeta
for _, index := range indexes {
qual := sql.NewStatQualifier(dbName, table, strings.ToLower(index.ID()))
qual := sql.NewStatQualifier(dbName, schemaName, table, strings.ToLower(index.ID()))
qualExists[qual] = true
curStat, ok := statDb.GetStat(branch, qual)
if !ok {

View File

@@ -169,11 +169,15 @@ func (p *Provider) GetTableStats(ctx *sql.Context, db string, table sql.Table) (
return nil, nil
}
// TODO: schema name
return p.GetTableDoltStats(ctx, branch, db, table.Name())
var schemaName string
if schTab, ok := table.(sql.DatabaseSchemaTable); ok {
schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName())
}
return p.GetTableDoltStats(ctx, branch, db, schemaName, table.Name())
}
func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, table string) ([]sql.Statistic, error) {
func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, schema, table string) ([]sql.Statistic, error) {
statDb, ok := p.getStatDb(db)
if !ok || statDb == nil {
return nil, nil
@@ -190,7 +194,7 @@ func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, table string)
var ret []sql.Statistic
for _, qual := range statDb.ListStatQuals(branch) {
if strings.EqualFold(db, qual.Database) && strings.EqualFold(table, qual.Tab) {
if strings.EqualFold(db, qual.Database) && strings.EqualFold(schema, qual.Sch) && strings.EqualFold(table, qual.Tab) {
stat, _ := statDb.GetStat(branch, qual)
ret = append(ret, stat)
}
@@ -333,8 +337,12 @@ func (p *Provider) RowCount(ctx *sql.Context, db string, table sql.Table) (uint6
return 0, err
}
// TODO: schema name
priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, table.Name(), "primary"))
var schemaName string
if schTab, ok := table.(sql.DatabaseSchemaTable); ok {
schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName())
}
priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, schemaName, table.Name(), "primary"))
if !ok {
return 0, nil
}
@@ -354,8 +362,12 @@ func (p *Provider) DataLength(ctx *sql.Context, db string, table sql.Table) (uin
return 0, err
}
// TODO: schema name
priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, table.Name(), "primary"))
var schemaName string
if schTab, ok := table.(sql.DatabaseSchemaTable); ok {
schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName())
}
priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, schemaName, table.Name(), "primary"))
if !ok {
return 0, nil
}
@@ -404,7 +416,7 @@ func (p *Provider) Prune(ctx *sql.Context) error {
}
defer p.UnlockTable(branch, dbName, t)
tableStats, err := p.GetTableDoltStats(ctx, branch, dbName, t)
tableStats, err := p.GetTableDoltStats(ctx, branch, dbName, sqlDb.SchemaName(), t)
if err != nil {
return err
}