diff --git a/crates/ott-balancer-protocol/src/monolith.rs b/crates/ott-balancer-protocol/src/monolith.rs index 4b579592a..1bea847fc 100644 --- a/crates/ott-balancer-protocol/src/monolith.rs +++ b/crates/ott-balancer-protocol/src/monolith.rs @@ -11,6 +11,7 @@ use crate::{ClientId, RoomName}; #[typeshare] pub enum MsgB2M { Load(B2MLoad), + Unload(B2MUnload), Join(B2MJoin), Leave(B2MLeave), ClientMsg(B2MClientMsg), @@ -22,6 +23,12 @@ pub struct B2MLoad { pub room: RoomName, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[typeshare] +pub struct B2MUnload { + pub room: RoomName, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[typeshare] pub struct B2MJoin { diff --git a/server/clientmanager.ts b/server/clientmanager.ts index 18ca08e0e..b6bbae98f 100644 --- a/server/clientmanager.ts +++ b/server/clientmanager.ts @@ -229,39 +229,69 @@ function onBalancerDisconnect(conn: BalancerConnection) { } } -function onBalancerMessage(conn: BalancerConnection, message: MsgB2M) { +async function onBalancerMessage(conn: BalancerConnection, message: MsgB2M) { log.silly("balancer message: " + JSON.stringify(message)); - if (message.type === "join") { - const msg = message.payload; - const client = new BalancerClient(msg.room, msg.client, conn); - connections.push(client); - client.on("auth", onClientAuth); - client.on("message", onClientMessage); - client.on("disconnect", onClientDisconnect); - client.auth(msg.token); - } else if (message.type === "leave") { - const msg = message.payload; - const client = connections.find(c => c.id === msg.client); - if (client instanceof BalancerClient) { - client.leave(); - } else { - log.error( - `Balancer tried to make client leave that does not exist or is not a balancer client` - ); - } - } else if (message.type === "client_msg") { - const msg = message.payload; - const client = connections.find(c => c.id === msg.client_id); - if (client instanceof BalancerClient) { - client.receiveMessage(msg.payload as ClientMessage); - } else { - log.error( - `Balancer sent message for client that does not exist or is not a balancer client` - ); - } - } else { + + /** + * This is a type that maps the message type to the handler for that message type. + * + * Useful for handling enums that are discriminated by a string, like the ones generated by typeshare. + */ + type EnumHandler = { + [P in T["type"]]: ( + instruction: Extract + ) => Promise; + }; + + // the intersection type makes it so that it throws a compile error if all the enum variants aren't handled + const handlers: Record & EnumHandler = { + load: async message => { + const msg = message.payload; + await roommanager.getRoom(msg.room); + }, + unload: async message => { + const msg = message.payload; + await roommanager.unloadRoom(msg.room); + }, + join: async message => { + const msg = message.payload; + const client = new BalancerClient(msg.room, msg.client, conn); + connections.push(client); + client.on("auth", onClientAuth); + client.on("message", onClientMessage); + client.on("disconnect", onClientDisconnect); + client.auth(msg.token); + }, + leave: async message => { + const msg = message.payload; + const client = connections.find(c => c.id === msg.client); + if (client instanceof BalancerClient) { + client.leave(); + } else { + log.error( + `Balancer tried to make client leave that does not exist or is not a balancer client` + ); + } + }, + client_msg: async message => { + const msg = message.payload; + const client = connections.find(c => c.id === msg.client_id); + if (client instanceof BalancerClient) { + client.receiveMessage(msg.payload as ClientMessage); + } else { + log.error( + `Balancer sent message for client that does not exist or is not a balancer client` + ); + } + }, + }; + + const handler = handlers[message.type]; + if (!handler) { log.error(`Unknown balancer message type: ${(message as { type: string }).type}`); + return; } + await handler(message as any); // this cast is safe because the type is checked and narrowed above } function onBalancerError(conn: BalancerConnection, error: WebSocket.ErrorEvent) { diff --git a/server/generated.ts b/server/generated.ts index be32ea215..714c06fa5 100644 --- a/server/generated.ts +++ b/server/generated.ts @@ -12,6 +12,10 @@ export interface B2MLoad { room: RoomName; } +export interface B2MUnload { + room: RoomName; +} + export interface B2MJoin { room: RoomName; client: ClientId; @@ -88,6 +92,7 @@ export interface M2BKick { export type MsgB2M = | { type: "load"; payload: B2MLoad } + | { type: "unload"; payload: B2MUnload } | { type: "join"; payload: B2MJoin } | { type: "leave"; payload: B2MLeave } | { type: "client_msg"; payload: B2MClientMsg }; diff --git a/tools/balancer-tester/src/monolith.rs b/tools/balancer-tester/src/monolith.rs index bccaead99..c5d27cf14 100644 --- a/tools/balancer-tester/src/monolith.rs +++ b/tools/balancer-tester/src/monolith.rs @@ -58,6 +58,7 @@ impl SimMonolith { }; outbound_tx.send(self.build_message(msg)).await.unwrap(); } + MsgB2M::Unload(_) => todo!(), MsgB2M::Join(msg) => { let room = match self.rooms.get_mut(&msg.room) { Some(room) => room,