-
Notifications
You must be signed in to change notification settings - Fork 18
/
select_list.go
93 lines (86 loc) · 2.44 KB
/
select_list.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package zenq
import (
"sync"
"sync/atomic"
"unsafe"
)
// global memory pool for storing and leasing node objects
var (
nodePool = sync.Pool{New: func() any { return new(node) }}
nodeGet = nodePool.Get
nodePut = nodePool.Put
)
// List is a lock-free linked list
// theory -> https://www.cs.rochester.edu/u/scott/papers/1996_PODC_queues.pdf
// pseudocode -> https://www.cs.rochester.edu/research/synchronization/pseudocode/queues.html
type List struct {
head atomic.Pointer[node]
tail atomic.Pointer[node]
}
// NewList returns a new list
func NewList() List {
n := nodeGet().(*node)
n.threadPtr, n.dataOut = nil, nil
n.next.Store(nil)
var ptr atomic.Pointer[node]
ptr.Store(n)
return List{head: ptr, tail: ptr}
}
// a single node in the linked list
type node struct {
next atomic.Pointer[node]
threadPtr *unsafe.Pointer
dataOut *any
}
// Enqueue inserts a value into the list
func (l *List) Enqueue(threadPtr *unsafe.Pointer, dataOut *any) {
var (
n = nodeGet().(*node)
tail, next *node
)
n.threadPtr, n.dataOut = threadPtr, dataOut
for {
tail = l.tail.Load()
next = tail.next.Load()
if tail == l.tail.Load() { // are tail and next consistent?
if next == nil {
if tail.next.CompareAndSwap(next, n) {
l.tail.CompareAndSwap(tail, n) // Enqueue is done. try to swing tail to the inserted node
return
}
} else { // tail was not pointing to the last node
// try to swing Tail to the next node
l.tail.CompareAndSwap(tail, next)
}
}
}
}
// Dequeue removes and returns the value at the head of the queue to the memory pool
// It returns nil if the list is empty
func (l *List) Dequeue() (threadPtr *unsafe.Pointer, dataOut *any) {
var head, tail, next *node
for {
head = l.head.Load()
tail = l.tail.Load()
next = head.next.Load()
if head == l.head.Load() { // are head, tail, and next consistent?
if head == tail { // is list empty or tail falling behind?
if next == nil { // is list empty?
return nil, nil
}
// tail is falling behind. try to advance it
l.tail.CompareAndSwap(tail, next)
} else {
// read value before CAS_node otherwise another dequeue might free the next node
threadPtr, dataOut = next.threadPtr, next.dataOut
if l.head.CompareAndSwap(head, next) {
// sysFreeOS(unsafe.Pointer(head), nodeSize)
head.threadPtr, head.dataOut = nil, nil
head.next.Store(nil)
nodePut(head)
return // Dequeue is done. return
}
}
}
}
}