Merge pull request #4287 from dolthub/andy/update-skip-list

[no-release-notes] Document and refactor `skip.List`
This commit is contained in:
AndyA
2022-09-08 11:45:15 -07:00
committed by GitHub
3 changed files with 233 additions and 105 deletions

View File

@@ -133,11 +133,11 @@ func memIterFromRange(list *skip.List, rng Range) *memRangeIter {
}
}
// skipSearchFromRange is a skip.SearchFn used to initialize
// a skip.List iterator for a given Range. The skip.SearchFn
// skipSearchFromRange is a skip.SeekFn used to initialize
// a skip.List iterator for a given Range. The skip.SeekFn
// returns true if the iter being initialized is not yet
// within the bounds of Range |rng|.
func skipSearchFromRange(rng Range) skip.SearchFn {
func skipSearchFromRange(rng Range) skip.SeekFn {
return func(nodeKey []byte) bool {
if nodeKey == nil {
return false

View File

@@ -15,48 +15,61 @@
package skip
import (
"hash/maphash"
"math"
"math/rand"
"github.com/zeebo/xxh3"
)
const (
maxCount = math.MaxUint32 - 1
maxHeight = uint8(5)
highest = maxHeight - 1
maxHeight = 9
maxCount = math.MaxUint32 - 1
sentinelId = nodeId(0)
)
// A KeyOrder determines the ordering of two keys |l| and |r|.
type KeyOrder func(l, r []byte) (cmp int)
// A SeekFn facilitates seeking into a List. It returns true
// if the seek operation should advance past |key|.
type SeekFn func(key []byte) (advance bool)
// List is an in-memory skip-list.
type List struct {
// nodes contains all skipNode's in the List.
// skipNode's are assigned ascending id's and
// are stored in the order they were created,
// i.e. skipNode.id stores its index in |nodes|
nodes []skipNode
// count stores the current number of items in
// the list (updates are not made in-place)
count uint32
// checkpoint stores the nodeId of the last
// checkpoint made. All nodes created after this
// point will be discarded on a Revert()
checkpoint nodeId
cmp ValueCmp
salt uint64
// keyOrder determines the ordering of items
keyOrder KeyOrder
// seed is hash salt
seed maphash.Seed
}
type ValueCmp func(left, right []byte) int
type SearchFn func(nodeKey []byte) bool
type nodeId uint32
type skipPointer [maxHeight]nodeId
type tower [maxHeight + 1]nodeId
type skipNode struct {
key, val []byte
id nodeId
next skipPointer
prev nodeId
height uint8
id nodeId
next tower
prev nodeId
height uint8
}
func NewSkipList(cmp ValueCmp) *List {
// NewSkipList returns a new skip.List.
func NewSkipList(order KeyOrder) *List {
nodes := make([]skipNode, 0, 8)
// initialize sentinel node
@@ -64,15 +77,15 @@ func NewSkipList(cmp ValueCmp) *List {
id: sentinelId,
key: nil, val: nil,
height: maxHeight,
next: skipPointer{},
next: tower{},
prev: sentinelId,
})
return &List{
nodes: nodes,
checkpoint: nodeId(1),
cmp: cmp,
salt: rand.Uint64(),
keyOrder: order,
seed: maphash.MakeSeed(),
}
}
@@ -83,11 +96,13 @@ func (l *List) Checkpoint() {
// Revert reverts to the last recorded checkpoint.
func (l *List) Revert() {
keepers := l.nodes[1:l.checkpoint]
cp := l.checkpoint
keepers := l.nodes[1:cp]
l.Truncate()
for _, nd := range keepers {
l.Put(nd.key, nd.val)
}
l.checkpoint = cp
}
// Truncate deletes all entries from the list.
@@ -95,21 +110,27 @@ func (l *List) Truncate() {
l.nodes = l.nodes[:1]
// point sentinel.prev at itself
s := l.getNode(sentinelId)
s.next = skipPointer{}
s.next = tower{}
s.prev = sentinelId
l.updateNode(s)
l.checkpoint = nodeId(1)
l.count = 0
}
// Count returns the number of items in the list.
func (l *List) Count() int {
return int(l.count)
}
// Has returns true if |key| is a member of the list.
func (l *List) Has(key []byte) (ok bool) {
_, ok = l.Get(key)
return
}
// Get returns the value associated with |key| and true
// if |key| is a member of the list, otherwise it returns
// nil and false.
func (l *List) Get(key []byte) (val []byte, ok bool) {
path := l.pathToKey(key)
node := l.getNode(path[0])
@@ -119,11 +140,12 @@ func (l *List) Get(key []byte) (val []byte, ok bool) {
return
}
// Put adds |key| and |values| to the list.
func (l *List) Put(key, val []byte) {
if key == nil {
panic("key must be non-nil")
}
if l.Count() >= maxCount {
if len(l.nodes) >= maxCount {
panic("list has no capacity")
}
@@ -143,11 +165,11 @@ func (l *List) Put(key, val []byte) {
}
}
func (l *List) pathToKey(key []byte) (path skipPointer) {
func (l *List) pathToKey(key []byte) (path tower) {
next := l.headPointer()
prev := sentinelId
for lvl := int(highest); lvl >= 0; {
for lvl := int(maxHeight); lvl >= 0; {
curr := l.getNode(next[lvl])
// descend if we can't advance at |lvl|
@@ -164,11 +186,11 @@ func (l *List) pathToKey(key []byte) (path skipPointer) {
return
}
func (l *List) pathBeforeKey(key []byte) (path skipPointer) {
func (l *List) pathBeforeKey(key []byte) (path tower) {
next := l.headPointer()
prev := sentinelId
for lvl := int(highest); lvl >= 0; {
for lvl := int(maxHeight); lvl >= 0; {
curr := l.getNode(next[lvl])
// descend if we can't advance at |lvl|
@@ -185,12 +207,12 @@ func (l *List) pathBeforeKey(key []byte) (path skipPointer) {
return
}
func (l *List) insert(key, value []byte, path skipPointer) {
func (l *List) insert(key, value []byte, path tower) {
novel := skipNode{
key: key,
val: value,
id: l.nextNodeId(),
height: rollHeight(key, l.salt),
height: l.rollHeight(key),
}
l.nodes = append(l.nodes, novel)
@@ -210,7 +232,7 @@ func (l *List) insert(key, value []byte, path skipPointer) {
l.updateNode(n)
}
func (l *List) overwrite(key, value []byte, path skipPointer, old skipNode) {
func (l *List) overwrite(key, value []byte, path tower, old skipNode) {
novel := old
novel.id = l.nextNodeId()
novel.key = key
@@ -235,45 +257,46 @@ type ListIter struct {
list *List
}
func (it *ListIter) Count() int {
return it.list.Count()
}
// Current returns the current key and value of the iterator.
func (it *ListIter) Current() (key, val []byte) {
return it.curr.key, it.curr.val
}
// Advance advances the iterator.
func (it *ListIter) Advance() {
it.curr = it.list.getNode(it.curr.next[0])
return
}
// Retreat retreats the iterator.
func (it *ListIter) Retreat() {
it.curr = it.list.getNode(it.curr.prev)
return
}
// GetIterAt creates an iterator starting at the first item
// of the list whose key is greater than or equal to |key|.
func (l *List) GetIterAt(key []byte) (it *ListIter) {
return l.GetIterFromSearchFn(func(nodeKey []byte) bool {
return l.compareKeysWithFn(key, nodeKey, l.cmp) > 0
return l.compareKeys(key, nodeKey) > 0
})
}
func (l *List) GetIterFromSearchFn(kontinue SearchFn) (it *ListIter) {
// GetIterFromSearchFn creates an iterator using a SeekFn.
func (l *List) GetIterFromSearchFn(fn SeekFn) (it *ListIter) {
it = &ListIter{
curr: l.seekWithSearchFn(kontinue),
curr: l.seekWithFn(fn),
list: l,
}
if it.curr.id == sentinelId {
// try to keep |it| in bounds if |key| is
// greater than the largest key in |l|
it.Retreat()
}
return
}
// IterAtStart creates an iterator at the start of the list.
func (l *List) IterAtStart() *ListIter {
return &ListIter{
curr: l.firstNode(),
@@ -281,6 +304,7 @@ func (l *List) IterAtStart() *ListIter {
}
}
// IterAtEnd creates an iterator at the end of the list.
func (l *List) IterAtEnd() *ListIter {
return &ListIter{
curr: l.lastNode(),
@@ -290,20 +314,16 @@ func (l *List) IterAtEnd() *ListIter {
// seek returns the skipNode with the smallest key >= |key|.
func (l *List) seek(key []byte) skipNode {
return l.seekWithCompare(key, l.cmp)
}
func (l *List) seekWithCompare(key []byte, cmp ValueCmp) (node skipNode) {
return l.seekWithSearchFn(func(nodeKey []byte) bool {
return l.compareKeysWithFn(key, nodeKey, cmp) > 0
return l.seekWithFn(func(curr []byte) (advance bool) {
return l.compareKeys(key, curr) > 0
})
}
func (l *List) seekWithSearchFn(kontinue SearchFn) (node skipNode) {
func (l *List) seekWithFn(cb SeekFn) (node skipNode) {
ptr := l.headPointer()
for h := int64(highest); h >= 0; h-- {
for h := int64(maxHeight); h >= 0; h-- {
node = l.getNode(ptr[h])
for kontinue(node.key) {
for cb(node.key) {
ptr = node.next
node = l.getNode(ptr[h])
}
@@ -311,7 +331,7 @@ func (l *List) seekWithSearchFn(kontinue SearchFn) (node skipNode) {
return
}
func (l *List) headPointer() skipPointer {
func (l *List) headPointer() tower {
return l.nodes[0].next
}
@@ -336,43 +356,33 @@ func (l *List) nextNodeId() nodeId {
return nodeId(len(l.nodes))
}
func (l *List) compare(left, right skipNode) int {
return l.compareKeys(left.key, right.key)
}
func (l *List) compareKeys(left, right []byte) int {
return l.compareKeysWithFn(left, right, l.cmp)
}
func (l *List) compareKeysWithFn(left, right []byte, cmp ValueCmp) int {
if right == nil {
return -1 // |right| is sentinel key
}
return cmp(left, right)
return l.keyOrder(left, right)
}
const (
pattern0 = uint64(1<<3 - 1)
pattern1 = uint64(1<<6 - 1)
pattern2 = uint64(1<<9 - 1)
pattern3 = uint64(1<<12 - 1)
var (
// Precompute the skiplist probabilities so that the optimal
// p-value can be used (inverse of Euler's number).
//
// https://github.com/andy-kimball/arenaskl/blob/master/skl.go
probabilities = [maxHeight]uint32{}
)
func rollHeight(key []byte, salt uint64) (h uint8) {
roll := xxh3.HashSeed(key, salt)
patterns := []uint64{
pattern0,
pattern1,
pattern2,
pattern3,
func init() {
p := float64(1.0)
for i := uint8(0); i < maxHeight; i++ {
p /= math.E
probabilities[i] = uint32(float64(math.MaxUint32) * p)
}
}
for _, pat := range patterns {
if uint64(roll)&pat != pat {
break
}
func (l *List) rollHeight(key []byte) (h uint8) {
rnd := maphash.Bytes(l.seed, key)
for h < maxHeight && uint32(rnd) <= probabilities[h] {
h++
}
return
}

View File

@@ -25,8 +25,7 @@ import (
"github.com/stretchr/testify/assert"
)
// var src = rand.New(rand.NewSource(time.Now().Unix()))
var src = rand.New(rand.NewSource(0))
var randSrc = rand.New(rand.NewSource(0))
func TestSkipList(t *testing.T) {
t.Run("test skip list", func(t *testing.T) {
@@ -39,7 +38,7 @@ func TestSkipList(t *testing.T) {
})
t.Run("test skip list of random bytes", func(t *testing.T) {
vals := randomVals((src.Int63() % 10_000) + 100)
vals := randomVals((randSrc.Int63() % 10_000) + 100)
testSkipList(t, bytes.Compare, vals...)
})
t.Run("test with custom compare function", func(t *testing.T) {
@@ -48,7 +47,7 @@ func TestSkipList(t *testing.T) {
r := int64(binary.LittleEndian.Uint64(right))
return int(l - r)
}
vals := randomInts((src.Int63() % 10_000) + 100)
vals := randomInts((randSrc.Int63() % 10_000) + 100)
testSkipList(t, compare, vals...)
})
}
@@ -64,7 +63,7 @@ func TestSkipListCheckpoints(t *testing.T) {
})
t.Run("test skip list of random bytes", func(t *testing.T) {
vals := randomVals((src.Int63() % 10_000) + 100)
vals := randomVals((randSrc.Int63() % 10_000) + 100)
testSkipListCheckpoints(t, bytes.Compare, vals...)
})
t.Run("test with custom compare function", func(t *testing.T) {
@@ -73,7 +72,7 @@ func TestSkipListCheckpoints(t *testing.T) {
r := int64(binary.LittleEndian.Uint64(right))
return int(l - r)
}
vals := randomInts((src.Int63() % 10_000) + 100)
vals := randomInts((randSrc.Int63() % 10_000) + 100)
testSkipListCheckpoints(t, compare, vals...)
})
}
@@ -81,13 +80,132 @@ func TestSkipListCheckpoints(t *testing.T) {
func TestMemoryFootprint(t *testing.T) {
var sz int
sz = int(unsafe.Sizeof(skipNode{}))
assert.Equal(t, 80, sz)
sz = int(unsafe.Sizeof(skipPointer{}))
assert.Equal(t, 20, sz)
assert.Equal(t, 104, sz)
sz = int(unsafe.Sizeof(tower{}))
assert.Equal(t, 40, sz)
}
func testSkipList(t *testing.T, compare ValueCmp, vals ...[]byte) {
src.Shuffle(len(vals), func(i, j int) {
func BenchmarkList(b *testing.B) {
b.Run("benchmark Get", func(b *testing.B) {
b.Run("n=64", func(b *testing.B) {
vals := randomInts(64)
l := NewSkipList(bytes.Compare)
for i := range vals {
l.Put(vals[i], vals[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, ok := l.Get(vals[i%64])
if !ok {
b.Fail()
}
}
b.ReportAllocs()
})
b.Run("n=512", func(b *testing.B) {
vals := randomInts(512)
l := NewSkipList(bytes.Compare)
for i := range vals {
l.Put(vals[i], vals[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, ok := l.Get(vals[i%512])
if !ok {
b.Fail()
}
}
b.ReportAllocs()
})
b.Run("n=4096", func(b *testing.B) {
vals := randomInts(4096)
l := NewSkipList(bytes.Compare)
for i := range vals {
l.Put(vals[i], vals[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, ok := l.Get(vals[i%4096])
if !ok {
b.Fail()
}
}
b.ReportAllocs()
})
b.Run("n=32768", func(b *testing.B) {
vals := randomInts(32768)
l := NewSkipList(bytes.Compare)
for i := range vals {
l.Put(vals[i], vals[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, ok := l.Get(vals[i%32768])
if !ok {
b.Fail()
}
}
b.ReportAllocs()
})
})
b.Run("benchmark Put", func(b *testing.B) {
b.Run("n=64", func(b *testing.B) {
vals := randomInts(64)
l := NewSkipList(bytes.Compare)
for i := 0; i < b.N; i++ {
j := i % 64
if j == 0 {
l.Truncate()
}
l.Put(vals[j], vals[j])
}
b.ReportAllocs()
})
b.Run("n=512", func(b *testing.B) {
vals := randomInts(512)
l := NewSkipList(bytes.Compare)
for i := 0; i < b.N; i++ {
j := i % 512
if j == 0 {
l.Truncate()
}
l.Put(vals[j], vals[j])
}
b.ReportAllocs()
})
b.Run("n=4096", func(b *testing.B) {
vals := randomInts(4096)
l := NewSkipList(bytes.Compare)
for i := 0; i < b.N; i++ {
j := i % 4096
if j == 0 {
l.Truncate()
}
l.Put(vals[j], vals[j])
}
b.ReportAllocs()
})
b.Run("n=32768", func(b *testing.B) {
vals := randomInts(32768)
l := NewSkipList(bytes.Compare)
for i := 0; i < b.N; i++ {
j := i % 32768
if j == 0 {
l.Truncate()
}
l.Put(vals[j], vals[j])
}
b.ReportAllocs()
})
})
}
func testSkipList(t *testing.T, compare KeyOrder, vals ...[]byte) {
randSrc.Shuffle(len(vals), func(i, j int) {
vals[i], vals[j] = vals[j], vals[i]
})
@@ -125,8 +243,8 @@ func testSkipListPuts(t *testing.T, list *List, vals ...[]byte) {
}
func testSkipListGets(t *testing.T, list *List, vals ...[]byte) {
// get in different order
src.Shuffle(len(vals), func(i, j int) {
// get in different keyOrder
randSrc.Shuffle(len(vals), func(i, j int) {
vals[i], vals[j] = vals[j], vals[i]
})
@@ -149,7 +267,7 @@ func testSkipListUpdates(t *testing.T, list *List, vals ...[]byte) {
}
assert.Equal(t, len(vals), list.Count())
src.Shuffle(len(vals), func(i, j int) {
randSrc.Shuffle(len(vals), func(i, j int) {
vals[i], vals[j] = vals[j], vals[i]
})
for _, exp := range vals {
@@ -163,7 +281,7 @@ func testSkipListUpdates(t *testing.T, list *List, vals ...[]byte) {
}
func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) {
// put |vals| back in order
// put |vals| back in keyOrder
sort.Slice(vals, func(i, j int) bool {
return list.compareKeys(vals[i], vals[j]) < 0
})
@@ -178,7 +296,7 @@ func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) {
// test iter at
for k := 0; k < 10; k++ {
idx = src.Int() % len(vals)
idx = randSrc.Int() % len(vals)
key := vals[idx]
act := validateIterForwardFrom(t, list, key)
exp := len(vals) - idx
@@ -192,14 +310,14 @@ func testSkipListIterForward(t *testing.T, list *List, vals ...[]byte) {
}
func testSkipListIterBackward(t *testing.T, list *List, vals ...[]byte) {
// put |vals| back in order
// put |vals| back in keyOrder
sort.Slice(vals, func(i, j int) bool {
return list.compareKeys(vals[i], vals[j]) < 0
})
// test iter at
for k := 0; k < 10; k++ {
idx := src.Int() % len(vals)
idx := randSrc.Int() % len(vals)
key := vals[idx]
act := validateIterBackwardFrom(t, list, key)
assert.Equal(t, idx+1, act)
@@ -276,8 +394,8 @@ func validateIterBackwardFrom(t *testing.T, l *List, key []byte) (count int) {
func randomVals(cnt int64) (vals [][]byte) {
vals = make([][]byte, cnt)
for i := range vals {
bb := make([]byte, (src.Int63()%91)+10)
src.Read(bb)
bb := make([]byte, (randSrc.Int63()%91)+10)
randSrc.Read(bb)
vals[i] = bb
}
return
@@ -287,7 +405,7 @@ func randomInts(cnt int64) (vals [][]byte) {
vals = make([][]byte, cnt)
for i := range vals {
vals[i] = make([]byte, 8)
v := uint64(src.Int63())
v := uint64(randSrc.Int63())
binary.LittleEndian.PutUint64(vals[i], v)
}
return
@@ -317,8 +435,8 @@ func iterAllBackwards(l *List, cb func([]byte, []byte)) {
}
}
func testSkipListCheckpoints(t *testing.T, compare ValueCmp, data ...[]byte) {
src.Shuffle(len(data), func(i, j int) {
func testSkipListCheckpoints(t *testing.T, compare KeyOrder, data ...[]byte) {
randSrc.Shuffle(len(data), func(i, j int) {
data[i], data[j] = data[j], data[i]
})