diff --git a/go/libraries/doltcore/env/actions/reset.go b/go/libraries/doltcore/env/actions/reset.go index 4d43702376..bec0a40027 100644 --- a/go/libraries/doltcore/env/actions/reset.go +++ b/go/libraries/doltcore/env/actions/reset.go @@ -25,7 +25,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/schema" - "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/store/datas" ) @@ -201,13 +200,13 @@ func ResetHard( return nil } -func ResetSoftTables(ctx context.Context, apr *argparser.ArgParseResults, roots doltdb.Roots) (doltdb.Roots, error) { - tables, err := getUnionedTables(ctx, tableNamesFromArgs(apr.Args), roots.Staged, roots.Head) +func ResetSoftTables(ctx context.Context, tableNames []doltdb.TableName, roots doltdb.Roots) (doltdb.Roots, error) { + tables, err := getUnionedTables(ctx, tableNames, roots.Staged, roots.Head) if err != nil { return doltdb.Roots{}, err } - err = ValidateTables(context.TODO(), tables, roots.Staged, roots.Head) + err = ValidateTables(ctx, tables, roots.Staged, roots.Head) if err != nil { return doltdb.Roots{}, err } @@ -220,14 +219,6 @@ func ResetSoftTables(ctx context.Context, apr *argparser.ArgParseResults, roots return roots, nil } -func tableNamesFromArgs(args []string) []doltdb.TableName { - tbls := make([]doltdb.TableName, len(args)) - for i, arg := range args { - tbls[i] = doltdb.TableName{Name: arg} - } - return tbls -} - // ResetSoftToRef matches the `git reset --soft ` pattern. It returns a new Roots with the Staged and Head values // set to the commit specified by the spec string. The Working root is not set func ResetSoftToRef(ctx context.Context, dbData env.DbData, cSpecStr string) (doltdb.Roots, error) { @@ -266,7 +257,7 @@ func ResetSoftToRef(ctx context.Context, dbData env.DbData, cSpecStr string) (do } func getUnionedTables(ctx context.Context, tables []doltdb.TableName, stagedRoot, headRoot doltdb.RootValue) ([]doltdb.TableName, error) { - if len(tables) == 0 || (len(tables) == 1 && tables[0].Name == ".") { + if len(tables) == 0 { var err error tables, err = doltdb.UnionTableNames(ctx, stagedRoot, headRoot) diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go b/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go index 0c53b7a548..42c61b48e0 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_reset.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/dolt/go/libraries/utils/argparser" ) @@ -93,14 +94,19 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) { } } else { if apr.NArg() != 1 || (apr.NArg() == 1 && apr.Arg(0) == ".") { - err := resetSoftTables(ctx, apr, roots, dSess, dbName) + err := resetSoftTables(ctx, nil, roots, dSess, dbName) if err != nil { return 1, err } } else { // check if the input is a table name or commit ref - if isTableInRoots(ctx, roots, apr.Arg(0)) { - err := resetSoftTables(ctx, apr, roots, dSess, dbName) + tblName, inRoots, err := isTableInRoots(ctx, roots, apr.Arg(0)) + if err != nil { + return 0, err + } + + if inRoots { + err := resetSoftTables(ctx, []doltdb.TableName{tblName}, roots, dSess, dbName) if err != nil { return 1, err } @@ -144,23 +150,44 @@ func resetSoftToRef( } // isTableInRoots returns true if the table given exists in any of the roots given -func isTableInRoots(ctx *sql.Context, roots doltdb.Roots, tableName string) bool { - _, tableNameInHead, _ := roots.Head.ResolveTableName(ctx, doltdb.TableName{Name: tableName}) - _, tableNameInStaged, _ := roots.Staged.ResolveTableName(ctx, doltdb.TableName{Name: tableName}) - _, tableNameInWorking, _ := roots.Working.ResolveTableName(ctx, doltdb.TableName{Name: tableName}) - isTableName := tableNameInHead || tableNameInStaged || tableNameInWorking - return isTableName +func isTableInRoots(ctx *sql.Context, roots doltdb.Roots, tableName string) (doltdb.TableName, bool, error) { + resolvedName, _, tableNameInHead, err := resolve.Table(ctx, roots.Head, tableName) + if err != nil { + return resolvedName, false, err + } + if tableNameInHead { + return resolvedName, true, nil + } + + resolvedName, _, tableNameInStaged, err := resolve.Table(ctx, roots.Staged, tableName) + if err != nil { + return resolvedName, false, err + } + if tableNameInStaged { + return resolvedName, true, nil + } + + resolvedName, _, tableNameInWorking, err := resolve.Table(ctx, roots.Working, tableName) + if err != nil { + return resolvedName, false, err + } + if tableNameInWorking { + return resolvedName, true, nil + } + + return doltdb.TableName{}, false, nil } -// resetSoftTables replaces staged tables named from HEAD +// resetSoftTables replaces staged tables named from HEAD. A nil table name slice resets all table names from +// HEAD and STAGED func resetSoftTables( ctx *sql.Context, - apr *argparser.ArgParseResults, + tableNames []doltdb.TableName, roots doltdb.Roots, dSess *dsess.DoltSession, dbName string, ) error { - roots, err := actions.ResetSoftTables(ctx, apr, roots) + roots, err := actions.ResetSoftTables(ctx, tableNames, roots) if err != nil { return err } diff --git a/go/libraries/doltcore/sqle/dtables/constraint_violations_prolly.go b/go/libraries/doltcore/sqle/dtables/constraint_violations_prolly.go index bb93d8d8a1..cc3745a018 100644 --- a/go/libraries/doltcore/sqle/dtables/constraint_violations_prolly.go +++ b/go/libraries/doltcore/sqle/dtables/constraint_violations_prolly.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/merge" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/pool" @@ -33,7 +34,7 @@ import ( ) func newProllyCVTable(ctx *sql.Context, tblName string, root doltdb.RootValue, rs RootSetter) (sql.Table, error) { - tbl, tblName, ok, err := doltdb.GetTableInsensitive(ctx, root, doltdb.TableName{Name: tblName}) + resolvedName, tbl, ok, err := resolve.Table(ctx, root, tblName) if err != nil { return nil, err } else if !ok { @@ -43,7 +44,7 @@ func newProllyCVTable(ctx *sql.Context, tblName string, root doltdb.RootValue, r if err != nil { return nil, err } - sqlSch, err := sqlutil.FromDoltSchema("", doltdb.DoltConstViolTablePrefix+tblName, cvSch) + sqlSch, err := sqlutil.FromDoltSchema("", doltdb.DoltConstViolTablePrefix+resolvedName.Name, cvSch) if err != nil { return nil, err } @@ -54,7 +55,7 @@ func newProllyCVTable(ctx *sql.Context, tblName string, root doltdb.RootValue, r } m := durable.ProllyMapFromArtifactIndex(arts) return &prollyConstraintViolationsTable{ - tblName: tblName, + tblName: resolvedName.Name, root: root, sqlSch: sqlSch, tbl: tbl,