mirror of
https://github.com/dolthub/dolt.git
synced 2026-03-09 11:19:01 -05:00
refactor queryist utils
This commit is contained in:
@@ -16,28 +16,20 @@ package sqlserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
sql2 "database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/gocraft/dbr/v2"
|
||||
"github.com/gocraft/dbr/v2/dialect"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/commands"
|
||||
eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/argparser"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -548,127 +540,3 @@ func getYAMLServerConfig(fs filesys.Filesys, path string) (ServerConfig, error)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this
|
||||
// module isn't ideal, but it's the only way to get the server config into the queryist.
|
||||
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) {
|
||||
clientConfig, err := GetClientConfig(cwdFS, creds, apr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ParseDSN currently doesn't support `/` in the db name
|
||||
dbName, _ := dsess.SplitRevisionDbName(dbRev)
|
||||
parsedMySQLConfig, err := mysql.ParseDSN(ConnectionString(clientConfig, dbName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedMySQLConfig.DBName = dbRev
|
||||
parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port)
|
||||
|
||||
if useTLS {
|
||||
parsedMySQLConfig.TLSConfig = "true"
|
||||
}
|
||||
|
||||
mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &dbr.Connection{DB: sql2.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL}
|
||||
|
||||
queryist := ConnectionQueryist{connection: conn}
|
||||
|
||||
var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) {
|
||||
sqlCtx := sql.NewContext(ctx)
|
||||
sqlCtx.SetCurrentDatabase(dbRev)
|
||||
return queryist, sqlCtx, func() { conn.Conn(ctx) }, nil
|
||||
}
|
||||
|
||||
return lateBind, nil
|
||||
}
|
||||
|
||||
// ConnectionQueryist executes queries by connecting to a running mySql server.
|
||||
type ConnectionQueryist struct {
|
||||
connection *dbr.Connection
|
||||
}
|
||||
|
||||
var _ cli.Queryist = ConnectionQueryist{}
|
||||
|
||||
func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) {
|
||||
rows, err := c.connection.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rowIter, err := NewMysqlRowWrapper(rows)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return rowIter.Schema(), rowIter, nil
|
||||
}
|
||||
|
||||
type MysqlRowWrapper struct {
|
||||
rows *sql2.Rows
|
||||
schema sql.Schema
|
||||
finished bool
|
||||
vRow []*string
|
||||
iRow []interface{}
|
||||
}
|
||||
|
||||
var _ sql.RowIter = (*MysqlRowWrapper)(nil)
|
||||
|
||||
func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) {
|
||||
colNames, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schema := make(sql.Schema, len(colNames))
|
||||
vRow := make([]*string, len(colNames))
|
||||
iRow := make([]interface{}, len(colNames))
|
||||
for i, colName := range colNames {
|
||||
schema[i] = &sql.Column{
|
||||
Name: colName,
|
||||
Type: types.LongText,
|
||||
Nullable: true,
|
||||
}
|
||||
iRow[i] = &vRow[i]
|
||||
}
|
||||
return &MysqlRowWrapper{
|
||||
rows: rows,
|
||||
schema: schema,
|
||||
finished: !rows.Next(),
|
||||
vRow: vRow,
|
||||
iRow: iRow,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Schema() sql.Schema {
|
||||
return s.schema
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Next(*sql.Context) (sql.Row, error) {
|
||||
if s.finished {
|
||||
return nil, io.EOF
|
||||
}
|
||||
err := s.rows.Scan(s.iRow...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlRow := make(sql.Row, len(s.vRow))
|
||||
for i, val := range s.vRow {
|
||||
if val != nil {
|
||||
sqlRow[i] = *val
|
||||
}
|
||||
}
|
||||
s.finished = !s.rows.Next()
|
||||
return sqlRow, nil
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) HasMoreRows() bool {
|
||||
return !s.finished
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Close(*sql.Context) error {
|
||||
return s.rows.Close()
|
||||
}
|
||||
|
||||
@@ -654,7 +654,7 @@ If you're interested in running this command against a remote host, hit us up on
|
||||
port = 3306
|
||||
}
|
||||
useTLS := !apr.Contains(cli.NoTLSFlag)
|
||||
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb)
|
||||
return BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb)
|
||||
} else {
|
||||
_, hasPort := apr.GetInt(cli.PortFlag)
|
||||
if hasPort {
|
||||
@@ -708,7 +708,7 @@ If you're interested in running this command against a remote host, hit us up on
|
||||
if !creds.Specified {
|
||||
creds = &cli.UserPassword{Username: sqlserver.LocalConnectionUser, Password: lock.Secret, Specified: false}
|
||||
}
|
||||
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", lock.Port, false, useDb)
|
||||
return BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", lock.Port, false, useDb)
|
||||
}
|
||||
|
||||
if verbose {
|
||||
|
||||
157
go/cmd/dolt/queryist_utils.go
Normal file
157
go/cmd/dolt/queryist_utils.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright 2023 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
sql2 "database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/commands/sqlserver"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/argparser"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/gocraft/dbr/v2"
|
||||
"github.com/gocraft/dbr/v2/dialect"
|
||||
)
|
||||
|
||||
// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this
|
||||
// module isn't ideal, but it's the only way to get the server config into the queryist.
|
||||
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) {
|
||||
clientConfig, err := sqlserver.GetClientConfig(cwdFS, creds, apr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ParseDSN currently doesn't support `/` in the db name
|
||||
dbName, _ := dsess.SplitRevisionDbName(dbRev)
|
||||
parsedMySQLConfig, err := mysql.ParseDSN(sqlserver.ConnectionString(clientConfig, dbName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedMySQLConfig.DBName = dbRev
|
||||
parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port)
|
||||
|
||||
if useTLS {
|
||||
parsedMySQLConfig.TLSConfig = "true"
|
||||
}
|
||||
|
||||
mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &dbr.Connection{DB: sql2.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL}
|
||||
|
||||
queryist := ConnectionQueryist{connection: conn}
|
||||
|
||||
var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) {
|
||||
sqlCtx := sql.NewContext(ctx)
|
||||
sqlCtx.SetCurrentDatabase(dbRev)
|
||||
return queryist, sqlCtx, func() { conn.Conn(ctx) }, nil
|
||||
}
|
||||
|
||||
return lateBind, nil
|
||||
}
|
||||
|
||||
// ConnectionQueryist executes queries by connecting to a running mySql server.
|
||||
type ConnectionQueryist struct {
|
||||
connection *dbr.Connection
|
||||
}
|
||||
|
||||
var _ cli.Queryist = ConnectionQueryist{}
|
||||
|
||||
func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) {
|
||||
rows, err := c.connection.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rowIter, err := NewMysqlRowWrapper(rows)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return rowIter.Schema(), rowIter, nil
|
||||
}
|
||||
|
||||
type MysqlRowWrapper struct {
|
||||
rows *sql2.Rows
|
||||
schema sql.Schema
|
||||
finished bool
|
||||
vRow []*string
|
||||
iRow []interface{}
|
||||
}
|
||||
|
||||
var _ sql.RowIter = (*MysqlRowWrapper)(nil)
|
||||
|
||||
func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) {
|
||||
colNames, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schema := make(sql.Schema, len(colNames))
|
||||
vRow := make([]*string, len(colNames))
|
||||
iRow := make([]interface{}, len(colNames))
|
||||
for i, colName := range colNames {
|
||||
schema[i] = &sql.Column{
|
||||
Name: colName,
|
||||
Type: types.LongText,
|
||||
Nullable: true,
|
||||
}
|
||||
iRow[i] = &vRow[i]
|
||||
}
|
||||
return &MysqlRowWrapper{
|
||||
rows: rows,
|
||||
schema: schema,
|
||||
finished: !rows.Next(),
|
||||
vRow: vRow,
|
||||
iRow: iRow,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Schema() sql.Schema {
|
||||
return s.schema
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Next(*sql.Context) (sql.Row, error) {
|
||||
if s.finished {
|
||||
return nil, io.EOF
|
||||
}
|
||||
err := s.rows.Scan(s.iRow...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlRow := make(sql.Row, len(s.vRow))
|
||||
for i, val := range s.vRow {
|
||||
if val != nil {
|
||||
sqlRow[i] = *val
|
||||
}
|
||||
}
|
||||
s.finished = !s.rows.Next()
|
||||
return sqlRow, nil
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) HasMoreRows() bool {
|
||||
return !s.finished
|
||||
}
|
||||
|
||||
func (s *MysqlRowWrapper) Close(*sql.Context) error {
|
||||
return s.rows.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user