diff --git a/go/cmd/dolt/commands/reflog.go b/go/cmd/dolt/commands/reflog.go index 0653aba3fd..4c07cfe773 100644 --- a/go/cmd/dolt/commands/reflog.go +++ b/go/cmd/dolt/commands/reflog.go @@ -136,12 +136,12 @@ type ReflogInfo struct { func printReflog(rows []sql.Row, queryist cli.Queryist, sqlCtx *sql.Context) int { var reflogInfo []ReflogInfo - // Get the current branch - curBranch := "" - res, err := GetRowsForSql(queryist, sqlCtx, "SELECT active_branch()") + // Get the hash of HEAD for the `HEAD ->` decoration + headHash := "" + res, err := GetRowsForSql(queryist, sqlCtx, "SELECT hashof('HEAD')") if err == nil { - // still print the reflog even if we can't get the current branch - curBranch = res[0][0].(string) + // still print the reflog even if we can't get the hash + headHash = res[0][0].(string) } for _, row := range rows { @@ -151,13 +151,13 @@ func printReflog(rows []sql.Row, queryist cli.Queryist, sqlCtx *sql.Context) int reflogInfo = append(reflogInfo, ReflogInfo{ref, commitHash, commitMessage}) } - reflogToStdOut(reflogInfo, curBranch) + reflogToStdOut(reflogInfo, headHash) return 0 } // reflogToStdOut takes a list of ReflogInfo and prints the reflog to stdout -func reflogToStdOut(reflogInfo []ReflogInfo, curBranch string) { +func reflogToStdOut(reflogInfo []ReflogInfo, headHash string) { if cli.ExecuteWithStdioRestored == nil { return } @@ -169,8 +169,12 @@ func reflogToStdOut(reflogInfo []ReflogInfo, curBranch string) { // TODO: use short hash instead line := []string{fmt.Sprintf("\033[33m%s\033[0m", info.commitHash)} // commit hash in yellow (33m) - processedRef := processRefForReflog(info.ref, curBranch) - line = append(line, fmt.Sprintf("\033[33m(%s\033[33m)\033[0m", processedRef)) // () in yellow (33m) + processedRef := processRefForReflog(info.ref) + if headHash != "" && headHash == info.commitHash { + line = append(line, fmt.Sprintf("\033[33m(\033[36;1mHEAD -> %s\033[33m)\033[0m", processedRef)) // HEAD in cyan (36;1) + } else { + line = append(line, fmt.Sprintf("\033[33m(%s\033[33m)\033[0m", processedRef)) // () in yellow (33m) + } line = append(line, fmt.Sprintf("%s\n", info.commitMessage)) pager.Writer.Write([]byte(strings.Join(line, " "))) } @@ -178,13 +182,9 @@ func reflogToStdOut(reflogInfo []ReflogInfo, curBranch string) { } // processRefForReflog takes a full ref (e.g. refs/heads/master) or tag name and returns the ref name (e.g. master) with relevant decoration. -func processRefForReflog(fullRef string, curBranch string) string { +func processRefForReflog(fullRef string) string { if strings.HasPrefix(fullRef, "refs/heads/") { - branch := strings.TrimPrefix(fullRef, "refs/heads/") - if curBranch != "" && branch == curBranch { - return fmt.Sprintf("\033[36;1mHEAD -> \033[32;1m%s\033[0m", branch) // HEAD in cyan (36;1), branch in green (32;1m) - } - return fmt.Sprintf("\033[32;1m%s\033[0m", branch) // branch in green (32;1m) + return fmt.Sprintf("\033[32;1m%s\033[0m", strings.TrimPrefix(fullRef, "refs/heads/")) // branch in green (32;1m) } else if strings.HasPrefix(fullRef, "refs/tags/") { return fmt.Sprintf("\033[33mtag: %s\033[0m", strings.TrimPrefix(fullRef, "refs/tags/")) // tag in yellow (33m) } else if strings.HasPrefix(fullRef, "refs/remotes/") { diff --git a/go/libraries/doltcore/sqle/reflog_table_function.go b/go/libraries/doltcore/sqle/reflog_table_function.go index f0124a84bf..5e9ef5f37d 100644 --- a/go/libraries/doltcore/sqle/reflog_table_function.go +++ b/go/libraries/doltcore/sqle/reflog_table_function.go @@ -29,10 +29,9 @@ import ( ) type ReflogTableFunction struct { - ctx *sql.Context - database sql.Database - refExpr sql.Expression - showAll bool + ctx *sql.Context + database sql.Database + refAndArgExprs []sql.Expression } var _ sql.TableFunction = (*ReflogTableFunction)(nil) @@ -66,17 +65,30 @@ func (rltf *ReflogTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.Row } var refName string - if rltf.refExpr != nil { - target, err := rltf.refExpr.Eval(ctx, row) + showAll := false + for _, expr := range rltf.refAndArgExprs { + target, err := expr.Eval(ctx, row) if err != nil { return nil, fmt.Errorf("error evaluating expression (%s): %s", - rltf.refExpr.String(), err.Error()) + expr.String(), err.Error()) } - - refName, ok = target.(string) + targetStr, ok := target.(string) if !ok { return nil, fmt.Errorf("argument (%v) is not a string value, but a %T", target, target) } + + if targetStr == "--all" { + if showAll { + return nil, fmt.Errorf("error: multiple values provided for `all`") + } + showAll = true + } else { + if refName != "" { + return nil, fmt.Errorf("error: %s has too many positional arguments. Expected at most %d, found %d: %s", + rltf.Name(), 1, 2, rltf.refAndArgExprs) + } + refName = targetStr + } } ddb := sqlDb.DbData().Ddb @@ -112,7 +124,7 @@ func (rltf *ReflogTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.Row } // skip workspace refs by default if doltRef.GetType() == ref.WorkspaceRefType { - if !rltf.showAll { + if !showAll { return nil } } @@ -179,19 +191,17 @@ func (rltf *ReflogTableFunction) Schema() sql.Schema { } func (rltf *ReflogTableFunction) Resolved() bool { - if rltf.refExpr != nil { - return rltf.refExpr.Resolved() + for _, expr := range rltf.refAndArgExprs { + return expr.Resolved() } return true } func (rltf *ReflogTableFunction) String() string { var args []string - if rltf.showAll { - args = append(args, "'--all'") - } - if rltf.refExpr != nil { - args = append(args, rltf.refExpr.String()) + + for _, expr := range rltf.refAndArgExprs { + args = append(args, expr.String()) } return fmt.Sprintf("DOLT_REFLOG(%s)", strings.Join(args, ", ")) } @@ -218,10 +228,7 @@ func (rltf *ReflogTableFunction) IsReadOnly() bool { } func (rltf *ReflogTableFunction) Expressions() []sql.Expression { - if rltf.refExpr != nil { - return []sql.Expression{rltf.refExpr} - } - return []sql.Expression{} + return rltf.refAndArgExprs } func (rltf *ReflogTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) { @@ -230,22 +237,7 @@ func (rltf *ReflogTableFunction) WithExpressions(expression ...sql.Expression) ( } new := *rltf - - if len(expression) == 2 { - if expression[0].String() == "'--all'" && expression[1].String() == "'--all'" { - return nil, fmt.Errorf("error: multiple values provided for `all`") - } - if expression[0].String() != "'--all'" && expression[1].String() != "'--all'" { - return nil, fmt.Errorf("error: %s has too many positional arguments. Expected at most %d, found %d: %s", rltf.Name(), 1, 2, expression) - } - } - for _, expr := range expression { - if expr.String() != "'--all'" { - new.refExpr = expr - } else { - new.showAll = true - } - } + new.refAndArgExprs = expression return &new, nil } diff --git a/integration-tests/bats/reflog.bats b/integration-tests/bats/reflog.bats index 93d54fda02..942f983466 100755 --- a/integration-tests/bats/reflog.bats +++ b/integration-tests/bats/reflog.bats @@ -241,3 +241,20 @@ SQL [ "$status" -eq 0 ] [ "${#lines[@]}" -eq 0 ] } + +@test "reflog: 'HEAD -> ' decoration only appears on HEAD entries" { + setup_common + + dolt sql -q "create table t (i int primary key, j int);" + dolt sql -q "insert into t values (1, 1), (2, 2), (3, 3)"; + dolt commit -Am "initial commit" + + run dolt reflog + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + line1=$(echo "${lines[0]}" | sed -E 's/\x1b\[[0-9;]*m//g') # remove special characters for color + line2=$(echo "${lines[1]}" | sed -E 's/\x1b\[[0-9;]*m//g') # remove special characters for color + [[ "$line1" =~ "(HEAD -> main) initial commit" ]] || false + [[ "$line2" =~ "Initialize data repository" ]] || false + [[ ! "$line2" =~ "HEAD" ]] || false +}