mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-22 02:50:04 -05:00
Merge pull request #9075 from dolthub/macneale4/amend
Don't alter branch head outside of transactions when amending a commit
This commit is contained in:
@@ -303,12 +303,14 @@ type PendingCommit struct {
|
||||
// commit, once written.
|
||||
// |headRef| is the ref of the HEAD the commit will update
|
||||
// |mergeParentCommits| are any merge parents for this commit
|
||||
// |amend| is a flag which indicates that additional parents should not be added to the provided |mergeParentCommits|.
|
||||
// |cm| is the metadata for the commit
|
||||
// The current branch head will be automatically filled in as the first parent at commit time.
|
||||
func (ddb *DoltDB) NewPendingCommit(
|
||||
ctx context.Context,
|
||||
roots Roots,
|
||||
mergeParentCommits []*Commit,
|
||||
amend bool,
|
||||
cm *datas.CommitMeta,
|
||||
) (*PendingCommit, error) {
|
||||
newstaged, val, err := ddb.writeRootValue(ctx, roots.Staged)
|
||||
@@ -322,7 +324,7 @@ func (ddb *DoltDB) NewPendingCommit(
|
||||
parents = append(parents, pc.dCommit.Addr())
|
||||
}
|
||||
|
||||
commitOpts := datas.CommitOptions{Parents: parents, Meta: cm}
|
||||
commitOpts := datas.CommitOptions{Parents: parents, Meta: cm, Amend: amend}
|
||||
return &PendingCommit{
|
||||
Roots: roots,
|
||||
Val: val,
|
||||
|
||||
+1
-3
@@ -98,9 +98,7 @@ func GetCommitStaged(
|
||||
return nil, NewTblSchemaConflictError(schConflicts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !props.Force {
|
||||
roots.Staged, err = doltdb.ValidateForeignKeysOnSchemas(ctx, roots.Staged)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -112,5 +110,5 @@ func GetCommitStaged(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.NewPendingCommit(ctx, roots, mergeParents, meta)
|
||||
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
|
||||
}
|
||||
|
||||
@@ -291,7 +291,7 @@ func commitRoot(
|
||||
return err
|
||||
}
|
||||
|
||||
pcm, err := ddb.NewPendingCommit(ctx, roots, parents, meta)
|
||||
pcm, err := ddb.NewPendingCommit(ctx, roots, parents, false, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -728,7 +728,6 @@ func (d *DoltSession) NewPendingCommit(
|
||||
// See NewPendingCommit
|
||||
func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchState, roots doltdb.Roots, props actions.CommitStagedProps) (*doltdb.PendingCommit, error) {
|
||||
headCommit := branchState.headCommit
|
||||
headHash, _ := headCommit.HashOf()
|
||||
|
||||
if branchState.WorkingSet() == nil {
|
||||
return nil, doltdb.ErrOperationNotSupportedInDetachedHead
|
||||
@@ -755,52 +754,18 @@ func (d *DoltSession) newPendingCommit(ctx *sql.Context, branchState *branchStat
|
||||
// If the commit message isn't set and we're amending the previous commit,
|
||||
// go ahead and set the commit message from the current HEAD
|
||||
if props.Message == "" && props.Amend {
|
||||
cs, err := doltdb.NewCommitSpec("HEAD")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headRef, err := branchState.dbData.Rsr.CWBHeadRef()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
optCmt, err := branchState.dbData.Ddb.Resolve(ctx, cs, headRef)
|
||||
commit, ok := optCmt.ToCommit()
|
||||
if !ok {
|
||||
return nil, doltdb.ErrGhostCommitEncountered
|
||||
}
|
||||
|
||||
meta, err := commit.GetCommitMeta(ctx)
|
||||
meta, err := headCommit.GetCommitMeta(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
props.Message = meta.Description
|
||||
}
|
||||
|
||||
// TODO: This is not the correct way to write this commit as an amend. While this commit is running
|
||||
// the branch head moves backwards and concurrency control here is not principled.
|
||||
newRoots, err := actions.ResetSoftToRef(ctx, branchState.dbData, "HEAD~1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = d.SetWorkingSet(ctx, ctx.GetCurrentDatabase(), branchState.WorkingSet().WithStagedRoot(newRoots.Staged))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roots.Head = newRoots.Head
|
||||
}
|
||||
|
||||
pendingCommit, err := actions.GetCommitStaged(ctx, roots, branchState.WorkingSet(), mergeParentCommits, branchState.dbData.Ddb, props)
|
||||
if err != nil {
|
||||
if props.Amend {
|
||||
_, err = actions.ResetSoftToRef(ctx, branchState.dbData, headHash.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if _, ok := err.(actions.NothingStaged); err != nil && !ok {
|
||||
// Special case for nothing staged, which is not an error
|
||||
if _, ok := err.(actions.NothingStaged); !ok {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,5 +32,10 @@ type CommitOptions struct {
|
||||
// parent.
|
||||
Parents []hash.Hash
|
||||
|
||||
// Amend flag indicates that the commit being build it to amend an existing commit. Generally we add the branch HEAD
|
||||
// as a parent, in addition to the parent set provided here. When we amend, we want to strictly use the commits
|
||||
// provided in |Parents|, and no others.
|
||||
Amend bool
|
||||
|
||||
Meta *CommitMeta
|
||||
}
|
||||
|
||||
@@ -586,7 +586,7 @@ func (db *database) BuildNewCommit(ctx context.Context, ds Dataset, v types.Valu
|
||||
if ok {
|
||||
opts.Parents = []hash.Hash{headAddr}
|
||||
}
|
||||
} else {
|
||||
} else if !opts.Amend {
|
||||
curr, ok := ds.MaybeHeadAddr()
|
||||
if ok {
|
||||
if !hasParentHash(opts, curr) {
|
||||
@@ -901,7 +901,7 @@ func (db *database) CommitWithWorkingSet(
|
||||
|
||||
// Prepend the current head hash to the list of parents if one was provided. This is only necessary if parents were
|
||||
// provided because we fill it in automatically in buildNewCommit otherwise.
|
||||
if len(opts.Parents) > 0 {
|
||||
if len(opts.Parents) > 0 && !opts.Amend {
|
||||
headHash, ok := commitDS.MaybeHeadAddr()
|
||||
if ok {
|
||||
if !hasParentHash(opts, headHash) {
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
// Copyright 2024 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
|
||||
)
|
||||
|
||||
func TestCommitConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("SQL transaction with amend commit", testSQLTransactionWithAmendCommit)
|
||||
t.Run("SQL racing amend", testSQLRacingAmend)
|
||||
}
|
||||
|
||||
// testSQLTransactionWithAmendCommit verifies that two transactions started at the same state will not both be able
|
||||
// to commit using --amend. The first transaction will be able to commit, but the second should get an error.
|
||||
func testSQLTransactionWithAmendCommit(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
u, err := driver.NewDoltUser()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
u.Cleanup()
|
||||
})
|
||||
|
||||
rs, err := u.MakeRepoStore()
|
||||
require.NoError(t, err)
|
||||
repo, err := rs.MakeRepo("commit_concurrency_test")
|
||||
require.NoError(t, err)
|
||||
|
||||
srvSettings := &driver.Server{
|
||||
Args: []string{"--port", `{{get_port "server"}}`},
|
||||
DynamicPort: "server",
|
||||
}
|
||||
var ports DynamicPorts
|
||||
ports.global = &GlobalPorts
|
||||
ports.t = t
|
||||
server := MakeServer(t, repo, srvSettings, &ports)
|
||||
server.DBName = "commit_concurrency_test"
|
||||
|
||||
// Connect to the database
|
||||
db, err := server.DB(driver.Connection{User: "root"})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
_, err = db.ExecContext(ctx, `
|
||||
CREATE TABLE test_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
value VARCHAR(20)
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test_table (value) VALUES ('initial')")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new context for the first (failing) transaction
|
||||
ctx1, cancel1 := context.WithCancel(ctx)
|
||||
defer cancel1()
|
||||
tx1, err := db.BeginTx(ctx1, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = tx1.ExecContext(ctx1, "UPDATE test_table SET value = 'amended by tx1' WHERE id = 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new context for the second (succeeding) transaction
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
defer cancel2()
|
||||
tx2, err := db.BeginTx(ctx2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update data within the second transaction
|
||||
_, err = tx2.ExecContext(ctx2, "UPDATE test_table SET value = 'amended by tx2' WHERE id = 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx2.ExecContext(ctx2, "CALL DOLT_COMMIT('--amend', '-m', 'tx2 amended commit')")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Commit --amend will result in tx2 being committed. You can still make updates on tx1, but any commit should fail
|
||||
_, err = tx1.ExecContext(ctx1, "INSERT INTO test_table (value) VALUES ('new row by tx1')")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx1.ExecContext(ctx1, "CALL DOLT_COMMIT('--amend', '-m', 'should fail')")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this transaction conflicts with a committed transaction from another client, try restarting transaction")
|
||||
|
||||
// Verify that the data in the head is what we would expect
|
||||
row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1")
|
||||
var value string
|
||||
err = row.Scan(&value)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "amended by tx2", value)
|
||||
|
||||
// Verify the commit message
|
||||
row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1")
|
||||
var commitMessage string
|
||||
err = row.Scan(&commitMessage)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "tx2 amended commit", commitMessage)
|
||||
|
||||
}
|
||||
|
||||
func testSQLRacingAmend(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
u, err := driver.NewDoltUser()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
u.Cleanup()
|
||||
})
|
||||
|
||||
rs, err := u.MakeRepoStore()
|
||||
require.NoError(t, err)
|
||||
repo, err := rs.MakeRepo("racing_amend_test")
|
||||
require.NoError(t, err)
|
||||
|
||||
srvSettings := &driver.Server{
|
||||
Args: []string{"--port", `{{get_port "server"}}`},
|
||||
DynamicPort: "server",
|
||||
}
|
||||
var ports DynamicPorts
|
||||
ports.global = &GlobalPorts
|
||||
ports.t = t
|
||||
server := MakeServer(t, repo, srvSettings, &ports)
|
||||
server.DBName = "racing_amend_test"
|
||||
|
||||
db, err := server.DB(driver.Connection{User: "root"})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
_, err = db.ExecContext(ctx, `
|
||||
CREATE TABLE test_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
value VARCHAR(20)
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO test_table VALUES (1, 'initial')")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.ExecContext(ctx, "CALL DOLT_COMMIT('-A','-m', 'initial commit')")
|
||||
require.NoError(t, err)
|
||||
|
||||
type txIdFunc struct {
|
||||
txNum int
|
||||
txFunc func() error
|
||||
}
|
||||
|
||||
var transactions []txIdFunc
|
||||
for txNum := 1; txNum <= 200; txNum++ {
|
||||
txCtx, cancel := context.WithTimeout(ctx, 15*time.Second) // Should never hit, but just in case
|
||||
tx, err := db.BeginTx(txCtx, nil)
|
||||
require.NoError(t, err)
|
||||
f := func() error {
|
||||
defer cancel()
|
||||
// update required to get a transaction conflict.
|
||||
_, e2 := tx.ExecContext(txCtx, "UPDATE test_table SET value = ? WHERE id = 1", fmt.Sprintf("tx%d value", txNum))
|
||||
require.NoError(t, e2)
|
||||
_, e2 = tx.ExecContext(txCtx, "INSERT INTO test_table (value) VALUES (?)", fmt.Sprintf("tx%d new row", txNum))
|
||||
require.NoError(t, e2)
|
||||
|
||||
// We want other transactions to have the chance to mess with the db. We'll verify that what's committed is what we expect.
|
||||
time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond)
|
||||
|
||||
// This will commit the transaction, or error.
|
||||
_, e2 = tx.ExecContext(txCtx, "CALL DOLT_COMMIT('--amend','-a', '-m', ?)", fmt.Sprintf("tx%d amend", txNum))
|
||||
if e2 != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
return e2
|
||||
}
|
||||
transactions = append(transactions, txIdFunc{txNum: txNum, txFunc: f})
|
||||
}
|
||||
|
||||
rand.Shuffle(len(transactions), func(i, j int) {
|
||||
transactions[i], transactions[j] = transactions[j], transactions[i]
|
||||
})
|
||||
|
||||
var atomicInt atomic.Int32
|
||||
atomicInt.Store(-1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, txn := range transactions {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// radomly sleep .5 - 1.5 seconds
|
||||
time.Sleep(time.Duration(rand.Intn(1000)+500) * time.Millisecond)
|
||||
err := txn.txFunc()
|
||||
|
||||
if err == nil {
|
||||
// If there are multiple updates, something went wrong.
|
||||
require.True(t, atomicInt.CompareAndSwap(-1, int32(txn.txNum)))
|
||||
}
|
||||
// Errors are expected.
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
winner := atomicInt.Load()
|
||||
require.NotEqual(t, -1, winner)
|
||||
|
||||
// Verify there are only 2 rows in the table
|
||||
rows, err := db.QueryContext(ctx, "SELECT COUNT(*) FROM test_table")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
rows.Next()
|
||||
var count int
|
||||
err = rows.Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
// Verify the final state
|
||||
row := db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id = 1")
|
||||
var value string
|
||||
err = row.Scan(&value)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fmt.Sprintf("tx%d value", winner), value)
|
||||
|
||||
row = db.QueryRowContext(ctx, "SELECT value FROM test_table WHERE id != 1")
|
||||
err = row.Scan(&value)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fmt.Sprintf("tx%d new row", winner), value)
|
||||
|
||||
// Verify the commit message
|
||||
row = db.QueryRowContext(ctx, "SELECT message FROM dolt_log ORDER BY date DESC LIMIT 1")
|
||||
var commitMessage string
|
||||
err = row.Scan(&commitMessage)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fmt.Sprintf("tx%d amend", winner), commitMessage)
|
||||
}
|
||||
Reference in New Issue
Block a user