Skip to content

Commit

Permalink
Simplify pools (#82)
Browse files Browse the repository at this point in the history
- Fix a potential race condition between checking `any_sleeping` and
waking threads
- Removed `ThreadPool::phase`, because we only ever have two phases
(startup and running)
- Add a unit test that checks both `ThreadPool` and `QueuePool` together
  • Loading branch information
mkeeter authored Apr 13, 2024
1 parent da3b7ee commit c8ebee1
Showing 1 changed file with 60 additions and 71 deletions.
131 changes: 60 additions & 71 deletions fidget/src/mesh/mt/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ impl ThreadPool {
threads,
counter: &self.counter,
index,
phase: 1,
}
}
}
Expand All @@ -66,47 +65,16 @@ pub struct ThreadContext<'a> {
threads: std::sync::RwLockReadGuard<'a, Vec<std::thread::Thread>>,
counter: &'a AtomicUsize,
index: usize,

/// We operate in 4 phases, depending on the value of `phase % 4`:
///
/// `phase` | byte | direction | start | end
/// --------|------|-----------|-------|-----
/// 0 | 0 | up | 0 | `N`
/// 1 | 1 | up | 0 | `N`
/// 2 | 0 | down | `N` | 0
/// 3 | 1 | down | `N` | 0
///
/// (`N` is `self.threads.len()`)
///
/// Note that the pairs of adjacent phases are non-interfering: if a thread
/// in phase 0 notices that it has hit `N`, then it can immediately enter
/// phase 1 and start modifing byte 1 of the counter, without other threads
/// in phase 0 noticing.
phase: u8,
}

impl ThreadContext<'_> {
/// Returns `true` if any threads in the pool are sleeping
fn any_sleeping(&self) -> bool {
let mut v = self.counter.load(Ordering::Acquire);
if (self.phase % 2) == 1 {
v >>= 8;
}
// Check to see if any other threads are sleeping; otherwise, skip the
// unparking step (because it costs time)
(((self.phase % 2) / 2 == 0) && v != 0)
|| (((self.phase % 2) / 2 == 1) && v != self.threads.len())
}

/// If some threads in the pool are sleeping, wakes them up
///
/// This function should be called when work is available.
pub fn wake(&self) {
if self.any_sleeping() {
for (i, t) in self.threads.iter().enumerate() {
if i != self.index {
t.unpark();
}
for (i, t) in self.threads.iter().enumerate() {
if i != self.index {
t.unpark();
}
}
}
Expand All @@ -118,9 +86,7 @@ impl ThreadContext<'_> {
/// (i.e. with `i == self.index`)
pub fn wake_one(&self, i: usize) {
assert_ne!(i, self.index);
if self.any_sleeping() {
self.threads[i].unpark();
}
self.threads[i].unpark();
}

/// Used to record that a piece of data has been scheduled for processing
Expand Down Expand Up @@ -148,14 +114,8 @@ impl ThreadContext<'_> {
}

fn done(&self, v: usize) -> bool {
(v >> 16 == 0) // No MPSC work queued up
& match self.phase % 4 {
0 => v & 0xFF == self.threads.len(),
1 => v >> 8 == self.threads.len(),
2 => v & 0xFF == 0,
3 => v >> 8 == 0,
_ => unreachable!(),
}
v >> 16 == 0 // No MPSC work queued up
&& v >> 8 == self.threads.len() // all threads sleeping
}

/// Sends the given thread to sleep
Expand All @@ -164,13 +124,7 @@ impl ThreadContext<'_> {
/// threads in the pool have requested to sleep, indicating that all work is
/// done and they should now halt.
pub fn sleep(&mut self) -> bool {
let v = match self.phase % 4 {
0 => self.counter.fetch_add(1, Ordering::Release) + 1,
1 => self.counter.fetch_add(256, Ordering::Release) + 256,
2 => self.counter.fetch_sub(1, Ordering::Release) - 1,
3 => self.counter.fetch_sub(256, Ordering::Release) - 256,
_ => unreachable!(),
};
let v = self.counter.fetch_add(256, Ordering::Release) + 256;

// At this point, the thread doesn't have any work to do, so we'll
// consider putting it to sleep. However, if every other thread is
Expand All @@ -197,18 +151,11 @@ impl ThreadContext<'_> {
}

if done {
self.phase += 1;
false // stop looping
} else {
// Back to the grind
match self.phase % 4 {
0 => self.counter.fetch_sub(1, Ordering::Release),
1 => self.counter.fetch_sub(256, Ordering::Release),
2 => self.counter.fetch_add(1, Ordering::Release),
3 => self.counter.fetch_add(256, Ordering::Release),
_ => unreachable!(),
};
self.counter.fetch_sub(256, Ordering::Release);
true // keep going
}
!done
}
}

Expand Down Expand Up @@ -355,19 +302,61 @@ mod test {
s.spawn(move || {
let mut ctx = pool.start(i);
let t = std::time::Duration::from_millis(1);
for _ in 0..8 {
for _ in 0..i {
std::thread::sleep(t);
ctx.wake();
}
while ctx.sleep() {
// Loop forever
}
for _ in 0..i {
std::thread::sleep(t);
ctx.wake();
}
while ctx.sleep() {
// Loop forever
}
done.fetch_add(1, Ordering::Release);
});
}
});
assert_eq!(done.load(Ordering::Acquire), N);
}

#[test]
fn queue_and_thread_pool() {
const N: usize = 8;
let mut queues = QueuePool::new(N);
let pool = &ThreadPool::new(N);
let mut counters = [0i32; N];
const DEPTH: usize = 16;
queues[0].push(DEPTH);

// Confirm that stealing leads to shared work between two threads
std::thread::scope(|s| {
for (i, (q, c)) in
queues.iter_mut().zip(counters.iter_mut()).enumerate()
{
s.spawn(move || {
let mut ctx = pool.start(i);
loop {
if let Some(v) = q.pop() {
*c += 1;
if v != 0 {
q.push(v - 1);
q.push(v - 1);
}
if q.changed() {
ctx.wake();
}
continue;
}
if !ctx.sleep() {
break;
}
}
});
}
});

const EXPECTED_COUNT: usize = (1 << (DEPTH + 1)) - 1;
assert_eq!(
counters.iter().sum::<i32>(),
EXPECTED_COUNT as i32,
"threads did not complete all work"
);
}
}

0 comments on commit c8ebee1

Please sign in to comment.