From bd045ae65809af34887e2d13ecb14f3696513d77 Mon Sep 17 00:00:00 2001 From: Stephanie You Date: Thu, 30 Nov 2023 15:32:43 -0800 Subject: [PATCH] refactor queryist utils --- go/cmd/dolt/commands/sqlserver/sqlserver.go | 136 +---------------- go/cmd/dolt/dolt.go | 4 +- go/cmd/dolt/queryist_utils.go | 157 ++++++++++++++++++++ 3 files changed, 161 insertions(+), 136 deletions(-) create mode 100644 go/cmd/dolt/queryist_utils.go diff --git a/go/cmd/dolt/commands/sqlserver/sqlserver.go b/go/cmd/dolt/commands/sqlserver/sqlserver.go index 44186e815f..55382499cb 100644 --- a/go/cmd/dolt/commands/sqlserver/sqlserver.go +++ b/go/cmd/dolt/commands/sqlserver/sqlserver.go @@ -16,28 +16,20 @@ package sqlserver import ( "context" - sql2 "database/sql" "fmt" - "io" "path/filepath" "strconv" "strings" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/types" - "github.com/fatih/color" - "github.com/go-sql-driver/mysql" - "github.com/gocraft/dbr/v2" - "github.com/gocraft/dbr/v2/dialect" - "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" "github.com/dolthub/dolt/go/libraries/doltcore/env" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/libraries/utils/svcs" + "github.com/dolthub/go-mysql-server/sql" + "github.com/fatih/color" ) const ( @@ -548,127 +540,3 @@ func getYAMLServerConfig(fs filesys.Filesys, path string) (ServerConfig, error) return cfg, nil } - -// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this -// module isn't ideal, but it's the only way to get the server config into the queryist. -func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) { - clientConfig, err := GetClientConfig(cwdFS, creds, apr) - if err != nil { - return nil, err - } - - // ParseDSN currently doesn't support `/` in the db name - dbName, _ := dsess.SplitRevisionDbName(dbRev) - parsedMySQLConfig, err := mysql.ParseDSN(ConnectionString(clientConfig, dbName)) - if err != nil { - return nil, err - } - - parsedMySQLConfig.DBName = dbRev - parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port) - - if useTLS { - parsedMySQLConfig.TLSConfig = "true" - } - - mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig) - if err != nil { - return nil, err - } - - conn := &dbr.Connection{DB: sql2.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL} - - queryist := ConnectionQueryist{connection: conn} - - var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) { - sqlCtx := sql.NewContext(ctx) - sqlCtx.SetCurrentDatabase(dbRev) - return queryist, sqlCtx, func() { conn.Conn(ctx) }, nil - } - - return lateBind, nil -} - -// ConnectionQueryist executes queries by connecting to a running mySql server. -type ConnectionQueryist struct { - connection *dbr.Connection -} - -var _ cli.Queryist = ConnectionQueryist{} - -func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) { - rows, err := c.connection.QueryContext(ctx, query) - if err != nil { - return nil, nil, err - } - rowIter, err := NewMysqlRowWrapper(rows) - if err != nil { - return nil, nil, err - } - return rowIter.Schema(), rowIter, nil -} - -type MysqlRowWrapper struct { - rows *sql2.Rows - schema sql.Schema - finished bool - vRow []*string - iRow []interface{} -} - -var _ sql.RowIter = (*MysqlRowWrapper)(nil) - -func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) { - colNames, err := rows.Columns() - if err != nil { - return nil, err - } - schema := make(sql.Schema, len(colNames)) - vRow := make([]*string, len(colNames)) - iRow := make([]interface{}, len(colNames)) - for i, colName := range colNames { - schema[i] = &sql.Column{ - Name: colName, - Type: types.LongText, - Nullable: true, - } - iRow[i] = &vRow[i] - } - return &MysqlRowWrapper{ - rows: rows, - schema: schema, - finished: !rows.Next(), - vRow: vRow, - iRow: iRow, - }, nil -} - -func (s *MysqlRowWrapper) Schema() sql.Schema { - return s.schema -} - -func (s *MysqlRowWrapper) Next(*sql.Context) (sql.Row, error) { - if s.finished { - return nil, io.EOF - } - err := s.rows.Scan(s.iRow...) - if err != nil { - return nil, err - } - sqlRow := make(sql.Row, len(s.vRow)) - for i, val := range s.vRow { - if val != nil { - sqlRow[i] = *val - } - } - s.finished = !s.rows.Next() - return sqlRow, nil -} - -func (s *MysqlRowWrapper) HasMoreRows() bool { - return !s.finished -} - -func (s *MysqlRowWrapper) Close(*sql.Context) error { - return s.rows.Close() -} diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index a170747660..1e2655f053 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -654,7 +654,7 @@ If you're interested in running this command against a remote host, hit us up on port = 3306 } useTLS := !apr.Contains(cli.NoTLSFlag) - return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb) + return BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb) } else { _, hasPort := apr.GetInt(cli.PortFlag) if hasPort { @@ -708,7 +708,7 @@ If you're interested in running this command against a remote host, hit us up on if !creds.Specified { creds = &cli.UserPassword{Username: sqlserver.LocalConnectionUser, Password: lock.Secret, Specified: false} } - return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", lock.Port, false, useDb) + return BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", lock.Port, false, useDb) } if verbose { diff --git a/go/cmd/dolt/queryist_utils.go b/go/cmd/dolt/queryist_utils.go new file mode 100644 index 0000000000..66d23c39c5 --- /dev/null +++ b/go/cmd/dolt/queryist_utils.go @@ -0,0 +1,157 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + sql2 "database/sql" + "fmt" + "io" + + "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/cmd/dolt/commands/sqlserver" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/utils/argparser" + "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/go-sql-driver/mysql" + "github.com/gocraft/dbr/v2" + "github.com/gocraft/dbr/v2/dialect" +) + +// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this +// module isn't ideal, but it's the only way to get the server config into the queryist. +func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) { + clientConfig, err := sqlserver.GetClientConfig(cwdFS, creds, apr) + if err != nil { + return nil, err + } + + // ParseDSN currently doesn't support `/` in the db name + dbName, _ := dsess.SplitRevisionDbName(dbRev) + parsedMySQLConfig, err := mysql.ParseDSN(sqlserver.ConnectionString(clientConfig, dbName)) + if err != nil { + return nil, err + } + + parsedMySQLConfig.DBName = dbRev + parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port) + + if useTLS { + parsedMySQLConfig.TLSConfig = "true" + } + + mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig) + if err != nil { + return nil, err + } + + conn := &dbr.Connection{DB: sql2.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL} + + queryist := ConnectionQueryist{connection: conn} + + var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) { + sqlCtx := sql.NewContext(ctx) + sqlCtx.SetCurrentDatabase(dbRev) + return queryist, sqlCtx, func() { conn.Conn(ctx) }, nil + } + + return lateBind, nil +} + +// ConnectionQueryist executes queries by connecting to a running mySql server. +type ConnectionQueryist struct { + connection *dbr.Connection +} + +var _ cli.Queryist = ConnectionQueryist{} + +func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) { + rows, err := c.connection.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + rowIter, err := NewMysqlRowWrapper(rows) + if err != nil { + return nil, nil, err + } + return rowIter.Schema(), rowIter, nil +} + +type MysqlRowWrapper struct { + rows *sql2.Rows + schema sql.Schema + finished bool + vRow []*string + iRow []interface{} +} + +var _ sql.RowIter = (*MysqlRowWrapper)(nil) + +func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) { + colNames, err := rows.Columns() + if err != nil { + return nil, err + } + schema := make(sql.Schema, len(colNames)) + vRow := make([]*string, len(colNames)) + iRow := make([]interface{}, len(colNames)) + for i, colName := range colNames { + schema[i] = &sql.Column{ + Name: colName, + Type: types.LongText, + Nullable: true, + } + iRow[i] = &vRow[i] + } + return &MysqlRowWrapper{ + rows: rows, + schema: schema, + finished: !rows.Next(), + vRow: vRow, + iRow: iRow, + }, nil +} + +func (s *MysqlRowWrapper) Schema() sql.Schema { + return s.schema +} + +func (s *MysqlRowWrapper) Next(*sql.Context) (sql.Row, error) { + if s.finished { + return nil, io.EOF + } + err := s.rows.Scan(s.iRow...) + if err != nil { + return nil, err + } + sqlRow := make(sql.Row, len(s.vRow)) + for i, val := range s.vRow { + if val != nil { + sqlRow[i] = *val + } + } + s.finished = !s.rows.Next() + return sqlRow, nil +} + +func (s *MysqlRowWrapper) HasMoreRows() bool { + return !s.finished +} + +func (s *MysqlRowWrapper) Close(*sql.Context) error { + return s.rows.Close() +}