Fixed more interface breakage

This commit is contained in:
Zach Musgrave
2021-10-18 15:15:06 -07:00
parent 5bc4388a5f
commit 4141dfe11b
4 changed files with 96 additions and 143 deletions
+5 -27
View File
@@ -37,7 +37,6 @@ import (
dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle"
_ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/tracing"
)
// Serve starts a MySQL-compatible server. Returns any errors that were encountered.
@@ -171,7 +170,7 @@ func portInUse(hostPort string) bool {
}
func newSessionBuilder(sqlEngine *sqle.Engine, dConf *env.DoltCliConfig, pro dsqle.DoltDatabaseProvider, mrEnv env.MultiRepoEnv, autocommit bool) server.SessionBuilder {
return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, *sql.IndexRegistry, *sql.ViewRegistry, error) {
return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, error) {
tmpSqlCtx := sql.NewEmptyContext()
client := sql.Client{Address: conn.RemoteAddr().String(), User: conn.User, Capabilities: conn.Capabilities}
@@ -179,46 +178,25 @@ func newSessionBuilder(sqlEngine *sqle.Engine, dConf *env.DoltCliConfig, pro dsq
doltDbs := dsqle.DbsAsDSQLDBs(sqlEngine.Analyzer.Catalog.AllDatabases())
dbStates, err := getDbStates(ctx, doltDbs)
if err != nil {
return nil, nil, nil, err
return nil, err
}
doltSess, err := dsess.NewSession(tmpSqlCtx, mysqlSess, pro, dConf, dbStates...)
if err != nil {
return nil, nil, nil, err
return nil, err
}
err = doltSess.SetSessionVariable(tmpSqlCtx, sql.AutoCommitSessionVar, autocommit)
if err != nil {
return nil, nil, nil, err
return nil, err
}
ir := sql.NewIndexRegistry()
vr := sql.NewViewRegistry()
sqlCtx := sql.NewContext(
ctx,
sql.WithIndexRegistry(ir),
sql.WithViewRegistry(vr),
sql.WithSession(doltSess),
sql.WithTracer(tracing.Tracer(ctx)))
dbs := dsqle.DbsAsDSQLDBs(sqlEngine.Analyzer.Catalog.AllDatabases())
for _, db := range dbs {
root, err := db.GetRoot(sqlCtx)
if err != nil {
cli.PrintErrln(err)
return nil, nil, nil, err
}
err = dsqle.RegisterSchemaFragments(sqlCtx, db, root)
if err != nil {
cli.PrintErr(err)
return nil, nil, nil, err
}
db.DbData().Ddb.SetCommitHookLogger(ctx, doltSess.GetLogger().Logger.Out)
}
return doltSess, ir, vr, nil
return doltSess, nil
}
}
+29 -96
View File
@@ -872,35 +872,36 @@ func (db Database) Flush(ctx *sql.Context) error {
return nil
}
// GetView implements sql.ViewDatabase
func (db Database) GetView(ctx *sql.Context, viewName string) (string, bool, error) {
stbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
root, err := db.GetRoot(ctx)
if err != nil {
return "", false, err
}
if !found {
return "", false, nil
}
tbl := stbl.(*WritableDoltTable)
row, exists, err := fragFromSchemasTable(ctx, tbl, "view", viewName)
tbl, ok, err := root.GetTable(ctx, doltdb.SchemasTableName)
if err != nil {
return "", false, err
}
if !exists {
if !ok {
return "", false, nil
}
if len(row) < 4 {
return "", false, errDoltSchemasTableFormat
fragments, err := getSchemaFragmentsOfType(ctx, tbl, "view")
if err != nil {
return "", false, err
}
if def, ok := row[2].(string); ok {
return def, true, nil
} else {
return "", false, errDoltSchemasTableFormat
for _, fragment := range fragments {
if strings.ToLower(fragment.name) == strings.ToLower(viewName) {
return fragment.fragment, true, nil
}
}
return "", false, nil
}
// GetView implements sql.ViewDatabase
func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
root, err := db.GetRoot(ctx)
if err != nil {
@@ -915,53 +916,18 @@ func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
return nil, nil
}
sch, err := tbl.GetSchema(ctx)
frags, err := getSchemaFragmentsOfType(ctx, tbl, "view")
if err != nil {
return nil, err
}
typeCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesTypeCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
nameCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesNameCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
fragCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesFragmentCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
rowData, err := tbl.GetRowData(ctx)
if err != nil {
return nil, err
}
var views []sql.ViewDefinition
err = rowData.Iter(ctx, func(key, val types.Value) (stop bool, err error) {
dRow, err := row.FromNoms(sch, key.(types.Tuple), val.(types.Tuple))
if err != nil {
return true, err
}
if typeColVal, ok := dRow.GetColVal(typeCol.Tag); ok && typeColVal.Equals(types.String("view")) {
name, ok := dRow.GetColVal(nameCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for view row: (%s)", doltdb.SchemasTablesNameCol, taggedVals)
}
def, ok := dRow.GetColVal(fragCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for view row: (%s)", doltdb.SchemasTablesFragmentCol, taggedVals)
}
views = append(views, sql.ViewDefinition{
Name: string(name.(types.String)),
TextDefinition: string(def.(types.String)),
})
}
return false, nil
})
for _, frag := range frags {
views = append(views, sql.ViewDefinition{
Name: frag.name,
TextDefinition: frag.fragment,
})
}
if err != nil {
return nil, err
}
@@ -998,55 +964,22 @@ func (db Database) GetTriggers(ctx *sql.Context) ([]sql.TriggerDefinition, error
return nil, nil
}
sch, err := tbl.GetSchema(ctx)
frags, err := getSchemaFragmentsOfType(ctx, tbl, "view")
if err != nil {
return nil, err
}
typeCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesTypeCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
nameCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesNameCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
fragCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesFragmentCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
rowData, err := tbl.GetRowData(ctx)
if err != nil {
return nil, err
}
var triggers []sql.TriggerDefinition
err = rowData.Iter(ctx, func(key, val types.Value) (stop bool, err error) {
dRow, err := row.FromNoms(sch, key.(types.Tuple), val.(types.Tuple))
if err != nil {
return true, err
}
if typeColVal, ok := dRow.GetColVal(typeCol.Tag); ok && typeColVal.Equals(types.String("trigger")) {
name, ok := dRow.GetColVal(nameCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for trigger row: (%s)", doltdb.SchemasTablesNameCol, taggedVals)
}
createStmt, ok := dRow.GetColVal(fragCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for trigger row: (%s)", doltdb.SchemasTablesFragmentCol, taggedVals)
}
triggers = append(triggers, sql.TriggerDefinition{
Name: string(name.(types.String)),
CreateStatement: string(createStmt.(types.String)),
})
}
return false, nil
})
for _, frag := range frags {
triggers = append(triggers, sql.TriggerDefinition{
Name: frag.name,
CreateStatement: frag.fragment,
})
}
if err != nil {
return nil, err
}
return triggers, nil
}
@@ -50,10 +50,7 @@ const (
type DoltHarness struct {
Version string
engine *sqle.Engine
sess *dsess.Session
idxReg *sql.IndexRegistry
viewReg *sql.ViewRegistry
}
func (h *DoltHarness) EngineStr() string {
@@ -69,8 +66,6 @@ func (h *DoltHarness) ExecuteStatement(statement string) error {
ctx := sql.NewContext(
context.Background(),
sql.WithPid(rand.Uint64()),
sql.WithIndexRegistry(h.idxReg),
sql.WithViewRegistry(h.viewReg),
sql.WithSession(h.sess))
statement = normalizeStatement(statement)
@@ -88,8 +83,6 @@ func (h *DoltHarness) ExecuteQuery(statement string) (schema string, results []s
ctx := sql.NewContext(
context.Background(),
sql.WithPid(uint64(pid)),
sql.WithIndexRegistry(h.idxReg),
sql.WithViewRegistry(h.viewReg),
sql.WithSession(h.sess))
var sch sql.Schema
@@ -141,13 +134,9 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error {
}
h.sess = dsess.DefaultSession()
h.idxReg = sql.NewIndexRegistry()
h.viewReg = sql.NewViewRegistry()
ctx := sql.NewContext(
context.Background(),
sql.WithIndexRegistry(h.idxReg),
sql.WithViewRegistry(h.viewReg),
sql.WithSession(h.sess))
dbs := h.engine.Analyzer.Catalog.AllDatabases()
@@ -158,7 +147,6 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error {
sess := dsess.DSessFromSess(ctx.Session)
err := sess.AddDB(ctx, getDbState(db, dEnv))
if err != nil {
return err
}
@@ -169,13 +157,6 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error {
}
err = dsqlDB.SetRoot(ctx, root)
if err != nil {
return err
}
err = dsql.RegisterSchemaFragments(ctx, dsqlDB, root)
if err != nil {
return err
}
+62 -1
View File
@@ -29,6 +29,7 @@ import (
)
var errDoltSchemasTableFormat = fmt.Errorf("`%s` schema in unexpected format", doltdb.SchemasTableName)
var noSchemaIndexDefined = fmt.Errorf("could not find index `%s` on system table `%s`", doltdb.SchemasTablesIndexName, doltdb.SchemasTableName)
// The fixed dolt schema for the `dolt_schemas` table.
func SchemasTableSchema() schema.Schema {
@@ -206,7 +207,7 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
}
}
if fragNameIndex == nil {
return nil, false, fmt.Errorf("could not find index `%s` on system table `%s`", doltdb.SchemasTablesIndexName, doltdb.SchemasTableName)
return nil, false, noSchemaIndexDefined
}
indexLookup, err := fragNameIndex.Get(fragType, name)
@@ -228,3 +229,63 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
return nil, false, err
}
}
type schemaFragment struct {
name string
fragment string
}
func getSchemaFragmentsOfType(ctx *sql.Context, tbl *doltdb.Table, fragmentType string) ([]schemaFragment, error) {
sch, err := tbl.GetSchema(ctx)
if err != nil {
return nil, err
}
typeCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesTypeCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
nameCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesNameCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
fragCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesFragmentCol)
if !ok {
return nil, errDoltSchemasTableFormat
}
rowData, err := tbl.GetRowData(ctx)
if err != nil {
return nil, err
}
var fragments []schemaFragment
err = rowData.Iter(ctx, func(key, val types.Value) (stop bool, err error) {
dRow, err := row.FromNoms(sch, key.(types.Tuple), val.(types.Tuple))
if err != nil {
return true, err
}
if typeColVal, ok := dRow.GetColVal(typeCol.Tag); ok && typeColVal.Equals(types.String(fragmentType)) {
name, ok := dRow.GetColVal(nameCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for view row: (%s)", doltdb.SchemasTablesNameCol, taggedVals)
}
def, ok := dRow.GetColVal(fragCol.Tag)
if !ok {
taggedVals, _ := dRow.TaggedValues()
return true, fmt.Errorf("missing `%s` value for view row: (%s)", doltdb.SchemasTablesFragmentCol, taggedVals)
}
fragments = append(fragments, schemaFragment{
name: string(name.(types.String)),
fragment: string(def.(types.String)),
})
}
return false, nil
})
if err != nil {
return nil, err
}
return fragments, nil
}