go/store/skip: optimized skip list Iter

This commit is contained in:
Andy Arthur
2023-03-02 12:05:31 -08:00
parent 9936aabdf1
commit ad134f2c6b
+24 -33
View File
@@ -113,10 +113,9 @@ func (l *List) Revert() {
func (l *List) Truncate() {
l.nodes = l.nodes[:1]
// point sentinel.prev at itself
s := l.getNode(sentinelId)
s := l.nodePtr(sentinelId)
s.next = tower{}
s.prev = sentinelId
l.updateNode(s)
l.checkpoint = nodeId(1)
l.count = 0
}
@@ -139,7 +138,7 @@ func (l *List) Get(key []byte) (val []byte, ok bool) {
var id nodeId
next, prev := l.headPointer(), sentinelId
for lvl := maxHeight; lvl >= 0; {
nd := l.getNodeRef(next[lvl])
nd := l.nodePtr(next[lvl])
// descend if we can't advance at |lvl|
if l.compareKeys(key, nd.key) < 0 {
id = prev
@@ -150,7 +149,7 @@ func (l *List) Get(key []byte) (val []byte, ok bool) {
next = nd.next
prev = nd.id
}
node := l.getNodeRef(id)
node := l.nodePtr(id)
if l.compareKeys(key, node.key) == 0 {
val, ok = node.val, true
}
@@ -170,7 +169,7 @@ func (l *List) Put(key, val []byte) {
var path tower
next, prev := l.headPointer(), sentinelId
for h := maxHeight; h >= 0; {
curr := l.getNodeRef(next[h])
curr := l.nodePtr(next[h])
// descend if we can't advance at |lvl|
if l.compareKeys(key, curr.key) <= 0 {
path[h] = prev
@@ -183,8 +182,8 @@ func (l *List) Put(key, val []byte) {
}
// check if |key| exists in |l|
node := l.getNodeRef(path[0])
node = l.getNodeRef(node.next[0])
node := l.nodePtr(path[0])
node = l.nodePtr(node.next[0])
if l.compareKeys(key, node.key) == 0 {
l.overwrite(key, val, path, node)
@@ -214,15 +213,15 @@ func (l *List) insert(key, value []byte, path tower) {
id: id,
height: l.rollHeight(key),
})
novel := l.getNodeRef(id)
novel := l.nodePtr(id)
for h := uint8(0); h <= novel.height; h++ {
// set forward pointers
n := l.getNodeRef(path[h])
n := l.nodePtr(path[h])
novel.next[h] = n.next[h]
n.next[h] = novel.id
}
// set back pointers
n := l.getNodeRef(novel.next[0])
n := l.nodePtr(novel.next[0])
novel.prev = n.prev
n.prev = novel.id
}
@@ -239,16 +238,16 @@ func (l *List) overwrite(key, value []byte, path tower, old *skipNode) {
})
for h := uint8(0); h <= old.height; h++ {
// set forward pointers
n := l.getNodeRef(path[h])
n := l.nodePtr(path[h])
n.next[h] = id
}
// set back pointer
n := l.getNodeRef(old.next[0])
n := l.nodePtr(old.next[0])
n.prev = id
}
type ListIter struct {
curr skipNode
curr *skipNode
list *List
}
@@ -259,13 +258,13 @@ func (it *ListIter) Current() (key, val []byte) {
// Advance advances the iterator.
func (it *ListIter) Advance() {
it.curr = it.list.getNode(it.curr.next[0])
it.curr = it.list.nodePtr(it.curr.next[0])
return
}
// Retreat retreats the iterator.
func (it *ListIter) Retreat() {
it.curr = it.list.getNode(it.curr.prev)
it.curr = it.list.nodePtr(it.curr.prev)
return
}
@@ -308,19 +307,19 @@ func (l *List) IterAtEnd() *ListIter {
}
// seek returns the skipNode with the smallest key >= |key|.
func (l *List) seek(key []byte) skipNode {
func (l *List) seek(key []byte) *skipNode {
return l.seekWithFn(func(curr []byte) (advance bool) {
return l.compareKeys(key, curr) > 0
})
}
func (l *List) seekWithFn(cb SeekFn) (node skipNode) {
func (l *List) seekWithFn(cb SeekFn) (node *skipNode) {
ptr := l.headPointer()
for h := int64(maxHeight); h >= 0; h-- {
node = l.getNode(ptr[h])
node = l.nodePtr(ptr[h])
for cb(node.key) {
ptr = node.next
node = l.getNode(ptr[h])
node = l.nodePtr(ptr[h])
}
}
return
@@ -330,27 +329,19 @@ func (l *List) headPointer() tower {
return l.nodes[0].next
}
func (l *List) firstNode() skipNode {
return l.getNode(l.nodes[0].next[0])
func (l *List) firstNode() *skipNode {
return l.nodePtr(l.nodes[0].next[0])
}
func (l *List) lastNode() skipNode {
s := l.getNode(sentinelId)
return l.getNode(s.prev)
func (l *List) lastNode() *skipNode {
s := l.nodePtr(sentinelId)
return l.nodePtr(s.prev)
}
func (l *List) getNode(id nodeId) skipNode {
return l.nodes[id]
}
func (l *List) getNodeRef(id nodeId) *skipNode {
func (l *List) nodePtr(id nodeId) *skipNode {
return &l.nodes[id]
}
func (l *List) updateNode(node skipNode) {
l.nodes[node.id] = node
}
func (l *List) nextNodeId() nodeId {
return nodeId(len(l.nodes))
}