diff --git a/crates/harness/src/monolith.rs b/crates/harness/src/monolith.rs index adf8225d0..e85babbb1 100644 --- a/crates/harness/src/monolith.rs +++ b/crates/harness/src/monolith.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, net::{IpAddr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{ @@ -14,7 +14,7 @@ use http_body_util::{BodyExt, Full}; use hyper::{body::Incoming as IncomingBody, Request}; use hyper::{service::Service, Response}; -use ott_balancer_protocol::{monolith::*, RoomName}; +use ott_balancer_protocol::{monolith::*, ClientId, RoomName}; use tokio::{net::TcpListener, sync::Notify}; use tracing::warn; use tungstenite::Message; @@ -52,6 +52,7 @@ pub(crate) struct MonolithState { response_mocks: HashMap, rooms: HashMap, room_load_epoch: Arc, + clients: HashSet, } impl Monolith { @@ -103,7 +104,21 @@ impl Monolith { match msg { Ok(msg) => { println!("monolith: incoming msg: {}", msg); - state.lock().unwrap().received_raw.push(msg); + let mut state = state.lock().unwrap(); + state.received_raw.push(msg.clone()); + // TODO: there's a better way to generalize this + if let Message::Text(m) = msg { + let msg: MsgB2M = serde_json::from_str(&m).unwrap(); + match msg { + MsgB2M::Join(join) => { + state.clients.insert(join.client); + }, + MsgB2M::Leave(leave) => { + state.clients.remove(&leave.client); + }, + _ => {}, + } + } _notif_recv.notify_one(); }, Err(e) => { @@ -178,6 +193,10 @@ impl Monolith { } } + pub fn clients(&self) -> HashSet { + self.state.lock().unwrap().clients.clone() + } + /// Tell the provider to add this monolith to the list of available monoliths. pub async fn show(&mut self) { println!("showing monolith"); @@ -371,3 +390,27 @@ pub struct MockRequest { pub headers: hyper::HeaderMap, pub body: Bytes, } + +#[cfg(test)] +mod tests { + use test_context::test_context; + + use crate::{Client, Monolith, TestRunner}; + + #[test_context(TestRunner)] + #[tokio::test] + async fn should_track_clients(ctx: &mut TestRunner) { + let mut m = Monolith::new(ctx).await.unwrap(); + m.show().await; + + let mut c1 = Client::new(ctx).unwrap(); + c1.join("foo").await; + + m.wait_recv().await; + assert_eq!(m.clients().len(), 1); + + c1.disconnect().await; + m.wait_recv().await; + assert_eq!(m.clients().len(), 0); + } +}