Fix dolt_statistics table for multiple schemas for doltgres

This commit is contained in:
Taylor Bantle
2024-11-21 16:21:09 -08:00
parent 6200a330ab
commit f79892c083
2 changed files with 27 additions and 45 deletions
+20 -2
View File
@@ -675,7 +675,18 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
}
}
case doltdb.StatisticsTableName:
dt, found = dtables.NewStatisticsTable(ctx, db.Name(), db.ddb, asOf), true
var tables []string
var err error
branch, ok := asOf.(string)
if ok && branch != "" {
tables, err = db.GetTableNamesAsOf(ctx, branch)
} else {
tables, err = db.GetTableNames(ctx)
}
if err != nil {
return nil, false, err
}
dt, found = dtables.NewStatisticsTable(ctx, db.Name(), branch, tables), true
case doltdb.ProceduresTableName:
found = true
backingTable, _, err := db.getTable(ctx, root, doltdb.ProceduresTableName)
@@ -856,7 +867,14 @@ func (db Database) GetTableNamesAsOf(ctx *sql.Context, time interface{}) ([]stri
return nil, nil
}
tblNames, err := db.getAllTableNames(ctx, root, false)
showSystemTablesVar, err := ctx.GetSessionVariable(ctx, dsess.ShowSystemTables)
if err != nil {
return nil, err
}
showSystemTables := showSystemTablesVar.(int8) == 1
tblNames, err := db.getAllTableNames(ctx, root, showSystemTables)
if err != nil {
return nil, err
}
@@ -15,8 +15,6 @@
package dtables
import (
"fmt"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/stats"
@@ -28,21 +26,17 @@ import (
// StatisticsTable is a sql.Table implementation that implements a system table which shows the dolt commit log
type StatisticsTable struct {
dbName string
branch string
ddb *doltdb.DoltDB
dbName string
branch string
tableNames []string
}
var _ sql.Table = (*StatisticsTable)(nil)
var _ sql.StatisticsTable = (*StatisticsTable)(nil)
// NewStatisticsTable creates a StatisticsTable
func NewStatisticsTable(_ *sql.Context, dbName string, ddb *doltdb.DoltDB, asOf interface{}) sql.Table {
ret := &StatisticsTable{dbName: dbName, ddb: ddb}
if branch, ok := asOf.(string); ok {
ret.branch = branch
}
return ret
func NewStatisticsTable(_ *sql.Context, dbName, branch string, tableNames []string) sql.Table {
return &StatisticsTable{dbName: dbName, branch: branch, tableNames: tableNames}
}
// DataLength implements sql.StatisticsTable
@@ -79,20 +73,8 @@ type BranchStatsProvider interface {
// RowCount implements sql.StatisticsTable
func (st *StatisticsTable) RowCount(ctx *sql.Context) (uint64, bool, error) {
dSess := dsess.DSessFromSess(ctx.Session)
prov := dSess.Provider()
sqlDb, err := prov.Database(ctx, st.dbName)
if err != nil {
return 0, false, err
}
tables, err := sqlDb.GetTableNames(ctx)
if err != nil {
return 0, false, err
}
var cnt int
for _, table := range tables {
for _, table := range st.tableNames {
// only Dolt-specific provider has branch support
dbStats, err := dSess.StatsProvider().(BranchStatsProvider).GetTableDoltStats(ctx, st.branch, st.dbName, table)
if err != nil {
@@ -136,27 +118,9 @@ func (st *StatisticsTable) Partitions(*sql.Context) (sql.PartitionIter, error) {
// PartitionRows is a sql.Table interface function that gets a row iterator for a partition
func (st *StatisticsTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) {
dSess := dsess.DSessFromSess(ctx.Session)
prov := dSess.Provider()
var sqlDb sql.Database
var err error
if st.branch != "" {
sqlDb, err = prov.Database(ctx, fmt.Sprintf("%s/%s", st.dbName, st.branch))
} else {
sqlDb, err = prov.Database(ctx, st.dbName)
}
if err != nil {
return nil, err
}
tables, err := sqlDb.GetTableNames(ctx)
if err != nil {
return nil, err
}
statsPro := dSess.StatsProvider().(BranchStatsProvider)
var dStats []sql.Statistic
for _, table := range tables {
for _, table := range st.tableNames {
dbStats, err := statsPro.GetTableDoltStats(ctx, st.branch, st.dbName, table)
if err != nil {
return nil, err