diff --git a/fidget/src/mesh/mt/pool.rs b/fidget/src/mesh/mt/pool.rs index e4f5a8c9..767bf313 100644 --- a/fidget/src/mesh/mt/pool.rs +++ b/fidget/src/mesh/mt/pool.rs @@ -56,7 +56,6 @@ impl ThreadPool { threads, counter: &self.counter, index, - phase: 1, } } } @@ -66,47 +65,16 @@ pub struct ThreadContext<'a> { threads: std::sync::RwLockReadGuard<'a, Vec>, counter: &'a AtomicUsize, index: usize, - - /// We operate in 4 phases, depending on the value of `phase % 4`: - /// - /// `phase` | byte | direction | start | end - /// --------|------|-----------|-------|----- - /// 0 | 0 | up | 0 | `N` - /// 1 | 1 | up | 0 | `N` - /// 2 | 0 | down | `N` | 0 - /// 3 | 1 | down | `N` | 0 - /// - /// (`N` is `self.threads.len()`) - /// - /// Note that the pairs of adjacent phases are non-interfering: if a thread - /// in phase 0 notices that it has hit `N`, then it can immediately enter - /// phase 1 and start modifing byte 1 of the counter, without other threads - /// in phase 0 noticing. - phase: u8, } impl ThreadContext<'_> { - /// Returns `true` if any threads in the pool are sleeping - fn any_sleeping(&self) -> bool { - let mut v = self.counter.load(Ordering::Acquire); - if (self.phase % 2) == 1 { - v >>= 8; - } - // Check to see if any other threads are sleeping; otherwise, skip the - // unparking step (because it costs time) - (((self.phase % 2) / 2 == 0) && v != 0) - || (((self.phase % 2) / 2 == 1) && v != self.threads.len()) - } - /// If some threads in the pool are sleeping, wakes them up /// /// This function should be called when work is available. pub fn wake(&self) { - if self.any_sleeping() { - for (i, t) in self.threads.iter().enumerate() { - if i != self.index { - t.unpark(); - } + for (i, t) in self.threads.iter().enumerate() { + if i != self.index { + t.unpark(); } } } @@ -118,9 +86,7 @@ impl ThreadContext<'_> { /// (i.e. with `i == self.index`) pub fn wake_one(&self, i: usize) { assert_ne!(i, self.index); - if self.any_sleeping() { - self.threads[i].unpark(); - } + self.threads[i].unpark(); } /// Used to record that a piece of data has been scheduled for processing @@ -148,14 +114,8 @@ impl ThreadContext<'_> { } fn done(&self, v: usize) -> bool { - (v >> 16 == 0) // No MPSC work queued up - & match self.phase % 4 { - 0 => v & 0xFF == self.threads.len(), - 1 => v >> 8 == self.threads.len(), - 2 => v & 0xFF == 0, - 3 => v >> 8 == 0, - _ => unreachable!(), - } + v >> 16 == 0 // No MPSC work queued up + && v >> 8 == self.threads.len() // all threads sleeping } /// Sends the given thread to sleep @@ -164,13 +124,7 @@ impl ThreadContext<'_> { /// threads in the pool have requested to sleep, indicating that all work is /// done and they should now halt. pub fn sleep(&mut self) -> bool { - let v = match self.phase % 4 { - 0 => self.counter.fetch_add(1, Ordering::Release) + 1, - 1 => self.counter.fetch_add(256, Ordering::Release) + 256, - 2 => self.counter.fetch_sub(1, Ordering::Release) - 1, - 3 => self.counter.fetch_sub(256, Ordering::Release) - 256, - _ => unreachable!(), - }; + let v = self.counter.fetch_add(256, Ordering::Release) + 256; // At this point, the thread doesn't have any work to do, so we'll // consider putting it to sleep. However, if every other thread is @@ -197,18 +151,11 @@ impl ThreadContext<'_> { } if done { - self.phase += 1; + false // stop looping } else { - // Back to the grind - match self.phase % 4 { - 0 => self.counter.fetch_sub(1, Ordering::Release), - 1 => self.counter.fetch_sub(256, Ordering::Release), - 2 => self.counter.fetch_add(1, Ordering::Release), - 3 => self.counter.fetch_add(256, Ordering::Release), - _ => unreachable!(), - }; + self.counter.fetch_sub(256, Ordering::Release); + true // keep going } - !done } } @@ -355,14 +302,12 @@ mod test { s.spawn(move || { let mut ctx = pool.start(i); let t = std::time::Duration::from_millis(1); - for _ in 0..8 { - for _ in 0..i { - std::thread::sleep(t); - ctx.wake(); - } - while ctx.sleep() { - // Loop forever - } + for _ in 0..i { + std::thread::sleep(t); + ctx.wake(); + } + while ctx.sleep() { + // Loop forever } done.fetch_add(1, Ordering::Release); }); @@ -370,4 +315,48 @@ mod test { }); assert_eq!(done.load(Ordering::Acquire), N); } + + #[test] + fn queue_and_thread_pool() { + const N: usize = 8; + let mut queues = QueuePool::new(N); + let pool = &ThreadPool::new(N); + let mut counters = [0i32; N]; + const DEPTH: usize = 16; + queues[0].push(DEPTH); + + // Confirm that stealing leads to shared work between two threads + std::thread::scope(|s| { + for (i, (q, c)) in + queues.iter_mut().zip(counters.iter_mut()).enumerate() + { + s.spawn(move || { + let mut ctx = pool.start(i); + loop { + if let Some(v) = q.pop() { + *c += 1; + if v != 0 { + q.push(v - 1); + q.push(v - 1); + } + if q.changed() { + ctx.wake(); + } + continue; + } + if !ctx.sleep() { + break; + } + } + }); + } + }); + + const EXPECTED_COUNT: usize = (1 << (DEPTH + 1)) - 1; + assert_eq!( + counters.iter().sum::(), + EXPECTED_COUNT as i32, + "threads did not complete all work" + ); + } }