Skip to content

Commit

Permalink
improve thread parker performance
Browse files Browse the repository at this point in the history
  • Loading branch information
alphadose committed Aug 5, 2022
1 parent 23116b0 commit 5718598
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
51 changes: 27 additions & 24 deletions thread_parker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ import (
// useful for saving up resources by parking excess goroutines and pre-empt them when required with minimal latency overhead
// Uses the same lock-free linked list implementation as in `list.go`
type ThreadParker[T any] struct {
head unsafe.Pointer
tail unsafe.Pointer
head atomic.Pointer[parkSpot[T]]
tail atomic.Pointer[parkSpot[T]]
}

// NewThreadParker returns a new thread parker.
func NewThreadParker[T any](n unsafe.Pointer) *ThreadParker[T] {
return &ThreadParker[T]{head: n, tail: n}
func NewThreadParker[T any](spot *parkSpot[T]) *ThreadParker[T] {
var ptr atomic.Pointer[parkSpot[T]]
ptr.Store(spot)
return &ThreadParker[T]{head: ptr, tail: ptr}
}

// a single parked goroutine
type parkSpot[T any] struct {
next unsafe.Pointer
next atomic.Pointer[parkSpot[T]]
threadPtr unsafe.Pointer
value T
}
Expand All @@ -29,43 +31,44 @@ type parkSpot[T any] struct {
// This keeps only one parked goroutine in state at all times
// the parked goroutine is called with minimal overhead via goready() due to both being in userland
// This ensures there is no thundering herd https://en.wikipedia.org/wiki/Thundering_herd_problem
func (tp *ThreadParker[T]) Park(nextNode unsafe.Pointer) {
var tail, next unsafe.Pointer
func (tp *ThreadParker[T]) Park(nextNode *parkSpot[T]) {
var tail, next *parkSpot[T]
for {
tail = atomic.LoadPointer(&tp.tail)
next = atomic.LoadPointer(&((*parkSpot[T])(tail)).next)
if tail == atomic.LoadPointer(&tp.tail) {
tail = tp.tail.Load()
next = tail.next.Load()
if tail == tp.tail.Load() {
if next == nil {
if atomic.CompareAndSwapPointer(&((*parkSpot[T])(tail)).next, next, nextNode) {
atomic.CompareAndSwapPointer(&tp.tail, tail, nextNode)
if tail.next.CompareAndSwap(next, nextNode) {
tp.tail.CompareAndSwap(tail, nextNode)
return
}
} else {
atomic.CompareAndSwapPointer(&tp.tail, tail, next)
tp.tail.CompareAndSwap(tail, next)
}
}
}
}

// Ready calls one parked goroutine from the queue if available
func (tp *ThreadParker[T]) Ready() (data T, ok bool, freeable *parkSpot[T]) {
var head, tail, next unsafe.Pointer
var head, tail, next *parkSpot[T]
for {
head = atomic.LoadPointer(&tp.head)
tail = atomic.LoadPointer(&tp.tail)
next = atomic.LoadPointer(&((*parkSpot[T])(head)).next)
if head == atomic.LoadPointer(&tp.head) {
head = tp.head.Load()
tail = tp.tail.Load()
next = head.next.Load()
if head == tp.head.Load() {
if head == tail {
if next == nil {
return
}
atomic.CompareAndSwapPointer(&tp.tail, tail, next)
tp.tail.CompareAndSwap(tail, next)
} else {
safe_ready((*parkSpot[T])(next).threadPtr)
data, ok = (*parkSpot[T])(next).value, true
if atomic.CompareAndSwapPointer(&tp.head, head, next) {
freeable = (*parkSpot[T])(head)
freeable.next, freeable.threadPtr = nil, nil
safe_ready(next.threadPtr)
data, ok = next.value, true
if tp.head.CompareAndSwap(head, next) {
freeable = head
freeable.threadPtr = nil
freeable.next.Store(nil)
return
}
}
Expand Down
11 changes: 6 additions & 5 deletions zenq.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ func New[T any](size uint32) *ZenQ[T] {
parkPool = sync.Pool{New: func() any { return new(parkSpot[T]) }}
)
for idx := uint32(0); idx < queueSize; idx++ {
n := parkPool.Get().(*parkSpot[T])
n.threadPtr, n.next = nil, nil
contents[idx].writeParker = NewThreadParker[T](unsafe.Pointer(n))
spot := parkPool.Get().(*parkSpot[T])
spot.threadPtr = nil
contents[idx].writeParker = NewThreadParker(spot)
}
zenq := &ZenQ[T]{
metaQ: metaQ{
Expand Down Expand Up @@ -170,8 +170,9 @@ direct_send:
wait()
case SlotCommitted:
n := self.alloc().(*parkSpot[T])
n.threadPtr, n.next, n.value = GetG(), nil, value
slot.writeParker.Park(unsafe.Pointer(n))
n.threadPtr, n.value = GetG(), value
n.next.Store(nil)
slot.writeParker.Park(n)
mcall(fast_park)
return
case SlotEmpty:
Expand Down

0 comments on commit 5718598

Please sign in to comment.