Skip to content

Commit

Permalink
fix(server): Prevent broadcast::Receiver leaks with a custom channel
Browse files Browse the repository at this point in the history
  • Loading branch information
mxxntype committed May 21, 2024
1 parent 1e06d92 commit a6fd379
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 24 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 19 additions & 12 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,6 @@ async fn existing_room(
.trim_matches(|c| c == '(' || c == ')');
let chosen_room = Uuid::from_str(chosen_room).unwrap();

let mut message_stream = chat
.subscribe_to_room(proto::Uuid::from(chosen_room))
.await
.unwrap()
.into_inner();

let room_action = Listbox::new(["Send new messages", "Listen to messages"])
.title("Would you like to send new messages or listen to incoming ones?")
.prompt()
Expand Down Expand Up @@ -173,6 +167,7 @@ async fn existing_room(
},

"Listen to messages" => {
// Print older, in-database messages.
let messages = chat
.list_messages(Into::<proto::Uuid>::into(chosen_room))
.await
Expand All @@ -181,12 +176,24 @@ async fn existing_room(
.messages;

for msg in messages.into_iter() {
print_message(msg);
print_message(&msg);
}

while let Ok(event) = message_stream.next().await.unwrap() {
// Print live messages.
let mut message_stream = chat
.subscribe_to_room(proto::Uuid::from(chosen_room))
.await
.unwrap()
.into_inner();

'message_listener: while let Ok(event) = message_stream.next().await.unwrap() {
match event.event.unwrap() {
Event::NewMessage(msg) => print_message(msg),
Event::NewMessage(msg) => {
print_message(&msg);
if msg.text.as_str() == "exit" {
break 'message_listener;
}
}
}
}
}
Expand All @@ -195,11 +202,11 @@ async fn existing_room(
}
}

fn print_message(msg: ServersideMessage) {
fn print_message(msg: &ServersideMessage) {
println!(
"{} | {}: {}",
msg.timestamp.unwrap().blue(),
msg.sender_uuid.unwrap().uuid.green(),
msg.timestamp.clone().unwrap().blue(),
msg.sender_uuid.clone().unwrap().uuid.green(),
msg.text
);
}
3 changes: 2 additions & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ rand_core = "0.6.4"
redis = { version = "0.25.3", features = ["uuid", "tokio-comp", "aio"] }
thiserror = "1.0.61"
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }
tokio-stream = "0.1.15"
tokio-stream = { version = "0.1.15", features = ["sync"] }
tokio-util = "0.7.11"
tonic = "0.11.0"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
Expand Down
40 changes: 40 additions & 0 deletions server/src/channel/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//! # A wrapper around a [`mpsc`] channel that detects disconnects.
//!
//! Implements the [`Deref`] trait (`Target = mpsc::Receiver<T>`), and uses a [`oneshot`]
//! channel to send a single message back when the whole thing gets dropped.
//!
//! **Source & further reading:** <https://github.com/hyperium/tonic/issues/377>
use futures::Stream;
use std::task::{Context, Poll};
use std::{ops::Deref, pin::Pin};
use tokio::sync::{mpsc, oneshot};

pub struct DisconnectChannel<T> {
pub(crate) signal_sender: Option<oneshot::Sender<()>>,
pub(crate) inner_channel: mpsc::Receiver<T>,
}

impl<T> Stream for DisconnectChannel<T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner_channel).poll_recv(cx)
}
}

impl<T> Deref for DisconnectChannel<T> {
type Target = mpsc::Receiver<T>;

fn deref(&self) -> &Self::Target {
&self.inner_channel
}
}

impl<T> Drop for DisconnectChannel<T> {
fn drop(&mut self) {
if let Some(drop_signal) = self.signal_sender.take() {
let _ = drop_signal.send(());
}
}
}
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(clippy::unwrap_used)]

pub mod auth;
pub mod channel;
pub mod entities;
pub mod persistence;
pub mod services;
Expand Down
72 changes: 61 additions & 11 deletions server/src/services/chat.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
use crate::auth::AuthenticatedRequest;
use crate::channel::DisconnectChannel;
use crate::entities::{Message, Room, User};
use crate::proto::user_lookup_request::Identifier;
use crate::proto::{ClientsideMessage, ClientsideRoom, MessageList, RoomList};
use crate::proto::{RoomWithUserCreationRequest, UserLookupRequest};
use crate::proto::{ServersideMessage, ServersideRoom, ServersideRoomEvent, ServersideUserEvent};
use crate::{persistence, proto};
use crate::{channel, persistence, proto};
use diesel::r2d2::{ConnectionManager, PooledConnection};
use diesel::PgConnection;
use redis::aio::MultiplexedConnection;
use redis::{AsyncCommands, Client, RedisResult};
use std::env;
use tokio::sync::{broadcast, mpsc};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::{Request, Response, Status};
use uuid::Uuid;

Expand All @@ -22,7 +24,7 @@ pub struct Chat {
cache_client: redis::Client,

// Message passing channel.
message_sender: broadcast::Sender<Message>,
message_tx: broadcast::Sender<Message>,
}

#[tonic::async_trait]
Expand Down Expand Up @@ -268,7 +270,7 @@ impl proto::chat_server::Chat for Chat {
Status::internal("Could not send the message due to an internal error")
})?;

match self.message_sender.send(message) {
match self.message_tx.send(message) {
Ok(recv_count) => tracing::trace!(message = "Broadcasting message", ?recv_count),
Err(err) => tracing::error!(message = "Could not broadcast message", ?err),
}
Expand Down Expand Up @@ -344,9 +346,9 @@ impl proto::chat_server::Chat for Chat {
Ok(Response::new(private_room_uuid.into()))
}

type SubscribeToRoomStream = ReceiverStream<Result<ServersideRoomEvent, Status>>;
type SubscribeToRoomStream = DisconnectChannel<Result<ServersideRoomEvent, Status>>;

#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, a = 1)]
async fn subscribe_to_room(
&self,
request: Request<proto::Uuid>,
Expand Down Expand Up @@ -376,12 +378,45 @@ impl proto::chat_server::Chat for Chat {
));
}

let (grpc_tx, grpc_rx) = mpsc::channel(4);
// The 'streamer' thread (see below) needs a cache connection.
let mut cache_connection = self.acquire_cache_connection().await?;
let mut message_receiver = self.message_sender.subscribe();

// NOTE: Read this.
//
// There are a total of 3 channels involved in this whole streaming thing:
// - An internal `broadcast` channel that transfers messages from `SendMessage` RPC calls;
// - A `DisconnectChannel`, which holds another 2 channels inside:
// - A `mpsc` Tokio channel, which performs gRPC streaming;
// - A `oneshot` Tokio channel, which fires when the client disconnects.

// The 'canceller' thread will cancel this token when the client disconnects.
let cancellation_token = CancellationToken::new();
let token_clone = cancellation_token.clone();

let (grpc_tx, grpc_rx) = mpsc::channel(4);
let (disconnect_tx, disconnect_rx) = oneshot::channel::<()>();
let disconnect_channel = channel::DisconnectChannel {
signal_sender: Some(disconnect_tx),
inner_channel: grpc_rx,
};

let mut message_rx = self.message_tx.subscribe();
let receiver_count = self.message_tx.receiver_count();
tracing::debug!(message = "New room subscriber", room = ?subscribed_room, total_subscribers = %receiver_count);

// This is the 'canceller' thread.
//
// This task will cancel the token when the client disconnects, which will shutdown
// the streaming thread (see below) and cause the broadcast::Receiver to drop.
tokio::spawn(async move {
while let Ok(msg) = message_receiver.recv().await {
let _ = disconnect_rx.await;
tracing::trace!("Client disconnected, cancelling streaming thread");
cancellation_token.cancel();
});

// The logic for the streaming thread, extracted into a variable to help rustfmt.
let streaming_closure = async move {
while let Ok(msg) = message_rx.recv().await {
let message_room = msg.room_uuid;
let subscriber_rooms: Vec<Uuid> = cache_connection
.lrange(subscriber, 0, -1)
Expand Down Expand Up @@ -413,9 +448,24 @@ impl proto::chat_server::Chat for Chat {
}
}
}
};

// This is the 'streamer' thread.
//
// This thread will receive all messages sent via the `SendMessage` RPC call, and
// mirror them to all subsribers. Without a canceller thread, a cancellation token
// and a hacky DisconnectChannel, this thread would never terminate, meaning there
// would soon be a thousand of hanging broadcast::Receivers with no real client.
tokio::spawn(async move {
tokio::select! {
_ = token_clone.cancelled() => {
tracing::debug!(message = "Streamer thread cancelled");
}
_ = streaming_closure => {}
}
});

Ok(Response::new(ReceiverStream::new(grpc_rx)))
Ok(Response::new(disconnect_channel))
}

type SubscribeToUserStream = ReceiverStream<Result<ServersideUserEvent, Status>>;
Expand Down Expand Up @@ -480,7 +530,7 @@ impl Chat {
Ok(Self {
persistence_pool,
cache_client,
message_sender,
message_tx: message_sender,
})
}

Expand Down

0 comments on commit a6fd379

Please sign in to comment.