go: sqle/dsess: Serialize transactions on a given database branch, instead of globally on all databases and all branches.

This commit is contained in:
Aaron Son
2025-04-23 15:03:53 -07:00
parent 1e27c1093e
commit 6a26d5fee4
3 changed files with 207 additions and 3 deletions
@@ -31,6 +31,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/keymutex"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly"
@@ -158,7 +159,7 @@ func (tx DoltTransaction) GetInitialRoot(dbName string) (hash.Hash, bool) {
return startPoint.rootHash, ok
}
var txLock sync.Mutex
var txLocks = keymutex.NewMapped()
// Commit attempts to merge the working set given into the current working set.
// Uses the same algorithm as merge.RootMerger:
@@ -395,11 +396,16 @@ func (tx *DoltTransaction) doCommit(
mergeOpts := branchState.EditOpts()
lockID := dbName + "\u0000" + workingSet.Ref().String()
for i := 0; i < maxTxCommitRetries; i++ {
updatedWs, newCommit, err := func() (*doltdb.WorkingSet, *doltdb.Commit, error) {
// Serialize commits, since only one can possibly succeed at a time anyway
txLock.Lock()
defer txLock.Unlock()
err := txLocks.Lock(ctx, lockID)
if err != nil {
return nil, nil, err
}
defer txLocks.Unlock(lockID)
newWorkingSet := false
+88
View File
@@ -0,0 +1,88 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package keymutex
import (
"context"
"sync"
"golang.org/x/sync/semaphore"
)
// A keymutex allows a caller to gain exclusive access to a critical
// section associated with the provided |key|. No callers will enter
// the critical section concurrently, and a caller which arrives while
// the critical section is occupied will block until it is available.
//
// A keymutex's Lock function should respect Context cancelation.
type Keymutex interface {
Lock(ctx context.Context, id string) error
Unlock(id string)
}
// Returns a Keymutex which stores mutexes in a map. This Keymutex has
// relatively high per-lock overhead, but allows all separate |key|s
// to make concurrent progress.
func NewMapped() Keymutex {
return &mapKeymutex{
states: make(map[string]*mapKeymutexState),
}
}
type mapKeymutex struct {
mu sync.Mutex
states map[string]*mapKeymutexState
}
type mapKeymutexState struct {
sema *semaphore.Weighted
waitCnt int
}
func newMapKeymutexState() *mapKeymutexState {
return &mapKeymutexState{
sema: semaphore.NewWeighted(1),
}
}
func (m *mapKeymutex) Lock(ctx context.Context, id string) error {
m.mu.Lock()
defer m.mu.Unlock()
var state *mapKeymutexState
var ok bool
if state, ok = m.states[id]; !ok {
state = newMapKeymutexState()
m.states[id] = state
}
if state.sema.TryAcquire(1) {
return nil
}
state.waitCnt += 1
m.mu.Unlock()
err := state.sema.Acquire(ctx, 1)
m.mu.Lock()
state.waitCnt -= 1
return err
}
func (m *mapKeymutex) Unlock(id string) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.states[id]
state.sema.Release(1)
if state.waitCnt == 0 {
delete(m.states, id)
}
}
@@ -0,0 +1,110 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package keymutex
import (
"context"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMapped(t *testing.T) {
t.Run("Cleanup", func(t *testing.T) {
mutexes := NewMapped()
func() {
for _, s := range []string{"a", "b", "c", "d", "e", "f", "g"} {
require.NoError(t, mutexes.Lock(context.Background(), s))
defer mutexes.Unlock(s)
}
}()
assert.Len(t, mutexes.(*mapKeymutex).states, 0)
})
t.Run("Exclusion", func(t *testing.T) {
mutexes := NewMapped()
var wg sync.WaitGroup
var fours int
var eights int
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 512; i++ {
require.NoError(t, mutexes.Lock(context.Background(), "fours"))
fours += 1
mutexes.Unlock("fours")
}
}()
}
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 256; i++ {
require.NoError(t, mutexes.Lock(context.Background(), "eights"))
eights += 1
mutexes.Unlock("eights")
}
}()
}
wg.Wait()
assert.Equal(t, fours, 2048)
assert.Equal(t, eights, 2048)
})
t.Run("Canceled", func(t *testing.T) {
mutexes := NewMapped()
require.NoError(t, mutexes.Lock(context.Background(), "taken"))
ctx, cancel := context.WithCancel(context.Background())
cancel()
require.Error(t, mutexes.Lock(ctx, "taken"), context.Canceled)
var cancels []func()
var wg sync.WaitGroup
wg.Add(64)
for i := 0; i < 64; i++ {
ctx, cancel := context.WithCancel(context.Background())
cancels = append(cancels, cancel)
go func() {
defer wg.Done()
require.ErrorIs(t, mutexes.Lock(ctx, "taken"), context.Canceled)
}()
}
var successWg sync.WaitGroup
successWg.Add(1)
go func() {
defer successWg.Done()
require.NoError(t, mutexes.Lock(context.Background(), "taken"))
defer mutexes.Unlock("taken")
}()
for {
mutexes.(*mapKeymutex).mu.Lock()
if mutexes.(*mapKeymutex).states["taken"].waitCnt == 65 {
mutexes.(*mapKeymutex).mu.Unlock()
break
}
mutexes.(*mapKeymutex).mu.Unlock()
runtime.Gosched()
}
for _, f := range cancels {
f()
}
wg.Wait()
mutexes.Unlock("taken")
successWg.Wait()
})
}