diff --git a/crates/harness-tests/src/main.rs b/crates/harness-tests/src/main.rs index 2874dcad3..1200ad882 100644 --- a/crates/harness-tests/src/main.rs +++ b/crates/harness-tests/src/main.rs @@ -6,6 +6,7 @@ use test_context::test_context; mod connection; mod routing; +mod state; #[test_context(TestRunner)] #[tokio::test] diff --git a/crates/harness-tests/src/state.rs b/crates/harness-tests/src/state.rs new file mode 100644 index 000000000..7577865e9 --- /dev/null +++ b/crates/harness-tests/src/state.rs @@ -0,0 +1,50 @@ +//! Tests for how the balancer handles room state in the context of managing rooms on monoliths. + +use std::time::Duration; + +use harness::{Client, Monolith, TestRunner}; +use ott_balancer_protocol::monolith::{B2MUnload, MsgB2M}; +use test_context::test_context; + +#[test_context(TestRunner)] +#[tokio::test] +async fn should_unload_duplicate_rooms(ctx: &TestRunner) { + let mut m1 = Monolith::new(ctx).await.unwrap(); + let mut m2 = Monolith::new(ctx).await.unwrap(); + + m1.show().await; + m2.show().await; + + m1.load_room("foo").await; + m2.load_room("foo").await; + + m2.wait_recv().await; + + let recv = m2.collect_recv(); + assert_eq!(recv.len(), 1); + assert!(matches!(recv[0], MsgB2M::Unload(B2MUnload { .. }))); +} + +#[test_context(TestRunner)] +#[tokio::test] +async fn should_unload_duplicate_rooms_and_route_correctly(ctx: &TestRunner) { + let mut m1 = Monolith::new(ctx).await.unwrap(); + let mut m2 = Monolith::new(ctx).await.unwrap(); + + m1.show().await; + m2.show().await; + + m1.load_room("foo").await; + m2.load_room("foo").await; + + tokio::time::timeout(Duration::from_millis(100), m2.wait_recv()) + .await + .expect("timed out waiting for unload"); + + let mut c = Client::new(ctx).unwrap(); + c.join("foo").await; + + tokio::time::timeout(Duration::from_millis(100), m1.wait_recv()) + .await + .expect("timed out waiting for client join"); +} diff --git a/crates/ott-balancer-bin/src/balancer.rs b/crates/ott-balancer-bin/src/balancer.rs index 78db3ea20..e27dd4649 100644 --- a/crates/ott-balancer-bin/src/balancer.rs +++ b/crates/ott-balancer-bin/src/balancer.rs @@ -1,6 +1,8 @@ use std::{collections::HashMap, sync::Arc}; -use ott_balancer_protocol::monolith::{B2MClientMsg, B2MJoin, B2MLeave, MsgM2B, RoomMetadata}; +use ott_balancer_protocol::monolith::{ + B2MClientMsg, B2MJoin, B2MLeave, B2MUnload, MsgM2B, RoomMetadata, +}; use ott_balancer_protocol::*; use rand::seq::IteratorRandom; use serde_json::value::RawValue; @@ -343,12 +345,15 @@ impl BalancerContext { match locator.load_epoch().cmp(&load_epoch) { std::cmp::Ordering::Less => { // we already have an older version of this room + self.unload_room(monolith_id, metadata.name.clone()).await?; return Err(anyhow::anyhow!("room already loaded")); } std::cmp::Ordering::Greater => { // we have an newer version of this room, remove it - self.remove_room(&metadata.name, locator.monolith_id()) + self.unload_room(locator.monolith_id(), metadata.name.clone()) .await?; + // self.remove_room(&metadata.name, locator.monolith_id()) + // .await?; } _ => {} } @@ -414,6 +419,12 @@ impl BalancerContext { .ok_or(anyhow::anyhow!("no monoliths available"))?; Ok(selected) } + + pub async fn unload_room(&self, monolith: MonolithId, room: RoomName) -> anyhow::Result<()> { + let monolith = self.monoliths.get(&monolith).unwrap(); + monolith.send(B2MUnload { room }).await?; + Ok(()) + } } pub async fn join_client( diff --git a/crates/ott-balancer-protocol/src/monolith.rs b/crates/ott-balancer-protocol/src/monolith.rs index 7c6f3b6ed..03036812a 100644 --- a/crates/ott-balancer-protocol/src/monolith.rs +++ b/crates/ott-balancer-protocol/src/monolith.rs @@ -58,6 +58,12 @@ impl From for MsgB2M { } } +impl From for MsgB2M { + fn from(val: B2MUnload) -> Self { + Self::Unload(val) + } +} + impl From for MsgB2M { fn from(val: B2MJoin) -> Self { Self::Join(val)