From a6fd379a227d58149d9851c7ded839d4152dd487 Mon Sep 17 00:00:00 2001 From: mxxntype <59417007+mxxntype@users.noreply.github.com> Date: Tue, 21 May 2024 05:09:28 +0300 Subject: [PATCH] fix(server): Prevent `broadcast::Receiver` leaks with a custom channel --- Cargo.lock | 2 ++ cli/src/main.rs | 31 +++++++++------- server/Cargo.toml | 3 +- server/src/channel/mod.rs | 40 +++++++++++++++++++++ server/src/lib.rs | 1 + server/src/services/chat.rs | 72 +++++++++++++++++++++++++++++++------ 6 files changed, 125 insertions(+), 24 deletions(-) create mode 100644 server/src/channel/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 322ff06..6faba9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1254,6 +1254,7 @@ dependencies = [ "thiserror", "tokio", "tokio-stream", + "tokio-util", "tonic", "tonic-build", "tracing", @@ -1382,6 +1383,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/cli/src/main.rs b/cli/src/main.rs index 7c7b11a..a08aee8 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -139,12 +139,6 @@ async fn existing_room( .trim_matches(|c| c == '(' || c == ')'); let chosen_room = Uuid::from_str(chosen_room).unwrap(); - let mut message_stream = chat - .subscribe_to_room(proto::Uuid::from(chosen_room)) - .await - .unwrap() - .into_inner(); - let room_action = Listbox::new(["Send new messages", "Listen to messages"]) .title("Would you like to send new messages or listen to incoming ones?") .prompt() @@ -173,6 +167,7 @@ async fn existing_room( }, "Listen to messages" => { + // Print older, in-database messages. let messages = chat .list_messages(Into::::into(chosen_room)) .await @@ -181,12 +176,24 @@ async fn existing_room( .messages; for msg in messages.into_iter() { - print_message(msg); + print_message(&msg); } - while let Ok(event) = message_stream.next().await.unwrap() { + // Print live messages. + let mut message_stream = chat + .subscribe_to_room(proto::Uuid::from(chosen_room)) + .await + .unwrap() + .into_inner(); + + 'message_listener: while let Ok(event) = message_stream.next().await.unwrap() { match event.event.unwrap() { - Event::NewMessage(msg) => print_message(msg), + Event::NewMessage(msg) => { + print_message(&msg); + if msg.text.as_str() == "exit" { + break 'message_listener; + } + } } } } @@ -195,11 +202,11 @@ async fn existing_room( } } -fn print_message(msg: ServersideMessage) { +fn print_message(msg: &ServersideMessage) { println!( "{} | {}: {}", - msg.timestamp.unwrap().blue(), - msg.sender_uuid.unwrap().uuid.green(), + msg.timestamp.clone().unwrap().blue(), + msg.sender_uuid.clone().unwrap().uuid.green(), msg.text ); } diff --git a/server/Cargo.toml b/server/Cargo.toml index 9e3cbc4..eecc7e6 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -22,7 +22,8 @@ rand_core = "0.6.4" redis = { version = "0.25.3", features = ["uuid", "tokio-comp", "aio"] } thiserror = "1.0.61" tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] } -tokio-stream = "0.1.15" +tokio-stream = { version = "0.1.15", features = ["sync"] } +tokio-util = "0.7.11" tonic = "0.11.0" tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/server/src/channel/mod.rs b/server/src/channel/mod.rs new file mode 100644 index 0000000..1c365c1 --- /dev/null +++ b/server/src/channel/mod.rs @@ -0,0 +1,40 @@ +//! # A wrapper around a [`mpsc`] channel that detects disconnects. +//! +//! Implements the [`Deref`] trait (`Target = mpsc::Receiver`), and uses a [`oneshot`] +//! channel to send a single message back when the whole thing gets dropped. +//! +//! **Source & further reading:** + +use futures::Stream; +use std::task::{Context, Poll}; +use std::{ops::Deref, pin::Pin}; +use tokio::sync::{mpsc, oneshot}; + +pub struct DisconnectChannel { + pub(crate) signal_sender: Option>, + pub(crate) inner_channel: mpsc::Receiver, +} + +impl Stream for DisconnectChannel { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner_channel).poll_recv(cx) + } +} + +impl Deref for DisconnectChannel { + type Target = mpsc::Receiver; + + fn deref(&self) -> &Self::Target { + &self.inner_channel + } +} + +impl Drop for DisconnectChannel { + fn drop(&mut self) { + if let Some(drop_signal) = self.signal_sender.take() { + let _ = drop_signal.send(()); + } + } +} diff --git a/server/src/lib.rs b/server/src/lib.rs index 8cbb190..297072b 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,6 +1,7 @@ #![deny(clippy::unwrap_used)] pub mod auth; +pub mod channel; pub mod entities; pub mod persistence; pub mod services; diff --git a/server/src/services/chat.rs b/server/src/services/chat.rs index 9d1cca7..473f4f3 100644 --- a/server/src/services/chat.rs +++ b/server/src/services/chat.rs @@ -1,17 +1,19 @@ use crate::auth::AuthenticatedRequest; +use crate::channel::DisconnectChannel; use crate::entities::{Message, Room, User}; use crate::proto::user_lookup_request::Identifier; use crate::proto::{ClientsideMessage, ClientsideRoom, MessageList, RoomList}; use crate::proto::{RoomWithUserCreationRequest, UserLookupRequest}; use crate::proto::{ServersideMessage, ServersideRoom, ServersideRoomEvent, ServersideUserEvent}; -use crate::{persistence, proto}; +use crate::{channel, persistence, proto}; use diesel::r2d2::{ConnectionManager, PooledConnection}; use diesel::PgConnection; use redis::aio::MultiplexedConnection; use redis::{AsyncCommands, Client, RedisResult}; use std::env; -use tokio::sync::{broadcast, mpsc}; +use tokio::sync::{broadcast, mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use tonic::{Request, Response, Status}; use uuid::Uuid; @@ -22,7 +24,7 @@ pub struct Chat { cache_client: redis::Client, // Message passing channel. - message_sender: broadcast::Sender, + message_tx: broadcast::Sender, } #[tonic::async_trait] @@ -268,7 +270,7 @@ impl proto::chat_server::Chat for Chat { Status::internal("Could not send the message due to an internal error") })?; - match self.message_sender.send(message) { + match self.message_tx.send(message) { Ok(recv_count) => tracing::trace!(message = "Broadcasting message", ?recv_count), Err(err) => tracing::error!(message = "Could not broadcast message", ?err), } @@ -344,9 +346,9 @@ impl proto::chat_server::Chat for Chat { Ok(Response::new(private_room_uuid.into())) } - type SubscribeToRoomStream = ReceiverStream>; + type SubscribeToRoomStream = DisconnectChannel>; - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, a = 1)] async fn subscribe_to_room( &self, request: Request, @@ -376,12 +378,45 @@ impl proto::chat_server::Chat for Chat { )); } - let (grpc_tx, grpc_rx) = mpsc::channel(4); + // The 'streamer' thread (see below) needs a cache connection. let mut cache_connection = self.acquire_cache_connection().await?; - let mut message_receiver = self.message_sender.subscribe(); + // NOTE: Read this. + // + // There are a total of 3 channels involved in this whole streaming thing: + // - An internal `broadcast` channel that transfers messages from `SendMessage` RPC calls; + // - A `DisconnectChannel`, which holds another 2 channels inside: + // - A `mpsc` Tokio channel, which performs gRPC streaming; + // - A `oneshot` Tokio channel, which fires when the client disconnects. + + // The 'canceller' thread will cancel this token when the client disconnects. + let cancellation_token = CancellationToken::new(); + let token_clone = cancellation_token.clone(); + + let (grpc_tx, grpc_rx) = mpsc::channel(4); + let (disconnect_tx, disconnect_rx) = oneshot::channel::<()>(); + let disconnect_channel = channel::DisconnectChannel { + signal_sender: Some(disconnect_tx), + inner_channel: grpc_rx, + }; + + let mut message_rx = self.message_tx.subscribe(); + let receiver_count = self.message_tx.receiver_count(); + tracing::debug!(message = "New room subscriber", room = ?subscribed_room, total_subscribers = %receiver_count); + + // This is the 'canceller' thread. + // + // This task will cancel the token when the client disconnects, which will shutdown + // the streaming thread (see below) and cause the broadcast::Receiver to drop. tokio::spawn(async move { - while let Ok(msg) = message_receiver.recv().await { + let _ = disconnect_rx.await; + tracing::trace!("Client disconnected, cancelling streaming thread"); + cancellation_token.cancel(); + }); + + // The logic for the streaming thread, extracted into a variable to help rustfmt. + let streaming_closure = async move { + while let Ok(msg) = message_rx.recv().await { let message_room = msg.room_uuid; let subscriber_rooms: Vec = cache_connection .lrange(subscriber, 0, -1) @@ -413,9 +448,24 @@ impl proto::chat_server::Chat for Chat { } } } + }; + + // This is the 'streamer' thread. + // + // This thread will receive all messages sent via the `SendMessage` RPC call, and + // mirror them to all subsribers. Without a canceller thread, a cancellation token + // and a hacky DisconnectChannel, this thread would never terminate, meaning there + // would soon be a thousand of hanging broadcast::Receivers with no real client. + tokio::spawn(async move { + tokio::select! { + _ = token_clone.cancelled() => { + tracing::debug!(message = "Streamer thread cancelled"); + } + _ = streaming_closure => {} + } }); - Ok(Response::new(ReceiverStream::new(grpc_rx))) + Ok(Response::new(disconnect_channel)) } type SubscribeToUserStream = ReceiverStream>; @@ -480,7 +530,7 @@ impl Chat { Ok(Self { persistence_pool, cache_client, - message_sender, + message_tx: message_sender, }) }