Skip to content

Commit

Permalink
EventBroadcaster: Add simple connect & close methods, clean the code …
Browse files Browse the repository at this point in the history
…a bit
  • Loading branch information
roekatz committed Sep 9, 2024
1 parent 671c189 commit ee3c2ee
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 169 deletions.
281 changes: 119 additions & 162 deletions fastapi_websocket_pubsub/event_broadcaster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
from typing import Any

from broadcaster import Broadcast
Expand All @@ -22,14 +23,6 @@ class BroadcastNotification(BaseModel):
data: Any = None


class EventBroadcasterException(Exception):
pass


class BroadcasterAlreadyStarted(EventBroadcasterException):
pass


class EventBroadcasterContextManager:
"""
Manages the context for the EventBroadcaster
Expand All @@ -56,56 +49,18 @@ def __init__(
self._listen: bool = listen

async def __aenter__(self):
async with self._event_broadcaster._context_manager_lock:
if self._listen:
self._event_broadcaster._listen_count += 1
if self._event_broadcaster._listen_count == 1:
# We have our first listener start the read-task for it (And all those who'd follow)
logger.info(
"Listening for incoming events from broadcast channel (first listener started)"
)
# Start task listening on incoming broadcasts
await self._event_broadcaster.start_reader_task()

if self._share:
self._event_broadcaster._share_count += 1
if self._event_broadcaster._share_count == 1:
# We have our first publisher
# Init the broadcast used for sharing (reading has its own)
logger.debug(
"Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
)
# Subscribe to internal events form our own event notifier and broadcast them
await self._event_broadcaster._subscribe_to_all_topics()
else:
logger.debug(
f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}"
)
return self
await self._event_broadcaster.connect(self._listen, self._share)

async def __aexit__(self, exc_type, exc, tb):
async with self._event_broadcaster._context_manager_lock:
try:
if self._listen:
self._event_broadcaster._listen_count -= 1
# if this was last listener - we can stop the reading task
if self._event_broadcaster._listen_count == 0:
# Cancel task reading broadcast subscriptions
if self._event_broadcaster._subscription_task is not None:
logger.info("Cancelling broadcast listen task")
self._event_broadcaster._subscription_task.cancel()
self._event_broadcaster._subscription_task = None

if self._share:
self._event_broadcaster._share_count -= 1
# if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
if self._event_broadcaster._share_count == 0:
# Unsubscribe from internal events
logger.debug("Unsubscribing from ALL TOPICS")
await self._event_broadcaster._unsubscribe_from_topics()

except:
logger.exception("Failed to exit EventBroadcaster context")
await self._event_broadcaster.close(self._listen, self._share)


class EventBroadcasterException(Exception):
pass


class BroadcasterAlreadyStarted(EventBroadcasterException):
pass


class EventBroadcaster:
Expand Down Expand Up @@ -137,62 +92,46 @@ def __init__(
broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
"""
# Broadcast init params
self._broadcast_url = broadcast_url
self._broadcast_type = broadcast_type or Broadcast
# Publish broadcast (initialized within async with statement)
self._sharing_broadcast_channel = None
# channel to operate on
self._channel = channel
# Async-io task for reading broadcasts (initialized within async with statement)
self._subscription_task = None
# Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
self._id = gen_uid()
# The internal events notifier
self._notifier = notifier
self._broadcast_channel = None
self._connect_lock = asyncio.Lock()
self._listen_refcount = 0
self._share_refcount = 0
self._is_publish_only = is_publish_only
self._publish_lock = None
# used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
self._listen_count: int = 0
self._share_count: int = 0
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
self._context_manager = None
self._context_manager_lock = asyncio.Lock()
self._tasks = set()
self.listening_broadcast_channel = None

async def __broadcast_notifications__(self, subscription: Subscription, data):
async def connect(self, listen=True, share=True):
"""
Share incoming internal notifications with the entire broadcast channel
Args:
subscription (Subscription): the subscription that got triggered
data: the event data
This connects the listening channel
"""
logger.info(
"Broadcasting incoming event: {}".format(
{"topic": subscription.topic, "notifier_id": self._id}
)
)
note = BroadcastNotification(
notifier_id=self._id, topics=[subscription.topic], data=data
)
async with self._connect_lock:
if listen:
await self._connect_listen()
self._listen_refcount += 1

# Publish event to broadcast
async with self._broadcast_type(
self._broadcast_url
) as sharing_broadcast_channel:
await sharing_broadcast_channel.publish(
self._channel, pydantic_serialize(note)
)
if share:
await self._connect_share()
self._share_refcount += 1

async def _subscribe_to_all_topics(self):
return await self._notifier.subscribe(
self._id, ALL_TOPICS, self.__broadcast_notifications__
)
async def close(self, listen=True, share=True):
async with self._connect_lock:
if listen:
await self._close_listen()
self._listen_refcount -= 1

if share:
await self._close_share()
self._share_refcount -= 1

async def __aenter__(self):
await self.connect(listen=not self._is_publish_only)

async def _unsubscribe_from_topics(self):
return await self._notifier.unsubscribe(self._id)
async def __aexit__(self, exc_type, exc, tb):
await self.close(listen=not self._is_publish_only)

def get_context(self, listen=True, share=True):
"""
Expand All @@ -213,97 +152,115 @@ def get_listening_context(self):
def get_sharing_context(self):
return EventBroadcasterContextManager(self, listen=False, share=True)

async def __aenter__(self):
async def __broadcast_notifications__(self, subscription: Subscription, data):
"""
Convince caller (also backward compaltability)
Share incoming internal notifications with the entire broadcast channel
Args:
subscription (Subscription): the subscription that got triggered
data: the event data
"""
if self._context_manager is None:
self._context_manager = self.get_context(listen=not self._is_publish_only)
return await self._context_manager.__aenter__()
logger.info(
"Broadcasting incoming event: {}".format(
{"topic": subscription.topic, "notifier_id": self._id}
)
)

async def __aexit__(self, exc_type, exc, tb):
await self._context_manager.__aexit__(exc_type, exc, tb)
note = BroadcastNotification(
notifier_id=self._id, topics=[subscription.topic], data=data
)

async def start_reader_task(self):
"""Spawn a task reading incoming broadcasts and posting them to the intreal notifier
Raises:
BroadcasterAlreadyStarted: if called more than once per context
Returns:
the spawned task
"""
# Make sure a task wasn't started already
if self._subscription_task is not None:
# we already started a task for this worker process
logger.debug(
"No need for listen task, already started broadcast listen task for this notifier"
# Publish event to broadcast using a new connection from connection pool
async with self._broadcast_type(
self._broadcast_url
) as sharing_broadcast_channel:
await sharing_broadcast_channel.publish(
self._channel, pydantic_serialize(note)
)
return

# Init new broadcast channel for reading
try:
if self.listening_broadcast_channel is None:
self.listening_broadcast_channel = self._broadcast_type(
self._broadcast_url
)
await self.listening_broadcast_channel.connect()
except Exception as e:
logger.error(
f"Failed to connect to broadcast channel for reading incoming events: {e}"
async def _connect_share(self):
if self._share_refcount == 0:
return await self._notifier.subscribe(
self._id, ALL_TOPICS, self.__broadcast_notifications__
)
raise e

# Trigger the task
logger.debug("Spawning broadcast listen task")
self._subscription_task = asyncio.create_task(self.__read_notifications__())
return self._subscription_task
async def _close_share(self):
if self._share_refcount == 1:
return await self._notifier.unsubscribe(self._id)

async def _connect_listen(self):
if self._listen_refcount == 0:
if self._listen_refcount == 0:
try:
self._broadcast_channel = self._broadcast_type(self._broadcast_url)
await self._broadcast_channel.connect()
except Exception as e:
logger.error(
f"Failed to connect to broadcast channel for reading incoming events: {e}"
)
raise e
self._subscription_task = asyncio.create_task(
self.__read_notifications__()
)
return await self._notifier.subscribe(
self._id, ALL_TOPICS, self.__broadcast_notifications__
)

async def _close_listen(self):
if self._listen_refcount == 1 and self._broadcast_channel is not None:
await self._broadcast_channel.disconnect()
await self.wait_until_done()
self._broadcast_channel = None

def get_reader_task(self):
return self._subscription_task

async def wait_until_done(self):
if self._subscription_task is not None:
await self._subscription_task
self._subscription_task = None

async def __read_notifications__(self):
"""
read incoming broadcasts and posting them to the intreal notifier
"""
logger.debug("Starting broadcaster listener")

notify_tasks = set()
try:
# Subscribe to our channel
async with self.listening_broadcast_channel.subscribe(
async with self._broadcast_channel.subscribe(
channel=self._channel
) as subscriber:
async for event in subscriber:
try:
notification = BroadcastNotification.parse_raw(event.message)
# Avoid re-publishing our own broadcasts
if notification.notifier_id != self._id:
logger.debug(
"Handling incoming broadcast event: {}".format(
{
"topics": notification.topics,
"src": notification.notifier_id,
}
)
notification = BroadcastNotification.parse_raw(event.message)
# Avoid re-publishing our own broadcasts
if notification.notifier_id != self._id:
logger.debug(
"Handling incoming broadcast event: {}".format(
{
"topics": notification.topics,
"src": notification.notifier_id,
}
)
# Notify subscribers of message received from broadcast
task = asyncio.create_task(
self._notifier.notify(
notification.topics,
notification.data,
notifier_id=self._id,
)
)
# Notify subscribers of message received from broadcast
task = asyncio.create_task(
self._notifier.notify(
notification.topics,
notification.data,
notifier_id=self._id,
)
)

self._tasks.add(task)
notify_tasks.add(task)

def cleanup(task):
self._tasks.remove(task)
def cleanup(t):
notify_tasks.remove(t)

task.add_done_callback(cleanup)
except:
logger.exception("Failed handling incoming broadcast")
task.add_done_callback(cleanup)
logger.info(
"No more events to read from subscriber (underlying connection closed)"
)
finally:
if self.listening_broadcast_channel is not None:
await self.listening_broadcast_channel.disconnect()
self.listening_broadcast_channel = None
await asyncio.gather(*notify_tasks, return_exceptions=True)
19 changes: 12 additions & 7 deletions fastapi_websocket_pubsub/pub_sub_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
on_connect: List[Coroutine] = None,
on_disconnect: List[Coroutine] = None,
rpc_channel_get_remote_id: bool = False,
ignore_broadcaster_disconnected = True,
ignore_broadcaster_disconnected: bool = True,
):
"""
The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications.
Expand Down Expand Up @@ -104,7 +104,7 @@ async def publish(self, topics: Union[TopicList, Topic], data=None):
logger.debug(f"Publishing message to topics: {topics}")
if self.broadcaster is not None:
logger.debug(f"Acquiring broadcaster sharing context")
async with self.broadcaster.get_context(listen=False, share=True):
async with self.broadcaster.get_sharing_context():
await self.notifier.notify(topics, data, notifier_id=self._id)
# otherwise just notify
else:
Expand Down Expand Up @@ -132,14 +132,19 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs)
async with self.broadcaster:
logger.debug("Entering endpoint's main loop with broadcaster")
if self._ignore_broadcaster_disconnected:
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
await self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
else:
main_loop_task = asyncio.create_task(
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
)
done, pending = await asyncio.wait(
[main_loop_task, self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED,
)
done, pending = await asyncio.wait([main_loop_task,
self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED)
logger.debug(f"task is done: {done}")
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
if main_loop_task in pending:
Expand Down

0 comments on commit ee3c2ee

Please sign in to comment.