From f13bedece9f21e21e10b1e33d12918d5db9df234 Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Wed, 7 Sep 2022 15:54:22 -0700 Subject: [PATCH 1/5] added Get and Put benchmarks --- go/store/skip/list_test.go | 119 +++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/go/store/skip/list_test.go b/go/store/skip/list_test.go index 5a03d62f1c..a650ada665 100644 --- a/go/store/skip/list_test.go +++ b/go/store/skip/list_test.go @@ -86,6 +86,125 @@ func TestMemoryFootprint(t *testing.T) { assert.Equal(t, 20, sz) } +func BenchmarkList(b *testing.B) { + b.Run("benchmark Get", func(b *testing.B) { + b.Run("n=64", func(b *testing.B) { + vals := randomInts(64) + l := NewSkipList(bytes.Compare) + for i := range vals { + l.Put(vals[i], vals[i]) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, ok := l.Get(vals[i%64]) + if !ok { + b.Fail() + } + } + b.ReportAllocs() + }) + b.Run("n=512", func(b *testing.B) { + vals := randomInts(512) + l := NewSkipList(bytes.Compare) + for i := range vals { + l.Put(vals[i], vals[i]) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, ok := l.Get(vals[i%512]) + if !ok { + b.Fail() + } + } + b.ReportAllocs() + }) + b.Run("n=4096", func(b *testing.B) { + vals := randomInts(4096) + l := NewSkipList(bytes.Compare) + for i := range vals { + l.Put(vals[i], vals[i]) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, ok := l.Get(vals[i%4096]) + if !ok { + b.Fail() + } + } + b.ReportAllocs() + }) + b.Run("n=32768", func(b *testing.B) { + vals := randomInts(32768) + l := NewSkipList(bytes.Compare) + for i := range vals { + l.Put(vals[i], vals[i]) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, ok := l.Get(vals[i%32768]) + if !ok { + b.Fail() + } + } + b.ReportAllocs() + }) + }) + b.Run("benchmark Put", func(b *testing.B) { + b.Run("n=64", func(b *testing.B) { + vals := randomInts(64) + l := NewSkipList(bytes.Compare) + for i := 0; i < b.N; i++ { + j := i % 64 + if j == 0 { + l.Truncate() + } + l.Put(vals[j], vals[j]) + } + b.ReportAllocs() + }) + b.Run("n=512", func(b *testing.B) { + vals := randomInts(512) + l := NewSkipList(bytes.Compare) + for i := 0; i < b.N; i++ { + j := i % 512 + if j == 0 { + l.Truncate() + } + l.Put(vals[j], vals[j]) + } + b.ReportAllocs() + }) + b.Run("n=4096", func(b *testing.B) { + vals := randomInts(4096) + l := NewSkipList(bytes.Compare) + for i := 0; i < b.N; i++ { + j := i % 4096 + if j == 0 { + l.Truncate() + } + l.Put(vals[j], vals[j]) + } + b.ReportAllocs() + }) + b.Run("n=32768", func(b *testing.B) { + vals := randomInts(32768) + l := NewSkipList(bytes.Compare) + for i := 0; i < b.N; i++ { + j := i % 32768 + if j == 0 { + l.Truncate() + } + l.Put(vals[j], vals[j]) + } + b.ReportAllocs() + }) + }) +} + func testSkipList(t *testing.T, compare ValueCmp, vals ...[]byte) { src.Shuffle(len(vals), func(i, j int) { vals[i], vals[j] = vals[j], vals[i] From b336716fe01ac1d23d0d1967f264a63c02d3541e Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Wed, 7 Sep 2022 16:36:55 -0700 Subject: [PATCH 2/5] increase max skip.List height from 5 to 10, improve height promotion probabilities, cleanup comparator callback types --- go/store/prolly/range_iter.go | 6 +- go/store/skip/list.go | 106 ++++++++++++++-------------------- go/store/skip/list_test.go | 46 +++++++-------- 3 files changed, 68 insertions(+), 90 deletions(-) diff --git a/go/store/prolly/range_iter.go b/go/store/prolly/range_iter.go index e6ae226648..6aff9b3fb1 100644 --- a/go/store/prolly/range_iter.go +++ b/go/store/prolly/range_iter.go @@ -133,11 +133,11 @@ func memIterFromRange(list *skip.List, rng Range) *memRangeIter { } } -// skipSearchFromRange is a skip.SearchFn used to initialize -// a skip.List iterator for a given Range. The skip.SearchFn +// skipSearchFromRange is a skip.SeekFn used to initialize +// a skip.List iterator for a given Range. The skip.SeekFn // returns true if the iter being initialized is not yet // within the bounds of Range |rng|. -func skipSearchFromRange(rng Range) skip.SearchFn { +func skipSearchFromRange(rng Range) skip.SeekFn { return func(nodeKey []byte) bool { if nodeKey == nil { return false diff --git a/go/store/skip/list.go b/go/store/skip/list.go index 878d124c0b..91fde8457c 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -17,35 +17,28 @@ package skip import ( "math" "math/rand" - - "github.com/zeebo/xxh3" ) const ( - maxCount = math.MaxUint32 - 1 - - maxHeight = uint8(5) - highest = maxHeight - 1 - + maxHeight = 9 + maxCount = math.MaxUint32 - 1 sentinelId = nodeId(0) ) +type KeyOrder func(l, r []byte) (cmp int) + +type SeekFn func(key []byte) (advance bool) + type List struct { - nodes []skipNode - count uint32 - + nodes []skipNode + count uint32 checkpoint nodeId - cmp ValueCmp - salt uint64 + keyOrder KeyOrder } -type ValueCmp func(left, right []byte) int - -type SearchFn func(nodeKey []byte) bool - type nodeId uint32 -type skipPointer [maxHeight]nodeId +type skipPointer [maxHeight + 1]nodeId type skipNode struct { key, val []byte @@ -56,7 +49,7 @@ type skipNode struct { height uint8 } -func NewSkipList(cmp ValueCmp) *List { +func NewSkipList(order KeyOrder) *List { nodes := make([]skipNode, 0, 8) // initialize sentinel node @@ -71,8 +64,7 @@ func NewSkipList(cmp ValueCmp) *List { return &List{ nodes: nodes, checkpoint: nodeId(1), - cmp: cmp, - salt: rand.Uint64(), + keyOrder: order, } } @@ -123,7 +115,7 @@ func (l *List) Put(key, val []byte) { if key == nil { panic("key must be non-nil") } - if l.Count() >= maxCount { + if len(l.nodes) >= maxCount { panic("list has no capacity") } @@ -147,7 +139,7 @@ func (l *List) pathToKey(key []byte) (path skipPointer) { next := l.headPointer() prev := sentinelId - for lvl := int(highest); lvl >= 0; { + for lvl := int(maxHeight); lvl >= 0; { curr := l.getNode(next[lvl]) // descend if we can't advance at |lvl| @@ -168,7 +160,7 @@ func (l *List) pathBeforeKey(key []byte) (path skipPointer) { next := l.headPointer() prev := sentinelId - for lvl := int(highest); lvl >= 0; { + for lvl := int(maxHeight); lvl >= 0; { curr := l.getNode(next[lvl]) // descend if we can't advance at |lvl| @@ -190,7 +182,7 @@ func (l *List) insert(key, value []byte, path skipPointer) { key: key, val: value, id: l.nextNodeId(), - height: rollHeight(key, l.salt), + height: rollHeight(), } l.nodes = append(l.nodes, novel) @@ -255,22 +247,20 @@ func (it *ListIter) Retreat() { func (l *List) GetIterAt(key []byte) (it *ListIter) { return l.GetIterFromSearchFn(func(nodeKey []byte) bool { - return l.compareKeysWithFn(key, nodeKey, l.cmp) > 0 + return l.compareKeys(key, nodeKey) > 0 }) } -func (l *List) GetIterFromSearchFn(kontinue SearchFn) (it *ListIter) { +func (l *List) GetIterFromSearchFn(fn SeekFn) (it *ListIter) { it = &ListIter{ - curr: l.seekWithSearchFn(kontinue), + curr: l.seekWithFn(fn), list: l, } - if it.curr.id == sentinelId { // try to keep |it| in bounds if |key| is // greater than the largest key in |l| it.Retreat() } - return } @@ -290,20 +280,16 @@ func (l *List) IterAtEnd() *ListIter { // seek returns the skipNode with the smallest key >= |key|. func (l *List) seek(key []byte) skipNode { - return l.seekWithCompare(key, l.cmp) -} - -func (l *List) seekWithCompare(key []byte, cmp ValueCmp) (node skipNode) { - return l.seekWithSearchFn(func(nodeKey []byte) bool { - return l.compareKeysWithFn(key, nodeKey, cmp) > 0 + return l.seekWithFn(func(curr []byte) (advance bool) { + return l.compareKeys(key, curr) > 0 }) } -func (l *List) seekWithSearchFn(kontinue SearchFn) (node skipNode) { +func (l *List) seekWithFn(cb SeekFn) (node skipNode) { ptr := l.headPointer() - for h := int64(highest); h >= 0; h-- { + for h := int64(maxHeight); h >= 0; h-- { node = l.getNode(ptr[h]) - for kontinue(node.key) { + for cb(node.key) { ptr = node.next node = l.getNode(ptr[h]) } @@ -336,43 +322,35 @@ func (l *List) nextNodeId() nodeId { return nodeId(len(l.nodes)) } -func (l *List) compare(left, right skipNode) int { - return l.compareKeys(left.key, right.key) -} - func (l *List) compareKeys(left, right []byte) int { - return l.compareKeysWithFn(left, right, l.cmp) -} - -func (l *List) compareKeysWithFn(left, right []byte, cmp ValueCmp) int { if right == nil { return -1 // |right| is sentinel key } - return cmp(left, right) + return l.keyOrder(left, right) } -const ( - pattern0 = uint64(1<<3 - 1) - pattern1 = uint64(1<<6 - 1) - pattern2 = uint64(1<<9 - 1) - pattern3 = uint64(1<<12 - 1) +var ( + // Precompute the skiplist probabilities so that the optimal + // p-value can be used (inverse of Euler's number). + // + // https://github.com/andy-kimball/arenaskl/blob/master/skl.go + probabilities = [maxHeight]uint32{} + randSrc = rand.New(rand.NewSource(rand.Int63())) ) -func rollHeight(key []byte, salt uint64) (h uint8) { - roll := xxh3.HashSeed(key, salt) - patterns := []uint64{ - pattern0, - pattern1, - pattern2, - pattern3, +func init() { + p := float64(1.0) + for i := uint8(0); i < maxHeight; i++ { + p /= math.E + probabilities[i] = uint32(float64(math.MaxUint32) * p) } +} - for _, pat := range patterns { - if uint64(roll)&pat != pat { - break - } +func rollHeight() (h uint8) { + rnd := randSrc.Uint32() + h = 0 + for h < maxHeight && rnd <= probabilities[h] { h++ } - return } diff --git a/go/store/skip/list_test.go b/go/store/skip/list_test.go index a650ada665..2415a69113 100644 --- a/go/store/skip/list_test.go +++ b/go/store/skip/list_test.go @@ -25,10 +25,10 @@ import ( "github.com/stretchr/testify/assert" ) -// var src = rand.New(rand.NewSource(time.Now().Unix())) -var src = rand.New(rand.NewSource(0)) - func TestSkipList(t *testing.T) { + // set constant seed to improve debugging + randSrc = rand.New(rand.NewSource(0)) + t.Run("test skip list", func(t *testing.T) { vals := [][]byte{ b("a"), b("b"), b("c"), b("d"), b("e"), @@ -39,7 +39,7 @@ func TestSkipList(t *testing.T) { }) t.Run("test skip list of random bytes", func(t *testing.T) { - vals := randomVals((src.Int63() % 10_000) + 100) + vals := randomVals((randSrc.Int63() % 10_000) + 100) testSkipList(t, bytes.Compare, vals...) }) t.Run("test with custom compare function", func(t *testing.T) { @@ -48,7 +48,7 @@ func TestSkipList(t *testing.T) { r := int64(binary.LittleEndian.Uint64(right)) return int(l - r) } - vals := randomInts((src.Int63() % 10_000) + 100) + vals := randomInts((randSrc.Int63() % 10_000) + 100) testSkipList(t, compare, vals...) }) } @@ -64,7 +64,7 @@ func TestSkipListCheckpoints(t *testing.T) { }) t.Run("test skip list of random bytes", func(t *testing.T) { - vals := randomVals((src.Int63() % 10_000) + 100) + vals := randomVals((randSrc.Int63() % 10_000) + 100) testSkipListCheckpoints(t, bytes.Compare, vals...) }) t.Run("test with custom compare function", func(t *testing.T) { @@ -73,7 +73,7 @@ func TestSkipListCheckpoints(t *testing.T) { r := int64(binary.LittleEndian.Uint64(right)) return int(l - r) } - vals := randomInts((src.Int63() % 10_000) + 100) + vals := randomInts((randSrc.Int63() % 10_000) + 100) testSkipListCheckpoints(t, compare, vals...) }) } @@ -81,9 +81,9 @@ func TestSkipListCheckpoints(t *testing.T) { func TestMemoryFootprint(t *testing.T) { var sz int sz = int(unsafe.Sizeof(skipNode{})) - assert.Equal(t, 80, sz) + assert.Equal(t, 104, sz) sz = int(unsafe.Sizeof(skipPointer{})) - assert.Equal(t, 20, sz) + assert.Equal(t, 40, sz) } func BenchmarkList(b *testing.B) { @@ -205,8 +205,8 @@ func BenchmarkList(b *testing.B) { }) } -func testSkipList(t *testing.T, compare ValueCmp, vals ...[]byte) { - src.Shuffle(len(vals), func(i, j int) { +func testSkipList(t *testing.T, compare KeyOrder, vals ...[]byte) { + randSrc.Shuffle(len(vals), func(i, j int) { vals[i], vals[j] = vals[j], vals[i] }) @@ -244,8 +244,8 @@ func testSkipListPuts(t *testing.T, list *List, vals ...[]byte) { } func testSkipListGets(t *testing.T, list *List, vals ...[]byte) { - // get in different order - src.Shuffle(len(vals), func(i, j int) { + // get in different keyOrder + randSrc.Shuffle(len(vals), func(i, j int) { vals[i], vals[j] = vals[j], vals[i] }) @@ -268,7 +268,7 @@ func testSkipListUpdates(t *testing.T, list *List, vals ...[]byte) { } assert.Equal(t, len(vals), list.Count()) - src.Shuffle(len(vals), func(i, j int) { + randSrc.Shuffle(len(vals), func(i, j int) { vals[i], vals[j] = vals[j], vals[i] }) for _, exp := range vals { @@ -282,7 +282,7 @@ func testSkipListUpdates(t *testing.T, list *List, vals ...[]byte) { } func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) { - // put |vals| back in order + // put |vals| back in keyOrder sort.Slice(vals, func(i, j int) bool { return list.compareKeys(vals[i], vals[j]) < 0 }) @@ -297,7 +297,7 @@ func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) { // test iter at for k := 0; k < 10; k++ { - idx = src.Int() % len(vals) + idx = randSrc.Int() % len(vals) key := vals[idx] act := validateIterForwardFrom(t, list, key) exp := len(vals) - idx @@ -311,14 +311,14 @@ func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) { } func testSkipListIterBackward(t *testing.T, list *List, vals ...[]byte) { - // put |vals| back in order + // put |vals| back in keyOrder sort.Slice(vals, func(i, j int) bool { return list.compareKeys(vals[i], vals[j]) < 0 }) // test iter at for k := 0; k < 10; k++ { - idx := src.Int() % len(vals) + idx := randSrc.Int() % len(vals) key := vals[idx] act := validateIterBackwardFrom(t, list, key) assert.Equal(t, idx+1, act) @@ -395,8 +395,8 @@ func validateIterBackwardFrom(t *testing.T, l *List, key []byte) (count int) { func randomVals(cnt int64) (vals [][]byte) { vals = make([][]byte, cnt) for i := range vals { - bb := make([]byte, (src.Int63()%91)+10) - src.Read(bb) + bb := make([]byte, (randSrc.Int63()%91)+10) + randSrc.Read(bb) vals[i] = bb } return @@ -406,7 +406,7 @@ func randomInts(cnt int64) (vals [][]byte) { vals = make([][]byte, cnt) for i := range vals { vals[i] = make([]byte, 8) - v := uint64(src.Int63()) + v := uint64(randSrc.Int63()) binary.LittleEndian.PutUint64(vals[i], v) } return @@ -436,8 +436,8 @@ func iterAllBackwards(l *List, cb func([]byte, []byte)) { } } -func testSkipListCheckpoints(t *testing.T, compare ValueCmp, data ...[]byte) { - src.Shuffle(len(data), func(i, j int) { +func testSkipListCheckpoints(t *testing.T, compare KeyOrder, data ...[]byte) { + randSrc.Shuffle(len(data), func(i, j int) { data[i], data[j] = data[j], data[i] }) From dd42b134d2c2956522752310e50d55f35e6afc2d Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Wed, 7 Sep 2022 17:03:29 -0700 Subject: [PATCH 3/5] improve skip.List documentation --- go/store/skip/list.go | 68 +++++++++++++++++++++++++++----------- go/store/skip/list_test.go | 2 +- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/go/store/skip/list.go b/go/store/skip/list.go index 91fde8457c..a5e059fc10 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -25,30 +25,47 @@ const ( sentinelId = nodeId(0) ) +// A KeyOrder determines the ordering of two keys |l| and |r|. type KeyOrder func(l, r []byte) (cmp int) +// A SeekFn facilitates seeking into a List. It returns true +// if the seek operation should advance past |key|. type SeekFn func(key []byte) (advance bool) +// List is an in-memory skip-list. type List struct { - nodes []skipNode - count uint32 + // nodes contains all skipNode's in the List. + // skipNode's are assigned ascending id's and + // are stored in the order they were created, + // i.e. skipNode.id stores its index in |nodes| + nodes []skipNode + + // count stores the current number of items in + // the list (updates are not made in-place) + count uint32 + + // checkpoint stores the nodeId of the last + // checkpoint made. All nodes created after this + // point will be discarded on a Revert() checkpoint nodeId - keyOrder KeyOrder + + // keyOrder determines the ordering of items + keyOrder KeyOrder } type nodeId uint32 -type skipPointer [maxHeight + 1]nodeId +type tower [maxHeight + 1]nodeId type skipNode struct { key, val []byte - - id nodeId - next skipPointer - prev nodeId - height uint8 + id nodeId + next tower + prev nodeId + height uint8 } +// NewSkipList returns a new skip.List. func NewSkipList(order KeyOrder) *List { nodes := make([]skipNode, 0, 8) @@ -57,7 +74,7 @@ func NewSkipList(order KeyOrder) *List { id: sentinelId, key: nil, val: nil, height: maxHeight, - next: skipPointer{}, + next: tower{}, prev: sentinelId, }) @@ -87,21 +104,27 @@ func (l *List) Truncate() { l.nodes = l.nodes[:1] // point sentinel.prev at itself s := l.getNode(sentinelId) - s.next = skipPointer{} + s.next = tower{} s.prev = sentinelId l.updateNode(s) + l.checkpoint = nodeId(1) l.count = 0 } +// Count returns the number of items in the list. func (l *List) Count() int { return int(l.count) } +// Has returns true if |key| is a member of the list. func (l *List) Has(key []byte) (ok bool) { _, ok = l.Get(key) return } +// Get returns the value associated with |key| and true +// if |key| is a member of the list, otherwise it returns +// nil and false. func (l *List) Get(key []byte) (val []byte, ok bool) { path := l.pathToKey(key) node := l.getNode(path[0]) @@ -111,6 +134,7 @@ func (l *List) Get(key []byte) (val []byte, ok bool) { return } +// Put adds |key| and |values| to the list. func (l *List) Put(key, val []byte) { if key == nil { panic("key must be non-nil") @@ -135,7 +159,7 @@ func (l *List) Put(key, val []byte) { } } -func (l *List) pathToKey(key []byte) (path skipPointer) { +func (l *List) pathToKey(key []byte) (path tower) { next := l.headPointer() prev := sentinelId @@ -156,7 +180,7 @@ func (l *List) pathToKey(key []byte) (path skipPointer) { return } -func (l *List) pathBeforeKey(key []byte) (path skipPointer) { +func (l *List) pathBeforeKey(key []byte) (path tower) { next := l.headPointer() prev := sentinelId @@ -177,7 +201,7 @@ func (l *List) pathBeforeKey(key []byte) (path skipPointer) { return } -func (l *List) insert(key, value []byte, path skipPointer) { +func (l *List) insert(key, value []byte, path tower) { novel := skipNode{ key: key, val: value, @@ -202,7 +226,7 @@ func (l *List) insert(key, value []byte, path skipPointer) { l.updateNode(n) } -func (l *List) overwrite(key, value []byte, path skipPointer, old skipNode) { +func (l *List) overwrite(key, value []byte, path tower, old skipNode) { novel := old novel.id = l.nextNodeId() novel.key = key @@ -227,30 +251,32 @@ type ListIter struct { list *List } -func (it *ListIter) Count() int { - return it.list.Count() -} - +// Current returns the current key and value of the iterator. func (it *ListIter) Current() (key, val []byte) { return it.curr.key, it.curr.val } +// Advance advances the iterator. func (it *ListIter) Advance() { it.curr = it.list.getNode(it.curr.next[0]) return } +// Retreat retreats the iterator. func (it *ListIter) Retreat() { it.curr = it.list.getNode(it.curr.prev) return } +// GetIterAt creates an iterator starting at the first item +// of the list whose key is greater than or equal to |key|. func (l *List) GetIterAt(key []byte) (it *ListIter) { return l.GetIterFromSearchFn(func(nodeKey []byte) bool { return l.compareKeys(key, nodeKey) > 0 }) } +// GetIterFromSearchFn creates an iterator using a SeekFn. func (l *List) GetIterFromSearchFn(fn SeekFn) (it *ListIter) { it = &ListIter{ curr: l.seekWithFn(fn), @@ -264,6 +290,7 @@ func (l *List) GetIterFromSearchFn(fn SeekFn) (it *ListIter) { return } +// IterAtStart creates an iterator at the start of the list. func (l *List) IterAtStart() *ListIter { return &ListIter{ curr: l.firstNode(), @@ -271,6 +298,7 @@ func (l *List) IterAtStart() *ListIter { } } +// IterAtEnd creates an iterator at the end of the list. func (l *List) IterAtEnd() *ListIter { return &ListIter{ curr: l.lastNode(), @@ -297,7 +325,7 @@ func (l *List) seekWithFn(cb SeekFn) (node skipNode) { return } -func (l *List) headPointer() skipPointer { +func (l *List) headPointer() tower { return l.nodes[0].next } diff --git a/go/store/skip/list_test.go b/go/store/skip/list_test.go index 2415a69113..05720cd8d8 100644 --- a/go/store/skip/list_test.go +++ b/go/store/skip/list_test.go @@ -82,7 +82,7 @@ func TestMemoryFootprint(t *testing.T) { var sz int sz = int(unsafe.Sizeof(skipNode{})) assert.Equal(t, 104, sz) - sz = int(unsafe.Sizeof(skipPointer{})) + sz = int(unsafe.Sizeof(tower{})) assert.Equal(t, 40, sz) } From e95cb86cd606d3e371a14d213de869fbdd11eb9f Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Wed, 7 Sep 2022 17:08:05 -0700 Subject: [PATCH 4/5] fix Revert() bug --- go/store/skip/list.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/store/skip/list.go b/go/store/skip/list.go index a5e059fc10..27661472ee 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -92,11 +92,13 @@ func (l *List) Checkpoint() { // Revert reverts to the last recorded checkpoint. func (l *List) Revert() { - keepers := l.nodes[1:l.checkpoint] + cp := l.checkpoint + keepers := l.nodes[1:cp] l.Truncate() for _, nd := range keepers { l.Put(nd.key, nd.val) } + l.checkpoint = cp } // Truncate deletes all entries from the list. From 2e0cc823daeb3be5107bcbf90c98fb4e183c9e96 Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Thu, 8 Sep 2022 10:22:07 -0700 Subject: [PATCH 5/5] fix race on global RNG --- go/store/skip/list.go | 16 +++++++++------- go/store/skip/list_test.go | 5 ++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/go/store/skip/list.go b/go/store/skip/list.go index 27661472ee..2a371ff9c0 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -15,8 +15,8 @@ package skip import ( + "hash/maphash" "math" - "math/rand" ) const ( @@ -51,6 +51,9 @@ type List struct { // keyOrder determines the ordering of items keyOrder KeyOrder + + // seed is hash salt + seed maphash.Seed } type nodeId uint32 @@ -82,6 +85,7 @@ func NewSkipList(order KeyOrder) *List { nodes: nodes, checkpoint: nodeId(1), keyOrder: order, + seed: maphash.MakeSeed(), } } @@ -208,7 +212,7 @@ func (l *List) insert(key, value []byte, path tower) { key: key, val: value, id: l.nextNodeId(), - height: rollHeight(), + height: l.rollHeight(key), } l.nodes = append(l.nodes, novel) @@ -365,7 +369,6 @@ var ( // // https://github.com/andy-kimball/arenaskl/blob/master/skl.go probabilities = [maxHeight]uint32{} - randSrc = rand.New(rand.NewSource(rand.Int63())) ) func init() { @@ -376,10 +379,9 @@ func init() { } } -func rollHeight() (h uint8) { - rnd := randSrc.Uint32() - h = 0 - for h < maxHeight && rnd <= probabilities[h] { +func (l *List) rollHeight(key []byte) (h uint8) { + rnd := maphash.Bytes(l.seed, key) + for h < maxHeight && uint32(rnd) <= probabilities[h] { h++ } return diff --git a/go/store/skip/list_test.go b/go/store/skip/list_test.go index 05720cd8d8..4b0a72dff8 100644 --- a/go/store/skip/list_test.go +++ b/go/store/skip/list_test.go @@ -25,10 +25,9 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSkipList(t *testing.T) { - // set constant seed to improve debugging - randSrc = rand.New(rand.NewSource(0)) +var randSrc = rand.New(rand.NewSource(0)) +func TestSkipList(t *testing.T) { t.Run("test skip list", func(t *testing.T) { vals := [][]byte{ b("a"), b("b"), b("c"), b("d"), b("e"),