First pass at making filter branch accept multiple queries and ignore errors

This commit is contained in:
Zach Musgrave
2023-11-01 17:37:03 -07:00
parent 9385688ea5
commit 32a7d190f2

View File

@@ -24,7 +24,6 @@ import (
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/fatih/color"
"gopkg.in/src-d/go-errors.v1"
@@ -49,7 +48,7 @@ const (
var filterBranchDocs = cli.CommandDocumentationContent{
ShortDesc: "Edits the commit history using the provided query",
LongDesc: `Traverses the commit history to the initial commit starting at the current HEAD commit. Replays all commits, rewriting the history using the provided SQL query.
LongDesc: `Traverses the commit history to the initial commit starting at the current HEAD commit. Replays all commits, rewriting the history using the provided SQL queries. Separate multiple queries with semicolons. Use the DELIMITER syntax to define stored procedures, triggers, etc.
If a {{.LessThan}}commit-spec{{.GreaterThan}} is provided, the traversal will stop when the commit is reached and rewriting will begin at that commit, or will error if the commit is not found.
@@ -59,7 +58,7 @@ If the {{.EmphasisLeft}}--all{{.EmphasisRight}} flag is supplied, filter-branch
`,
Synopsis: []string{
"[--all] {{.LessThan}}query{{.GreaterThan}} [{{.LessThan}}commit{{.GreaterThan}}]",
"[--all] {{.LessThan}}queries{{.GreaterThan}} [{{.LessThan}}commit{{.GreaterThan}}]",
},
}
@@ -87,6 +86,7 @@ func (cmd FilterBranchCmd) ArgParser() *argparser.ArgParser {
ap.SupportsFlag(cli.VerboseFlag, "v", "logs more information")
ap.SupportsFlag(branchesFlag, "b", "filter all branches")
ap.SupportsFlag(cli.AllFlag, "a", "filter all branches and tags")
ap.SupportsFlag(continueFlag, "c", "log a warning and continue if any errors occur executing statements")
return ap
}
@@ -113,18 +113,24 @@ func (cmd FilterBranchCmd) Exec(ctx context.Context, commandStr string, args []s
query := apr.Arg(0)
verbose := apr.Contains(cli.VerboseFlag)
continueOnErr := apr.Contains(continueFlag)
notFound := make(missingTbls)
replay := func(ctx context.Context, commit, _, _ *doltdb.Commit) (*doltdb.RootValue, error) {
var cmHash, before hash.Hash
var root *doltdb.RootValue
if verbose {
var err error
cmHash, err = commit.HashOf()
if err != nil {
return nil, err
}
cli.Printf("processing commit %s\n", cmHash.String())
root, err := commit.GetRootValue(ctx)
if verbose {
cli.Printf("processing commit %s\n", cmHash.String())
}
root, err = commit.GetRootValue(ctx)
if err != nil {
return nil, err
}
@@ -134,22 +140,30 @@ func (cmd FilterBranchCmd) Exec(ctx context.Context, commandStr string, args []s
}
}
root, err := processFilterQuery(ctx, dEnv, commit, query, notFound)
updatedRoot, err := processFilterQuery(ctx, dEnv, commit, query, notFound)
if err != nil {
return nil, err
if continueOnErr {
cli.PrintErrln("error encountered processing commit %s (continuing): %s", cmHash.String(), err.Error())
return root, nil
} else {
return nil, err
}
}
if verbose {
after, err := root.HashOf()
after, err := updatedRoot.HashOf()
if err != nil {
return nil, err
}
if before != after {
cli.Printf("updated commit %s (root: %s -> %s)\n",
cmHash.String(), before.String(), after.String())
} else {
cli.Printf("no changes to commit %s", cmHash.String())
}
}
return root, nil
return updatedRoot, nil
}
nerf, err := getNerf(ctx, dEnv, apr)
@@ -215,51 +229,38 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commi
return nil, err
}
sqlStatement, err := sqlparser.Parse(query)
scanner := NewSqlStatementScanner(strings.NewReader(query))
if err != nil {
return nil, err
}
for scanner.Scan() {
q := scanner.Text()
_, itr, err := eng.Query(sqlCtx, q)
itr := sql.RowsToRowIter() // empty RowIter
switch sqlStatement.(type) {
case *sqlparser.Insert, *sqlparser.Update:
_, itr, err = eng.Query(sqlCtx, query)
err, ok := captureTblNotFoundErr(err, mt, rh)
if ok {
// table doesn't exist, save the error and continue
return root, nil
}
if err != nil {
return nil, err
}
case *sqlparser.Delete:
_, itr, err = eng.Query(sqlCtx, query)
case *sqlparser.AlterTable:
_, itr, err = eng.Query(sqlCtx, query)
case *sqlparser.DDL:
_, itr, err = eng.Query(sqlCtx, query)
case *sqlparser.Select, *sqlparser.OtherRead, *sqlparser.Show, *sqlparser.Explain, *sqlparser.SetOp:
return nil, fmt.Errorf("filter-branch queries must be write queries: '%s'", query)
default:
return nil, fmt.Errorf("SQL statement not supported for filter-branch: '%s'", query)
}
err, ok := captureTblNotFoundErr(err, mt, rh)
if ok {
// table doesn't exist, save the error and continue
return root, nil
}
if err != nil {
return nil, err
}
for {
_, err = itr.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
for {
_, err = itr.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
}
err = itr.Close(sqlCtx)
if err != nil {
return nil, err
}
}
err = itr.Close(sqlCtx)
if err != nil {
return nil, err
}
sess := dsess.DSessFromSess(sqlCtx.Session)
ws, err := sess.WorkingSet(sqlCtx, filterDbName)
if err != nil {