diff --git a/bats/sql-batch.bats b/bats/sql-batch.bats index 5b05e8e154..c9ffc0a150 100644 --- a/bats/sql-batch.bats +++ b/bats/sql-batch.bats @@ -80,7 +80,7 @@ SQL [[ "$output" =~ "poop" ]] || false } -@test "sql dolt_reset('hard') function" { +@test "sql reset('hard') function" { mkdir test && cd test && dolt init dolt sql < LICENSE.md + dolt add . + + run dolt sql -q "SELECT DOLT_RESET('--hard')" + [ $status -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Untracked files:" ]] || false + [[ "$output" =~ ([[:space:]]*new doc:[[:space:]]*LICENSE.md) ]] || false + + # Tracked file gets reset + dolt commit -a -m "Add a the license file" + echo ~edited-license~ > LICENSE.md + + dolt add . + + run dolt sql -q "SELECT DOLT_RESET('--hard')" + [ $status -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*LICENSE.md) ]] || false +} + +@test "DOLT_RESET --soft works on unstaged and staged table changes" { + dolt sql -q "INSERT INTO test VALUES (1)" + + # Table should still be unstaged + run dolt sql -q "SELECT DOLT_RESET('--soft')" + [ $status -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false + + dolt add . + + run dolt sql -q "SELECT DOLT_RESET('--soft')" + [ $status -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false +} + +@test "DOLT_RESET --soft ignores staged docs" { + echo ~license~ > LICENSE.md + dolt add . + + run dolt sql -q "SELECT DOLT_RESET('--soft')" + [ $status -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes to be committed:" ]] || false + [[ "$output" =~ ([[:space:]]*new doc:[[:space:]]*LICENSE.md) ]] || false + + # Explicitly defining the file ignores it. + run dolt sql -q "SELECT DOLT_RESET('LICENSE.md')" + [ "$status" -eq 1 ] + [[ "$output" =~ ("error: the table(s) LICENSE.md do not exist") ]] || false +} + +@test "DOLT_RESET works on specific tables" { + dolt sql -q "INSERT INTO test VALUES (1)" + + # Table should still be unstaged + run dolt sql -q "SELECT DOLT_RESET('test')" + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false + + dolt sql -q "CREATE TABLE test2 (pk int primary key);" + + dolt add . + run dolt sql -q "SELECT DOLT_RESET('test', 'test2')" + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false + [[ "$output" =~ ([[:space:]]*new table:[[:space:]]*test2) ]] || false +} + +@test "DOLT_RESET --soft and --hard on the same table" { + # Make a change to the table and do a soft reset + dolt sql -q "INSERT INTO test VALUES (1)" + + run dolt sql -q "SELECT DOLT_RESET('test')" + [ "$status" -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false + + # Add and unstage the table with a soft reset. Make sure the same data exists. + dolt add . + + run dolt sql -q "SELECT DOLT_RESET('test')" + [ "$status" -eq 0 ] + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "Changes not staged for commit:" ]] || false + [[ "$output" =~ ([[:space:]]*modified:[[:space:]]*test) ]] || false + + run dolt sql -r csv -q "select * from test" + [[ "$output" =~ pk ]] || false + [[ "$output" =~ 1 ]] || false + + # Do a hard reset and validate the insert was wiped properly + run dolt sql -q "SELECT DOLT_RESET('--hard')" + + run dolt status + [ "$status" -eq 0 ] + [[ "$output" =~ "On branch master" ]] || false + [[ "$output" =~ "nothing to commit, working tree clean" ]] || false + + run dolt sql -r csv -q "select * from test" + [[ "$output" =~ pk ]] || false + [[ "$output" != 1 ]] || false +} diff --git a/bats/sql-server.bats b/bats/sql-server.bats index bf08e51c97..05fc73b65d 100644 --- a/bats/sql-server.bats +++ b/bats/sql-server.bats @@ -307,7 +307,7 @@ SQL [[ "$output" =~ "test" ]] || false multi_query 1 " - SET @@repo1_head = dolt_reset('hard'); + SET @@repo1_head = reset('hard'); REPLACE INTO dolt_branches (name,hash) VALUES ('master', @@repo1_head);" run dolt status @@ -319,7 +319,7 @@ SQL multi_query 1 " INSERT INTO test VALUES (8,8); - SET @@repo1_head=dolt_reset('hard'); + SET @@repo1_head = reset('hard'); REPLACE INTO dolt_branches (name,hash) VALUES ('master', @@repo1_head);" run dolt status diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 32eec1bd0d..d9304dbeaa 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -78,6 +78,8 @@ const ( AuthorParam = "author" ForceFlag = "force" AllFlag = "all" + HardResetParam = "hard" + SoftResetParam = "soft" ) // Creates the argparser shared dolt commit cli and DOLT_COMMIT. @@ -98,3 +100,10 @@ func CreateAddArgParser() *argparser.ArgParser { ap.SupportsFlag("all", "A", "Stages any and all changes (adds, deletes, and modifications).") return ap } + +func CreateResetArgParser() *argparser.ArgParser { + ap := argparser.NewArgParser() + ap.SupportsFlag(HardResetParam, "", "Resets the working tables and staged tables. Any changes to tracked tables in the working tree since {{.LessThan}}commit{{.GreaterThan}} are discarded.") + ap.SupportsFlag(SoftResetParam, "", "Does not touch the working tables, but removes all tables staged to be committed.") + return ap +} diff --git a/go/cmd/dolt/commands/checkout.go b/go/cmd/dolt/commands/checkout.go index a2607e8c91..6b8a8dc2a8 100644 --- a/go/cmd/dolt/commands/checkout.go +++ b/go/cmd/dolt/commands/checkout.go @@ -121,7 +121,7 @@ func (cmd CheckoutCmd) Exec(ctx context.Context, commandStr string, args []strin return HandleVErrAndExitCode(verr, usagePrt) } - tbls, docs, err := actions.GetTblsAndDocDetails(dEnv, args) + tbls, docs, err := actions.GetTblsAndDocDetails(dEnv.DocsReadWriter(), args) if err != nil { verr := errhand.BuildDError("error: unable to parse arguments.").AddCause(err).Build() return HandleVErrAndExitCode(verr, usagePrt) diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index 36fa8ca5a3..a300049a77 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -874,7 +874,7 @@ func createSplitter(fromSch schema.Schema, toSch schema.Schema, joiner *rowconv. } func diffDoltDocs(ctx context.Context, dEnv *env.DoltEnv, from, to *doltdb.RootValue, dArgs *diffArgs) error { - _, docDetails, err := actions.GetTblsAndDocDetails(dEnv, dArgs.docSet.AsSlice()) + _, docDetails, err := actions.GetTblsAndDocDetails(dEnv.DocsReadWriter(), dArgs.docSet.AsSlice()) if err != nil { return err diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index d379d9cc48..d5e89869eb 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -418,7 +418,7 @@ func mergedRootToWorking(ctx context.Context, squash bool, dEnv *env.DoltEnv, me if err != nil { return errhand.BuildDError("error: failed to update docs to the new working root").AddCause(err).Build() } - verr = UpdateStagedWithVErr(dEnv, mergedRoot) + verr = UpdateStagedWithVErr(dEnv.DoltDB, dEnv.RepoStateWriter(), mergedRoot) if verr != nil { // Log a new message here to indicate that merge was successful, only staging failed. cli.Println("Unable to stage changes: add and commit to finish merge") diff --git a/go/cmd/dolt/commands/reset.go b/go/cmd/dolt/commands/reset.go index 0ba2aeed75..f5aee0b18f 100644 --- a/go/cmd/dolt/commands/reset.go +++ b/go/cmd/dolt/commands/reset.go @@ -27,7 +27,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" - "github.com/dolthub/dolt/go/libraries/utils/argparser" ) const ( @@ -71,20 +70,13 @@ func (cmd ResetCmd) Description() string { // CreateMarkdown creates a markdown file containing the helptext for the command at the given path func (cmd ResetCmd) CreateMarkdown(fs filesys.Filesys, path, commandStr string) error { - ap := cmd.createArgParser() + ap := cli.CreateResetArgParser() return CreateMarkdown(fs, path, cli.GetCommandDocumentation(commandStr, resetDocContent, ap)) } -func (cmd ResetCmd) createArgParser() *argparser.ArgParser { - ap := argparser.NewArgParser() - ap.SupportsFlag(HardResetParam, "", "Resets the working tables and staged tables. Any changes to tracked tables in the working tree since {{.LessThan}}commit{{.GreaterThan}} are discarded.") - ap.SupportsFlag(SoftResetParam, "", "Does not touch the working tables, but removes all tables staged to be committed.") - return ap -} - // Exec executes the command func (cmd ResetCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv) int { - ap := cmd.createArgParser() + ap := cli.CreateResetArgParser() help, usage := cli.HelpAndUsagePrinters(cli.GetCommandDocumentation(commandStr, resetDocContent, ap)) apr := cli.ParseArgs(ap, args, help) @@ -94,170 +86,25 @@ func (cmd ResetCmd) Exec(ctx context.Context, commandStr string, args []string, workingRoot, stagedRoot, headRoot, verr := getAllRoots(ctx, dEnv) + var err error if verr == nil { if apr.ContainsAll(HardResetParam, SoftResetParam) { verr = errhand.BuildDError("error: --%s and --%s are mutually exclusive options.", HardResetParam, SoftResetParam).Build() + HandleVErrAndExitCode(verr, usage) } else if apr.Contains(HardResetParam) { - verr = resetHard(ctx, dEnv, apr, workingRoot, stagedRoot, headRoot) + err = actions.ResetHard(ctx, dEnv, apr, workingRoot, stagedRoot, headRoot) } else { - verr = resetSoft(ctx, dEnv, apr, stagedRoot, headRoot) + stagedRoot, err = actions.ResetSoft(ctx, dEnv, apr, stagedRoot, headRoot) + + if err != nil { + return handleResetError(err, usage) + } + + printNotStaged(ctx, dEnv, stagedRoot) } } - return HandleVErrAndExitCode(verr, usage) -} - -func resetHard(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, workingRoot, stagedRoot, headRoot *doltdb.RootValue) errhand.VerboseError { - if apr.NArg() > 1 { - return errhand.BuildDError("--%s supports at most one additional param", HardResetParam).SetPrintUsage().Build() - } - - var newHead *doltdb.Commit - if apr.NArg() == 1 { - cs, err := doltdb.NewCommitSpec(apr.Arg(0)) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - - newHead, err = dEnv.DoltDB.Resolve(ctx, cs, dEnv.RepoState.CWBHeadRef()) - if err != nil { - return errhand.VerboseErrorFromError(err) - } - - headRoot, err = newHead.GetRootValue() - if err != nil { - return errhand.VerboseErrorFromError(err) - } - } - - // need to save the state of files that aren't tracked - untrackedTables := make(map[string]*doltdb.Table) - wTblNames, err := workingRoot.GetTableNames(ctx) - - if err != nil { - return errhand.BuildDError("error: failed to read tables from the working set").AddCause(err).Build() - } - - for _, tblName := range wTblNames { - untrackedTables[tblName], _, err = workingRoot.GetTable(ctx, tblName) - - if err != nil { - return errhand.BuildDError("error: failed to read '%s' from the working set", tblName).AddCause(err).Build() - } - } - - headTblNames, err := stagedRoot.GetTableNames(ctx) - - if err != nil { - return errhand.BuildDError("error: failed to read tables from head").AddCause(err).Build() - } - - for _, tblName := range headTblNames { - delete(untrackedTables, tblName) - } - - newWkRoot := headRoot - for tblName, tbl := range untrackedTables { - if tblName != doltdb.DocTableName { - newWkRoot, err = newWkRoot.PutTable(ctx, tblName, tbl) - } - if err != nil { - return errhand.BuildDError("error: failed to write table back to database").Build() - } - } - - // TODO: update working and staged in one repo_state write. - err = dEnv.UpdateWorkingRoot(ctx, newWkRoot) - - if err != nil { - return errhand.BuildDError("error: failed to update the working tables.").AddCause(err).Build() - } - - _, err = dEnv.UpdateStagedRoot(ctx, headRoot) - - if err != nil { - return errhand.BuildDError("error: failed to update the staged tables.").AddCause(err).Build() - } - - err = actions.SaveTrackedDocsFromWorking(ctx, dEnv) - if err != nil { - return errhand.BuildDError("error: failed to update docs on the filesystem.").AddCause(err).Build() - } - - if newHead != nil { - if err = dEnv.DoltDB.SetHeadToCommit(ctx, dEnv.RepoState.CWBHeadRef(), newHead); err != nil { - return errhand.VerboseErrorFromError(err) - } - } - - return nil -} - -// RemoveDocsTbl takes a slice of table names and returns a new slice with DocTableName removed. -func RemoveDocsTbl(tbls []string) []string { - var result []string - for _, tblName := range tbls { - if tblName != doltdb.DocTableName { - result = append(result, tblName) - } - } - return result -} - -func resetSoft(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, stagedRoot, headRoot *doltdb.RootValue) errhand.VerboseError { - tbls := apr.Args() - - if len(tbls) == 0 || (len(tbls) == 1 && tbls[0] == ".") { - var err error - tbls, err = doltdb.UnionTableNames(ctx, stagedRoot, headRoot) - - if err != nil { - return errhand.BuildDError("error: failed to get all tables").AddCause(err).Build() - } - } - - tables, docs, err := actions.GetTblsAndDocDetails(dEnv, tbls) - if err != nil { - return errhand.BuildDError("error: failed to get all tables").AddCause(err).Build() - } - - if len(docs) > 0 { - tables = RemoveDocsTbl(tables) - } - - verr := ValidateTablesWithVErr(tables, stagedRoot, headRoot) - - if verr != nil { - return verr - } - - stagedRoot, err = resetDocs(ctx, dEnv, headRoot, docs) - if err != nil { - return errhand.BuildDError("error: failed to reset docs").AddCause(err).Build() - } - - stagedRoot, verr = resetStaged(ctx, dEnv, tables, stagedRoot, headRoot) - - if verr != nil { - return verr - } - - printNotStaged(ctx, dEnv, stagedRoot) - return nil -} - -func resetDocs(ctx context.Context, dEnv *env.DoltEnv, headRoot *doltdb.RootValue, docDetails env.Docs) (newStgRoot *doltdb.RootValue, err error) { - docs, err := dEnv.GetDocsWithNewerTextFromRoot(ctx, headRoot, docDetails) - if err != nil { - return nil, err - } - - err = dEnv.PutDocsToWorking(ctx, docs) - if err != nil { - return nil, err - } - - return dEnv.PutDocsToStaged(ctx, docs) + return handleResetError(err, usage) } func printNotStaged(ctx context.Context, dEnv *env.DoltEnv, staged *doltdb.RootValue) { @@ -314,15 +161,24 @@ func printNotStaged(ctx context.Context, dEnv *env.DoltEnv, staged *doltdb.RootV } } -func resetStaged(ctx context.Context, dEnv *env.DoltEnv, tbls []string, staged, head *doltdb.RootValue) (*doltdb.RootValue, errhand.VerboseError) { - updatedRoot, err := actions.MoveTablesBetweenRoots(ctx, tbls, head, staged) +func handleResetError(err error, usage cli.UsagePrinter) int { + if actions.IsTblNotExist(err) { + tbls := actions.GetTablesForError(err) + bdr := errhand.BuildDError("Invalid Table(s):") - if err != nil { - tt := strings.Join(tbls, ", ") - return nil, errhand.BuildDError("error: failed to unstage tables: %s", tt).AddCause(err).Build() + for _, tbl := range tbls { + bdr.AddDetails("\t" + tbl) + } + + return HandleVErrAndExitCode(bdr.Build(), usage) } - return updatedRoot, UpdateStagedWithVErr(dEnv, updatedRoot) + var verr errhand.VerboseError = nil + if err != nil { + verr = errhand.BuildDError("error: Failed to reset changes.").AddCause(err).Build() + } + + return HandleVErrAndExitCode(verr, usage) } func getAllRoots(ctx context.Context, dEnv *env.DoltEnv) (*doltdb.RootValue, *doltdb.RootValue, *doltdb.RootValue, errhand.VerboseError) { diff --git a/go/cmd/dolt/commands/schcmds/show.go b/go/cmd/dolt/commands/schcmds/show.go index b392b863af..e94c5cddb0 100644 --- a/go/cmd/dolt/commands/schcmds/show.go +++ b/go/cmd/dolt/commands/schcmds/show.go @@ -25,6 +25,7 @@ import ( eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/filesys" @@ -126,7 +127,7 @@ func printSchemas(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env return errhand.BuildDError("unable to get table names.").AddCause(err).Build() } - tables = commands.RemoveDocsTbl(tables) + tables = actions.RemoveDocsTable(tables) if len(tables) == 0 { cli.Println("No tables in working set") return nil diff --git a/go/cmd/dolt/commands/schcmds/tags.go b/go/cmd/dolt/commands/schcmds/tags.go index 2253ccbe94..ccc434c502 100644 --- a/go/cmd/dolt/commands/schcmds/tags.go +++ b/go/cmd/dolt/commands/schcmds/tags.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/commands" "github.com/dolthub/dolt/go/cmd/dolt/errhand" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/filesys" @@ -82,7 +83,7 @@ func (cmd TagsCmd) Exec(ctx context.Context, commandStr string, args []string, d return commands.HandleVErrAndExitCode(errhand.BuildDError("unable to get table names.").AddCause(err).Build(), usage) } - tables = commands.RemoveDocsTbl(tables) + tables = actions.RemoveDocsTable(tables) if len(tables) == 0 { cli.Println("No tables in working set") return 0 diff --git a/go/cmd/dolt/commands/utils.go b/go/cmd/dolt/commands/utils.go index fcfd0857ba..3659ac4dec 100644 --- a/go/cmd/dolt/commands/utils.go +++ b/go/cmd/dolt/commands/utils.go @@ -58,8 +58,8 @@ func UpdateWorkingWithVErr(dEnv *env.DoltEnv, updatedRoot *doltdb.RootValue) err return nil } -func UpdateStagedWithVErr(dEnv *env.DoltEnv, updatedRoot *doltdb.RootValue) errhand.VerboseError { - _, err := dEnv.UpdateStagedRoot(context.Background(), updatedRoot) +func UpdateStagedWithVErr(ddb *doltdb.DoltDB, rsw env.RepoStateWriter, updatedRoot *doltdb.RootValue) errhand.VerboseError { + _, err := env.UpdateStagedRoot(context.Background(), ddb, rsw, updatedRoot) switch err { case doltdb.ErrNomsIO: @@ -71,27 +71,6 @@ func UpdateStagedWithVErr(dEnv *env.DoltEnv, updatedRoot *doltdb.RootValue) errh return nil } -func ValidateTablesWithVErr(tbls []string, roots ...*doltdb.RootValue) errhand.VerboseError { - err := actions.ValidateTables(context.TODO(), tbls, roots...) - - if err != nil { - if actions.IsTblNotExist(err) { - tbls := actions.GetTablesForError(err) - bdr := errhand.BuildDError("Invalid Table(s):") - - for _, tbl := range tbls { - bdr.AddDetails("\t" + tbl) - } - - return bdr.Build() - } else { - return errhand.BuildDError("fatal: " + err.Error()).Build() - } - } - - return nil -} - func ResolveCommitWithVErr(dEnv *env.DoltEnv, cSpecStr string) (*doltdb.Commit, errhand.VerboseError) { cs, err := doltdb.NewCommitSpec(cSpecStr) diff --git a/go/libraries/doltcore/env/actions/errors.go b/go/libraries/doltcore/env/actions/errors.go index e980fea88e..007bb6125d 100644 --- a/go/libraries/doltcore/env/actions/errors.go +++ b/go/libraries/doltcore/env/actions/errors.go @@ -42,7 +42,7 @@ func NewTblInConflictError(tbls []string) TblError { } func (te TblError) Error() string { - return "error: the tables " + strings.Join(te.tables, ", ") + string(te.tblErrType) + return "error: the table(s) " + strings.Join(te.tables, ", ") + " " + string(te.tblErrType) } func getTblErrType(err error) tblErrorType { diff --git a/go/libraries/doltcore/env/actions/reset.go b/go/libraries/doltcore/env/actions/reset.go new file mode 100644 index 0000000000..cbfc26842d --- /dev/null +++ b/go/libraries/doltcore/env/actions/reset.go @@ -0,0 +1,244 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package actions + +import ( + "context" + "errors" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/utils/argparser" +) + +func resetHardTables(ctx context.Context, dbData env.DbData, apr *argparser.ArgParseResults, workingRoot, stagedRoot, headRoot *doltdb.RootValue) (*doltdb.Commit, error) { + if apr.NArg() > 1 { + return nil, errors.New("--hard supports at most one additional param") + } + + ddb := dbData.Ddb + rsr := dbData.Rsr + rsw := dbData.Rsw + + var newHead *doltdb.Commit + if apr.NArg() == 1 { + cs, err := doltdb.NewCommitSpec(apr.Arg(0)) + if err != nil { + return nil, err + } + + newHead, err = ddb.Resolve(ctx, cs, rsr.CWBHeadRef()) + if err != nil { + return nil, err + } + + headRoot, err = newHead.GetRootValue() + if err != nil { + return nil, err + } + } + + // need to save the state of files that aren't tracked + untrackedTables := make(map[string]*doltdb.Table) + wTblNames, err := workingRoot.GetTableNames(ctx) + + if err != nil { + return nil, err + } + + for _, tblName := range wTblNames { + untrackedTables[tblName], _, err = workingRoot.GetTable(ctx, tblName) + + if err != nil { + return nil, err + } + } + + headTblNames, err := stagedRoot.GetTableNames(ctx) + + if err != nil { + return nil, err + } + + for _, tblName := range headTblNames { + delete(untrackedTables, tblName) + } + + newWkRoot := headRoot + for tblName, tbl := range untrackedTables { + if tblName != doltdb.DocTableName { + newWkRoot, err = newWkRoot.PutTable(ctx, tblName, tbl) + } + if err != nil { + return nil, errors.New("error: failed to write table back to database") + } + } + + _, err = env.UpdateWorkingRoot(ctx, ddb, rsw, newWkRoot) + + if err != nil { + return nil, err + } + + _, err = env.UpdateStagedRoot(ctx, ddb, rsw, headRoot) + + if err != nil { + return nil, err + } + + return newHead, nil +} + +func ResetHardTables(ctx context.Context, dbData env.DbData, apr *argparser.ArgParseResults, workingRoot, stagedRoot, headRoot *doltdb.RootValue) error { + newHead, err := resetHardTables(ctx, dbData, apr, workingRoot, stagedRoot, headRoot) + + if err != nil { + return err + } + + ddb := dbData.Ddb + rsr := dbData.Rsr + + if newHead != nil { + if err := ddb.SetHeadToCommit(ctx, rsr.CWBHeadRef(), newHead); err != nil { + return err + } + } + + return nil +} + +func ResetHard(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, workingRoot, stagedRoot, headRoot *doltdb.RootValue) error { + dbData := dEnv.DbData() + + newHead, err := resetHardTables(ctx, dbData, apr, workingRoot, stagedRoot, headRoot) + + if err != nil { + return err + } + + err = SaveTrackedDocsFromWorking(ctx, dEnv) + if err != nil { + return err + } + + ddb := dbData.Ddb + rsr := dbData.Rsr + + if newHead != nil { + if err = ddb.SetHeadToCommit(ctx, rsr.CWBHeadRef(), newHead); err != nil { + return err + } + } + + return nil +} + +func ResetSoftTables(ctx context.Context, dbData env.DbData, apr *argparser.ArgParseResults, stagedRoot, headRoot *doltdb.RootValue) (*doltdb.RootValue, error) { + tables, err := getUnionedTables(ctx, apr.Args(), stagedRoot, headRoot) + tables = RemoveDocsTable(tables) + + if err != nil { + return nil, err + } + + err = ValidateTables(context.TODO(), tables, stagedRoot, headRoot) + + if err != nil { + return nil, err + } + + stagedRoot, err = resetStaged(ctx, dbData.Ddb, dbData.Rsw, tables, stagedRoot, headRoot) + + if err != nil { + return nil, err + } + + return stagedRoot, nil +} + +func ResetSoft(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, stagedRoot, headRoot *doltdb.RootValue) (*doltdb.RootValue, error) { + tables, err := getUnionedTables(ctx, apr.Args(), stagedRoot, headRoot) + + if err != nil { + return nil, err + } + + dbData := dEnv.DbData() + tables, docs, err := GetTblsAndDocDetails(dbData.Drw, tables) + if err != nil { + return nil, err + } + + if len(docs) > 0 { + tables = RemoveDocsTable(tables) + } + + err = ValidateTables(context.TODO(), tables, stagedRoot, headRoot) + + if err != nil { + return nil, err + } + + stagedRoot, err = resetDocs(ctx, dEnv, headRoot, docs) + if err != nil { + return nil, err + } + + stagedRoot, err = resetStaged(ctx, dbData.Ddb, dbData.Rsw, tables, stagedRoot, headRoot) + + if err != nil { + return nil, err + } + + return stagedRoot, nil +} + +func getUnionedTables(ctx context.Context, tables []string, stagedRoot, headRoot *doltdb.RootValue) ([]string, error) { + if len(tables) == 0 || (len(tables) == 1 && tables[0] == ".") { + var err error + tables, err = doltdb.UnionTableNames(ctx, stagedRoot, headRoot) + + if err != nil { + return nil, err + } + } + + return tables, nil +} + +func resetDocs(ctx context.Context, dEnv *env.DoltEnv, headRoot *doltdb.RootValue, docDetails env.Docs) (newStgRoot *doltdb.RootValue, err error) { + docs, err := dEnv.GetDocsWithNewerTextFromRoot(ctx, headRoot, docDetails) + if err != nil { + return nil, err + } + + err = dEnv.PutDocsToWorking(ctx, docs) + if err != nil { + return nil, err + } + + return dEnv.PutDocsToStaged(ctx, docs) +} + +func resetStaged(ctx context.Context, ddb *doltdb.DoltDB, rsw env.RepoStateWriter, tbls []string, staged, head *doltdb.RootValue) (*doltdb.RootValue, error) { + updatedRoot, err := MoveTablesBetweenRoots(ctx, tbls, head, staged) + + if err != nil { + return nil, err + } + + return updatedRoot, env.UpdateStagedRootWithVErr(ddb, rsw, updatedRoot) +} diff --git a/go/libraries/doltcore/env/actions/table.go b/go/libraries/doltcore/env/actions/table.go index 9c0e3915bc..19ec72aee3 100644 --- a/go/libraries/doltcore/env/actions/table.go +++ b/go/libraries/doltcore/env/actions/table.go @@ -211,3 +211,14 @@ func validateTablesExist(ctx context.Context, currRoot *doltdb.RootValue, unknow return nil } + +// RemoveDocsTable takes a slice of table names and returns a new slice with DocTableName removed. +func RemoveDocsTable(tbls []string) []string { + var result []string + for _, tblName := range tbls { + if tblName != doltdb.DocTableName { + result = append(result, tblName) + } + } + return result +} diff --git a/go/libraries/doltcore/env/environment.go b/go/libraries/doltcore/env/environment.go index 1d0ebba6cc..35d84f2dfc 100644 --- a/go/libraries/doltcore/env/environment.go +++ b/go/libraries/doltcore/env/environment.go @@ -457,6 +457,10 @@ func (d *docsReadWriter) PutDocsToWorking(ctx context.Context, docDetails []dolt return d.dEnv.PutDocsToWorking(ctx, docDetails) } +func (d *docsReadWriter) PutDocsToStaged(ctx context.Context, docDetails []doltdb.DocDetails) (*doltdb.RootValue, error) { + return d.dEnv.PutDocsToStaged(ctx, docDetails) +} + func (d *docsReadWriter) ResetWorkingDocsToStagedDocs(ctx context.Context) error { return d.dEnv.ResetWorkingDocsToStagedDocs(ctx) } diff --git a/go/libraries/doltcore/env/repo_state.go b/go/libraries/doltcore/env/repo_state.go index a1b47d11f3..eb73d2cc37 100644 --- a/go/libraries/doltcore/env/repo_state.go +++ b/go/libraries/doltcore/env/repo_state.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" + "github.com/dolthub/dolt/go/cmd/dolt/errhand" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/utils/filesys" @@ -44,6 +45,7 @@ type RepoStateWriter interface { type DocsReadWriter interface { GetAllValidDocDetails() ([]doltdb.DocDetails, error) PutDocsToWorking(ctx context.Context, docDetails []doltdb.DocDetails) error + PutDocsToStaged(ctx context.Context, docDetails []doltdb.DocDetails) (*doltdb.RootValue, error) ResetWorkingDocsToStagedDocs(ctx context.Context) error GetDocDetail(docName string) (doc doltdb.DocDetails, err error) } @@ -249,3 +251,38 @@ func UpdateStagedRoot(ctx context.Context, ddb *doltdb.DoltDB, rsw RepoStateWrit return h, nil } + +func UpdateStagedRootWithVErr(ddb *doltdb.DoltDB, rsw RepoStateWriter, updatedRoot *doltdb.RootValue) errhand.VerboseError { + _, err := UpdateStagedRoot(context.Background(), ddb, rsw, updatedRoot) + + switch err { + case doltdb.ErrNomsIO: + return errhand.BuildDError("fatal: failed to write value").Build() + case ErrStateUpdate: + return errhand.BuildDError("fatal: failed to update the staged root state").Build() + } + + return nil +} + +func GetRoots(ctx context.Context, ddb *doltdb.DoltDB, rsr RepoStateReader) (working *doltdb.RootValue, staged *doltdb.RootValue, head *doltdb.RootValue, err error) { + working, err = WorkingRoot(ctx, ddb, rsr) + + if err != nil { + return nil, nil, nil, err + } + + staged, err = StagedRoot(ctx, ddb, rsr) + + if err != nil { + return nil, nil, nil, err + } + + head, err = HeadRoot(ctx, ddb, rsr) + + if err != nil { + return nil, nil, nil, err + } + + return working, staged, head, nil +} diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_add.go b/go/libraries/doltcore/sqle/dfunctions/dolt_add.go index 78aa0eeca3..319c1f192a 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_add.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_add.go @@ -94,6 +94,11 @@ func (d DoltAddFunc) Type() sql.Type { } func (d DoltAddFunc) IsNullable() bool { + for _, child := range d.Children() { + if child.IsNullable() { + return true + } + } return false } diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go b/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go new file mode 100644 index 0000000000..b838f58db0 --- /dev/null +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go @@ -0,0 +1,125 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dfunctions + +import ( + "fmt" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle" +) + +const DoltResetFuncName = "dolt_reset" + +type DoltResetFunc struct { + children []sql.Expression +} + +func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + dbName := ctx.GetCurrentDatabase() + + if len(dbName) == 0 { + return 1, fmt.Errorf("Empty database name.") + } + + dSess := sqle.DSessFromSess(ctx.Session) + dbData, ok := dSess.GetDbData(dbName) + + if !ok { + return 1, fmt.Errorf("Could not load database %s", dbName) + } + + ap := cli.CreateResetArgParser() + args, err := getDoltArgs(ctx, row, d.Children()) + + if err != nil { + return 1, err + } + + apr := cli.ParseArgs(ap, args, nil) + + // Check if problems with args first. + if apr.ContainsAll(cli.HardResetParam, cli.SoftResetParam) { + return 1, fmt.Errorf("error: --%s and --%s are mutually exclusive options.", cli.HardResetParam, cli.SoftResetParam) + } + + // Get all the needed roots. + working, staged, head, err := env.GetRoots(ctx, dbData.Ddb, dbData.Rsr) + + if err != nil { + return 1, err + } + + if apr.Contains(cli.HardResetParam) { + err = actions.ResetHardTables(ctx, dbData, apr, working, staged, head) + } else { + _, err = actions.ResetSoftTables(ctx, dbData, apr, staged, head) + } + + if err != nil { + return 1, err + } + + return 0, nil +} + +func (d DoltResetFunc) Resolved() bool { + for _, child := range d.Children() { + if !child.Resolved() { + return false + } + } + return true +} + +func (d DoltResetFunc) String() string { + childrenStrings := make([]string, len(d.children)) + + for i, child := range d.children { + childrenStrings[i] = child.String() + } + + return fmt.Sprintf("DOLT_RESET(%s)", strings.Join(childrenStrings, ",")) +} + +func (d DoltResetFunc) Type() sql.Type { + return sql.Int8 +} + +func (d DoltResetFunc) IsNullable() bool { + for _, child := range d.Children() { + if child.IsNullable() { + return true + } + } + return false +} + +func (d DoltResetFunc) Children() []sql.Expression { + return d.children +} + +func (d DoltResetFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewDoltResetFunc(children...) +} + +func NewDoltResetFunc(args ...sql.Expression) (sql.Expression, error) { + return DoltResetFunc{children: args}, nil +} diff --git a/go/libraries/doltcore/sqle/dfunctions/init.go b/go/libraries/doltcore/sqle/dfunctions/init.go index 9bc4427cc2..d5ddaf821c 100644 --- a/go/libraries/doltcore/sqle/dfunctions/init.go +++ b/go/libraries/doltcore/sqle/dfunctions/init.go @@ -20,10 +20,11 @@ var DoltFunctions = []sql.Function{ sql.Function1{Name: HashOfFuncName, Fn: NewHashOf}, sql.FunctionN{Name: CommitFuncName, Fn: NewCommitFunc}, sql.FunctionN{Name: MergeFuncName, Fn: NewMergeFunc}, - sql.Function1{Name: resetFuncName, Fn: NewDoltResetFunc}, + sql.Function1{Name: resetFuncName, Fn: NewResetFunc}, sql.Function0{Name: VersionFuncName, Fn: NewVersion}, sql.FunctionN{Name: DoltCommitFuncName, Fn: NewDoltCommitFunc}, sql.FunctionN{Name: DoltAddFuncName, Fn: NewDoltAddFunc}, + sql.FunctionN{Name: DoltResetFuncName, Fn: NewDoltResetFunc}, } // These are the DoltFunctions that get exposed to Dolthub Api. diff --git a/go/libraries/doltcore/sqle/dfunctions/reset.go b/go/libraries/doltcore/sqle/dfunctions/reset.go index 89c178f02e..f0a1bdebfb 100644 --- a/go/libraries/doltcore/sqle/dfunctions/reset.go +++ b/go/libraries/doltcore/sqle/dfunctions/reset.go @@ -26,22 +26,22 @@ import ( ) const ( - resetFuncName = "dolt_reset" + resetFuncName = "reset" resetHardParameter = "hard" ) -type DoltResetFunc struct { +type ResetFunc struct { expression.UnaryExpression } -// NewDoltResetFunc creates a new DoltResetFunc expression. -func NewDoltResetFunc(e sql.Expression) sql.Expression { - return DoltResetFunc{expression.UnaryExpression{Child: e}} +// NewDoltResetFunc creates a new ResetFunc expression. +func NewResetFunc(e sql.Expression) sql.Expression { + return ResetFunc{expression.UnaryExpression{Child: e}} } // Eval implements the Expression interface. -func (rf DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (rf ResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { val, err := rf.Child.Eval(ctx, row) if err != nil { return nil, err @@ -55,7 +55,7 @@ func (rf DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) dSess := sqle.DSessFromSess(ctx.Session) var h hash.Hash - if strings.ToLower(arg) != "hard" { + if strings.ToLower(arg) != resetHardParameter { return nil, fmt.Errorf("invalid arugument to %s(): %s", resetFuncName, arg) } @@ -73,34 +73,34 @@ func (rf DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } // Resolved implements the Expression interface. -func (rf DoltResetFunc) Resolved() bool { +func (rf ResetFunc) Resolved() bool { return rf.Child.Resolved() } // String implements the Stringer interface. -func (rf DoltResetFunc) String() string { +func (rf ResetFunc) String() string { return fmt.Sprintf("RESET_HARD(%s)", rf.Child.String()) } // IsNullable implements the Expression interface. -func (rf DoltResetFunc) IsNullable() bool { +func (rf ResetFunc) IsNullable() bool { return false } // Children implements the Expression interface. -func (rf DoltResetFunc) Children() []sql.Expression { +func (rf ResetFunc) Children() []sql.Expression { return []sql.Expression{rf.Child} } // WithChildren implements the Expression interface. -func (rf DoltResetFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (rf ResetFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(rf, len(children), 1) } - return NewDoltResetFunc(children[0]), nil + return NewResetFunc(children[0]), nil } // Type implements the Expression interface. -func (rf DoltResetFunc) Type() sql.Type { +func (rf ResetFunc) Type() sql.Type { return sql.Text }