diff --git a/crates/harness-tests/src/routing.rs b/crates/harness-tests/src/routing.rs index 23d56ae9c..76dd2b887 100644 --- a/crates/harness-tests/src/routing.rs +++ b/crates/harness-tests/src/routing.rs @@ -132,12 +132,23 @@ async fn route_ws_to_correct_monolith_race(ctx: &mut TestRunner) { m.load_room(room_name.clone()).await; let mut client = Client::new(ctx).unwrap(); - client.join(room_name).await; + client.join(room_name.clone()).await; println!("waiting for monolith to receive join message"); - tokio::time::timeout(Duration::from_secs(1), m.wait_recv()) - .await - .expect("msg recv timeout"); + // this more accurately emulates what the client would actually do + loop { + tokio::select! { + result = tokio::time::timeout(Duration::from_secs(1), m.wait_recv()) => { + result.expect("msg recv timeout"); + break; + }, + _ = client.wait_for_disconnect() => { + println!("client disconnected, retrying"); + client.join(room_name.clone()).await; + continue; + } + }; + } let recvd = m.collect_recv(); assert_eq!(recvd.len(), 1); diff --git a/crates/harness/src/client.rs b/crates/harness/src/client.rs index 90b751ef9..866e135d7 100644 --- a/crates/harness/src/client.rs +++ b/crates/harness/src/client.rs @@ -80,6 +80,15 @@ impl Client { let _ = stream.close(None).await; } + pub async fn wait_for_disconnect(&mut self) { + if !self.connected() { + return; + } + + let mut stream = self.stream.take().unwrap(); + while stream.next().await.is_some() {} + } + pub async fn recv(&mut self) -> anyhow::Result<Message> { if let Some(stream) = self.stream.as_mut() { match tokio::time::timeout(Duration::from_millis(200), stream.next()).await {