From 37c1933ea916c5fd443d3b59d381a995b0b56d5e Mon Sep 17 00:00:00 2001 From: Taylor Bantle Date: Fri, 22 Nov 2024 15:17:18 -0800 Subject: [PATCH] Get stats for same table name, different schemas --- go/libraries/doltcore/sqle/database.go | 10 ++++++- .../doltcore/sqle/dtables/statistics_table.go | 11 +++---- go/libraries/doltcore/sqle/statsnoms/iter.go | 16 +++++----- go/libraries/doltcore/sqle/statsnoms/load.go | 5 ++-- .../doltcore/sqle/statspro/analyze.go | 8 +++-- .../doltcore/sqle/statspro/auto_refresh.go | 9 ++++-- .../doltcore/sqle/statspro/stats_provider.go | 30 +++++++++++++------ 7 files changed, 61 insertions(+), 28 deletions(-) diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 46cde1c2c3..61edb64dc2 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -675,6 +675,14 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds } } case doltdb.StatisticsTableName: + if resolve.UseSearchPath && db.schemaName == "" { + schemaName, err := resolve.FirstExistingSchemaOnSearchPath(ctx, root) + if err != nil { + return nil, false, err + } + db.schemaName = schemaName + } + var tables []string var err error branch, ok := asOf.(string) @@ -686,7 +694,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds if err != nil { return nil, false, err } - dt, found = dtables.NewStatisticsTable(ctx, db.Name(), branch, tables), true + dt, found = dtables.NewStatisticsTable(ctx, db.Name(), db.schemaName, branch, tables), true case doltdb.ProceduresTableName: found = true backingTable, _, err := db.getTable(ctx, root, doltdb.ProceduresTableName) diff --git a/go/libraries/doltcore/sqle/dtables/statistics_table.go b/go/libraries/doltcore/sqle/dtables/statistics_table.go index c61d1c4937..fda463e7e4 100644 --- a/go/libraries/doltcore/sqle/dtables/statistics_table.go +++ b/go/libraries/doltcore/sqle/dtables/statistics_table.go @@ -27,6 +27,7 @@ import ( // StatisticsTable is a sql.Table implementation that implements a system table which shows the dolt commit log type StatisticsTable struct { dbName string + schemaName string branch string tableNames []string } @@ -35,8 +36,8 @@ var _ sql.Table = (*StatisticsTable)(nil) var _ sql.StatisticsTable = (*StatisticsTable)(nil) // NewStatisticsTable creates a StatisticsTable -func NewStatisticsTable(_ *sql.Context, dbName, branch string, tableNames []string) sql.Table { - return &StatisticsTable{dbName: dbName, branch: branch, tableNames: tableNames} +func NewStatisticsTable(_ *sql.Context, dbName, schemaName, branch string, tableNames []string) sql.Table { + return &StatisticsTable{dbName: dbName, schemaName: schemaName, branch: branch, tableNames: tableNames} } // DataLength implements sql.StatisticsTable @@ -67,7 +68,7 @@ func (st *StatisticsTable) DataLength(ctx *sql.Context) (uint64, error) { } type BranchStatsProvider interface { - GetTableDoltStats(ctx *sql.Context, branch, db, table string) ([]sql.Statistic, error) + GetTableDoltStats(ctx *sql.Context, branch, db, schema, table string) ([]sql.Statistic, error) } // RowCount implements sql.StatisticsTable @@ -76,7 +77,7 @@ func (st *StatisticsTable) RowCount(ctx *sql.Context) (uint64, bool, error) { var cnt int for _, table := range st.tableNames { // only Dolt-specific provider has branch support - dbStats, err := dSess.StatsProvider().(BranchStatsProvider).GetTableDoltStats(ctx, st.branch, st.dbName, table) + dbStats, err := dSess.StatsProvider().(BranchStatsProvider).GetTableDoltStats(ctx, st.branch, st.dbName, st.schemaName, table) if err != nil { } @@ -121,7 +122,7 @@ func (st *StatisticsTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql statsPro := dSess.StatsProvider().(BranchStatsProvider) var dStats []sql.Statistic for _, table := range st.tableNames { - dbStats, err := statsPro.GetTableDoltStats(ctx, st.branch, st.dbName, table) + dbStats, err := statsPro.GetTableDoltStats(ctx, st.branch, st.dbName, st.schemaName, table) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/statsnoms/iter.go b/go/libraries/doltcore/sqle/statsnoms/iter.go index 685b497582..59b9456eed 100644 --- a/go/libraries/doltcore/sqle/statsnoms/iter.go +++ b/go/libraries/doltcore/sqle/statsnoms/iter.go @@ -32,7 +32,7 @@ import ( var ErrIncompatibleVersion = errors.New("client stats version mismatch") -func NewStatsIter(ctx *sql.Context, m prolly.Map) (*statsIter, error) { +func NewStatsIter(ctx *sql.Context, schemaName string, m prolly.Map) (*statsIter, error) { iter, err := m.IterAll(ctx) if err != nil { return nil, err @@ -43,11 +43,12 @@ func NewStatsIter(ctx *sql.Context, m prolly.Map) (*statsIter, error) { ns := m.NodeStore() return &statsIter{ - iter: iter, - kb: keyBuilder, - vb: valueBuilder, - ns: ns, - planb: planbuilder.New(ctx, nil, nil, nil), + iter: iter, + kb: keyBuilder, + vb: valueBuilder, + ns: ns, + schemaName: schemaName, + planb: planbuilder.New(ctx, nil, nil, nil), }, nil } @@ -61,6 +62,7 @@ type statsIter struct { ns tree.NodeStore planb *planbuilder.Builder currentQual string + schemaName string currentTypes []sql.Type } @@ -118,7 +120,7 @@ func (s *statsIter) Next(ctx *sql.Context) (sql.Row, error) { typs[i] = strings.TrimSpace(t) } - qual := sql.NewStatQualifier(dbName, tableName, indexName) + qual := sql.NewStatQualifier(dbName, s.schemaName, tableName, indexName) if curQual := qual.String(); !strings.EqualFold(curQual, s.currentQual) { s.currentQual = curQual s.currentTypes, err = parseTypeStrings(typs) diff --git a/go/libraries/doltcore/sqle/statsnoms/load.go b/go/libraries/doltcore/sqle/statsnoms/load.go index 043eb3fced..55b438b1ca 100644 --- a/go/libraries/doltcore/sqle/statsnoms/load.go +++ b/go/libraries/doltcore/sqle/statsnoms/load.go @@ -39,7 +39,8 @@ import ( func loadStats(ctx *sql.Context, db dsess.SqlDatabase, m prolly.Map) (map[sql.StatQualifier]*statspro.DoltStats, error) { qualToStats := make(map[sql.StatQualifier]*statspro.DoltStats) - iter, err := NewStatsIter(ctx, m) + schemaName := db.SchemaName() + iter, err := NewStatsIter(ctx, schemaName, m) if err != nil { return nil, err } @@ -72,7 +73,7 @@ func loadStats(ctx *sql.Context, db dsess.SqlDatabase, m prolly.Map) (map[sql.St typs[i] = strings.TrimSpace(t) } - qual := sql.NewStatQualifier(dbName, tableName, indexName) + qual := sql.NewStatQualifier(dbName, schemaName, tableName, indexName) if currentStat.Statistic.Qual.String() != qual.String() { if !currentStat.Statistic.Qual.Empty() { currentStat.Statistic.LowerBnd, currentStat.Tb, err = loadLowerBound(ctx, db, currentStat.Statistic.Qual, len(currentStat.Columns())) diff --git a/go/libraries/doltcore/sqle/statspro/analyze.go b/go/libraries/doltcore/sqle/statspro/analyze.go index ecd470b0d4..0672b6f6f8 100644 --- a/go/libraries/doltcore/sqle/statspro/analyze.go +++ b/go/libraries/doltcore/sqle/statspro/analyze.go @@ -100,6 +100,10 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table tableName := strings.ToLower(table.Name()) dbName := strings.ToLower(db) + var schemaName string + if schTab, ok := table.(sql.DatabaseSchemaTable); ok { + schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + } iat, ok := table.(sql.IndexAddressableTable) if !ok { @@ -146,7 +150,7 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table ctx.GetLogger().Debugf("statistics refresh: detected table schema change: %s,%s/%s", dbName, table, branch) statDb.SetSchemaHash(branch, tableName, schHash) - stats, err := p.GetTableDoltStats(ctx, branch, dbName, tableName) + stats, err := p.GetTableDoltStats(ctx, branch, dbName, schemaName, tableName) if err != nil { return err } @@ -163,7 +167,7 @@ func (p *Provider) RefreshTableStatsWithBranch(ctx *sql.Context, table sql.Table cols[i] = strings.TrimPrefix(strings.ToLower(c), tablePrefix) } - qual := sql.NewStatQualifier(db, table.Name(), strings.ToLower(idx.ID())) + qual := sql.NewStatQualifier(db, schemaName, table.Name(), strings.ToLower(idx.ID())) curStat, ok := statDb.GetStat(branch, qual) if !ok { curStat = NewDoltStats() diff --git a/go/libraries/doltcore/sqle/statspro/auto_refresh.go b/go/libraries/doltcore/sqle/statspro/auto_refresh.go index 6bc92380fa..d275453321 100644 --- a/go/libraries/doltcore/sqle/statspro/auto_refresh.go +++ b/go/libraries/doltcore/sqle/statspro/auto_refresh.go @@ -163,13 +163,18 @@ func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, br return err } + var schemaName string + if schTab, ok := sqlTable.(sql.DatabaseSchemaTable); ok { + schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + } + if oldSchHash := statDb.GetSchemaHash(branch, table); oldSchHash.IsEmpty() { statDb.SetSchemaHash(branch, table, schHash) } else if oldSchHash != schHash { ctx.GetLogger().Debugf("statistics refresh: detected table schema change: %s,%s/%s", dbName, table, branch) statDb.SetSchemaHash(branch, table, schHash) - stats, err := p.GetTableDoltStats(ctx, branch, dbName, table) + stats, err := p.GetTableDoltStats(ctx, branch, schemaName, dbName, table) if err != nil { return err } @@ -191,7 +196,7 @@ func (p *Provider) checkRefresh(ctx *sql.Context, sqlDb sql.Database, dbName, br // collect indexes and ranges to be updated var idxMetas []indexMeta for _, index := range indexes { - qual := sql.NewStatQualifier(dbName, table, strings.ToLower(index.ID())) + qual := sql.NewStatQualifier(dbName, schemaName, table, strings.ToLower(index.ID())) qualExists[qual] = true curStat, ok := statDb.GetStat(branch, qual) if !ok { diff --git a/go/libraries/doltcore/sqle/statspro/stats_provider.go b/go/libraries/doltcore/sqle/statspro/stats_provider.go index f1cb411a15..4e05e60e26 100644 --- a/go/libraries/doltcore/sqle/statspro/stats_provider.go +++ b/go/libraries/doltcore/sqle/statspro/stats_provider.go @@ -169,11 +169,15 @@ func (p *Provider) GetTableStats(ctx *sql.Context, db string, table sql.Table) ( return nil, nil } - // TODO: schema name - return p.GetTableDoltStats(ctx, branch, db, table.Name()) + var schemaName string + if schTab, ok := table.(sql.DatabaseSchemaTable); ok { + schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + } + + return p.GetTableDoltStats(ctx, branch, db, schemaName, table.Name()) } -func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, table string) ([]sql.Statistic, error) { +func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, schema, table string) ([]sql.Statistic, error) { statDb, ok := p.getStatDb(db) if !ok || statDb == nil { return nil, nil @@ -190,7 +194,7 @@ func (p *Provider) GetTableDoltStats(ctx *sql.Context, branch, db, table string) var ret []sql.Statistic for _, qual := range statDb.ListStatQuals(branch) { - if strings.EqualFold(db, qual.Database) && strings.EqualFold(table, qual.Tab) { + if strings.EqualFold(db, qual.Database) && strings.EqualFold(schema, qual.Sch) && strings.EqualFold(table, qual.Tab) { stat, _ := statDb.GetStat(branch, qual) ret = append(ret, stat) } @@ -333,8 +337,12 @@ func (p *Provider) RowCount(ctx *sql.Context, db string, table sql.Table) (uint6 return 0, err } - // TODO: schema name - priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, table.Name(), "primary")) + var schemaName string + if schTab, ok := table.(sql.DatabaseSchemaTable); ok { + schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + } + + priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, schemaName, table.Name(), "primary")) if !ok { return 0, nil } @@ -354,8 +362,12 @@ func (p *Provider) DataLength(ctx *sql.Context, db string, table sql.Table) (uin return 0, err } - // TODO: schema name - priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, table.Name(), "primary")) + var schemaName string + if schTab, ok := table.(sql.DatabaseSchemaTable); ok { + schemaName = strings.ToLower(schTab.DatabaseSchema().SchemaName()) + } + + priStats, ok := statDb.GetStat(branch, sql.NewStatQualifier(db, schemaName, table.Name(), "primary")) if !ok { return 0, nil } @@ -404,7 +416,7 @@ func (p *Provider) Prune(ctx *sql.Context) error { } defer p.UnlockTable(branch, dbName, t) - tableStats, err := p.GetTableDoltStats(ctx, branch, dbName, t) + tableStats, err := p.GetTableDoltStats(ctx, branch, dbName, sqlDb.SchemaName(), t) if err != nil { return err }