diff --git a/fastapi_websocket_pubsub/event_broadcaster.py b/fastapi_websocket_pubsub/event_broadcaster.py index 4aa2d55..7635b11 100644 --- a/fastapi_websocket_pubsub/event_broadcaster.py +++ b/fastapi_websocket_pubsub/event_broadcaster.py @@ -1,4 +1,5 @@ import asyncio +import contextlib from typing import Any from broadcaster import Broadcast @@ -22,14 +23,6 @@ class BroadcastNotification(BaseModel): data: Any = None -class EventBroadcasterException(Exception): - pass - - -class BroadcasterAlreadyStarted(EventBroadcasterException): - pass - - class EventBroadcasterContextManager: """ Manages the context for the EventBroadcaster @@ -56,56 +49,18 @@ def __init__( self._listen: bool = listen async def __aenter__(self): - async with self._event_broadcaster._context_manager_lock: - if self._listen: - self._event_broadcaster._listen_count += 1 - if self._event_broadcaster._listen_count == 1: - # We have our first listener start the read-task for it (And all those who'd follow) - logger.info( - "Listening for incoming events from broadcast channel (first listener started)" - ) - # Start task listening on incoming broadcasts - await self._event_broadcaster.start_reader_task() - - if self._share: - self._event_broadcaster._share_count += 1 - if self._event_broadcaster._share_count == 1: - # We have our first publisher - # Init the broadcast used for sharing (reading has its own) - logger.debug( - "Subscribing to ALL_TOPICS, and sharing messages with broadcast channel" - ) - # Subscribe to internal events form our own event notifier and broadcast them - await self._event_broadcaster._subscribe_to_all_topics() - else: - logger.debug( - f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}" - ) - return self + await self._event_broadcaster.connect(self._listen, self._share) async def __aexit__(self, exc_type, exc, tb): - async with self._event_broadcaster._context_manager_lock: - try: - if self._listen: - self._event_broadcaster._listen_count -= 1 - # if this was last listener - we can stop the reading task - if self._event_broadcaster._listen_count == 0: - # Cancel task reading broadcast subscriptions - if self._event_broadcaster._subscription_task is not None: - logger.info("Cancelling broadcast listen task") - self._event_broadcaster._subscription_task.cancel() - self._event_broadcaster._subscription_task = None - - if self._share: - self._event_broadcaster._share_count -= 1 - # if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore - if self._event_broadcaster._share_count == 0: - # Unsubscribe from internal events - logger.debug("Unsubscribing from ALL TOPICS") - await self._event_broadcaster._unsubscribe_from_topics() - - except: - logger.exception("Failed to exit EventBroadcaster context") + await self._event_broadcaster.close(self._listen, self._share) + + +class EventBroadcasterException(Exception): + pass + + +class BroadcasterAlreadyStarted(EventBroadcasterException): + pass class EventBroadcaster: @@ -137,62 +92,46 @@ def __init__( broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast. is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False """ - # Broadcast init params self._broadcast_url = broadcast_url self._broadcast_type = broadcast_type or Broadcast - # Publish broadcast (initialized within async with statement) - self._sharing_broadcast_channel = None - # channel to operate on self._channel = channel - # Async-io task for reading broadcasts (initialized within async with statement) self._subscription_task = None - # Uniqueue instance id (used to avoid reading own notifications sent in broadcast) self._id = gen_uid() - # The internal events notifier self._notifier = notifier + self._broadcast_channel = None + self._connect_lock = asyncio.Lock() + self._listen_refcount = 0 + self._share_refcount = 0 self._is_publish_only = is_publish_only - self._publish_lock = None - # used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share) - self._listen_count: int = 0 - self._share_count: int = 0 - # If we opt to manage the context directly (i.e. call async with on the event broadcaster itself) - self._context_manager = None - self._context_manager_lock = asyncio.Lock() - self._tasks = set() - self.listening_broadcast_channel = None - async def __broadcast_notifications__(self, subscription: Subscription, data): + async def connect(self, listen=True, share=True): """ - Share incoming internal notifications with the entire broadcast channel - - Args: - subscription (Subscription): the subscription that got triggered - data: the event data + This connects the listening channel """ - logger.info( - "Broadcasting incoming event: {}".format( - {"topic": subscription.topic, "notifier_id": self._id} - ) - ) - note = BroadcastNotification( - notifier_id=self._id, topics=[subscription.topic], data=data - ) + async with self._connect_lock: + if listen: + await self._connect_listen() + self._listen_refcount += 1 - # Publish event to broadcast - async with self._broadcast_type( - self._broadcast_url - ) as sharing_broadcast_channel: - await sharing_broadcast_channel.publish( - self._channel, pydantic_serialize(note) - ) + if share: + await self._connect_share() + self._share_refcount += 1 - async def _subscribe_to_all_topics(self): - return await self._notifier.subscribe( - self._id, ALL_TOPICS, self.__broadcast_notifications__ - ) + async def close(self, listen=True, share=True): + async with self._connect_lock: + if listen: + await self._close_listen() + self._listen_refcount -= 1 + + if share: + await self._close_share() + self._share_refcount -= 1 + + async def __aenter__(self): + await self.connect(listen=not self._is_publish_only) - async def _unsubscribe_from_topics(self): - return await self._notifier.unsubscribe(self._id) + async def __aexit__(self, exc_type, exc, tb): + await self.close(listen=not self._is_publish_only) def get_context(self, listen=True, share=True): """ @@ -213,97 +152,115 @@ def get_listening_context(self): def get_sharing_context(self): return EventBroadcasterContextManager(self, listen=False, share=True) - async def __aenter__(self): + async def __broadcast_notifications__(self, subscription: Subscription, data): """ - Convince caller (also backward compaltability) + Share incoming internal notifications with the entire broadcast channel + + Args: + subscription (Subscription): the subscription that got triggered + data: the event data """ - if self._context_manager is None: - self._context_manager = self.get_context(listen=not self._is_publish_only) - return await self._context_manager.__aenter__() + logger.info( + "Broadcasting incoming event: {}".format( + {"topic": subscription.topic, "notifier_id": self._id} + ) + ) - async def __aexit__(self, exc_type, exc, tb): - await self._context_manager.__aexit__(exc_type, exc, tb) + note = BroadcastNotification( + notifier_id=self._id, topics=[subscription.topic], data=data + ) - async def start_reader_task(self): - """Spawn a task reading incoming broadcasts and posting them to the intreal notifier - Raises: - BroadcasterAlreadyStarted: if called more than once per context - Returns: - the spawned task - """ - # Make sure a task wasn't started already - if self._subscription_task is not None: - # we already started a task for this worker process - logger.debug( - "No need for listen task, already started broadcast listen task for this notifier" + # Publish event to broadcast using a new connection from connection pool + async with self._broadcast_type( + self._broadcast_url + ) as sharing_broadcast_channel: + await sharing_broadcast_channel.publish( + self._channel, pydantic_serialize(note) ) - return - # Init new broadcast channel for reading - try: - if self.listening_broadcast_channel is None: - self.listening_broadcast_channel = self._broadcast_type( - self._broadcast_url - ) - await self.listening_broadcast_channel.connect() - except Exception as e: - logger.error( - f"Failed to connect to broadcast channel for reading incoming events: {e}" + async def _connect_share(self): + if self._share_refcount == 0: + return await self._notifier.subscribe( + self._id, ALL_TOPICS, self.__broadcast_notifications__ ) - raise e - # Trigger the task - logger.debug("Spawning broadcast listen task") - self._subscription_task = asyncio.create_task(self.__read_notifications__()) - return self._subscription_task + async def _close_share(self): + if self._share_refcount == 1: + return await self._notifier.unsubscribe(self._id) + + async def _connect_listen(self): + if self._listen_refcount == 0: + if self._listen_refcount == 0: + try: + self._broadcast_channel = self._broadcast_type(self._broadcast_url) + await self._broadcast_channel.connect() + except Exception as e: + logger.error( + f"Failed to connect to broadcast channel for reading incoming events: {e}" + ) + raise e + self._subscription_task = asyncio.create_task( + self.__read_notifications__() + ) + return await self._notifier.subscribe( + self._id, ALL_TOPICS, self.__broadcast_notifications__ + ) + + async def _close_listen(self): + if self._listen_refcount == 1 and self._broadcast_channel is not None: + await self._broadcast_channel.disconnect() + await self.wait_until_done() + self._broadcast_channel = None def get_reader_task(self): return self._subscription_task + async def wait_until_done(self): + if self._subscription_task is not None: + await self._subscription_task + self._subscription_task = None + async def __read_notifications__(self): """ read incoming broadcasts and posting them to the intreal notifier """ logger.debug("Starting broadcaster listener") + + notify_tasks = set() try: # Subscribe to our channel - async with self.listening_broadcast_channel.subscribe( + async with self._broadcast_channel.subscribe( channel=self._channel ) as subscriber: async for event in subscriber: - try: - notification = BroadcastNotification.parse_raw(event.message) - # Avoid re-publishing our own broadcasts - if notification.notifier_id != self._id: - logger.debug( - "Handling incoming broadcast event: {}".format( - { - "topics": notification.topics, - "src": notification.notifier_id, - } - ) + notification = BroadcastNotification.parse_raw(event.message) + # Avoid re-publishing our own broadcasts + if notification.notifier_id != self._id: + logger.debug( + "Handling incoming broadcast event: {}".format( + { + "topics": notification.topics, + "src": notification.notifier_id, + } ) - # Notify subscribers of message received from broadcast - task = asyncio.create_task( - self._notifier.notify( - notification.topics, - notification.data, - notifier_id=self._id, - ) + ) + # Notify subscribers of message received from broadcast + task = asyncio.create_task( + self._notifier.notify( + notification.topics, + notification.data, + notifier_id=self._id, ) + ) - self._tasks.add(task) + notify_tasks.add(task) - def cleanup(task): - self._tasks.remove(task) + def cleanup(t): + notify_tasks.remove(t) - task.add_done_callback(cleanup) - except: - logger.exception("Failed handling incoming broadcast") + task.add_done_callback(cleanup) logger.info( "No more events to read from subscriber (underlying connection closed)" ) finally: - if self.listening_broadcast_channel is not None: - await self.listening_broadcast_channel.disconnect() - self.listening_broadcast_channel = None + await asyncio.gather(*notify_tasks, return_exceptions=True) diff --git a/fastapi_websocket_pubsub/pub_sub_server.py b/fastapi_websocket_pubsub/pub_sub_server.py index 83f3913..539f72a 100644 --- a/fastapi_websocket_pubsub/pub_sub_server.py +++ b/fastapi_websocket_pubsub/pub_sub_server.py @@ -34,7 +34,7 @@ def __init__( on_connect: List[Coroutine] = None, on_disconnect: List[Coroutine] = None, rpc_channel_get_remote_id: bool = False, - ignore_broadcaster_disconnected = True, + ignore_broadcaster_disconnected: bool = True, ): """ The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications. @@ -104,7 +104,7 @@ async def publish(self, topics: Union[TopicList, Topic], data=None): logger.debug(f"Publishing message to topics: {topics}") if self.broadcaster is not None: logger.debug(f"Acquiring broadcaster sharing context") - async with self.broadcaster.get_context(listen=False, share=True): + async with self.broadcaster.get_sharing_context(): await self.notifier.notify(topics, data, notifier_id=self._id) # otherwise just notify else: @@ -132,14 +132,19 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs) async with self.broadcaster: logger.debug("Entering endpoint's main loop with broadcaster") if self._ignore_broadcaster_disconnected: - await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs) + await self.endpoint.main_loop( + websocket, client_id=client_id, **kwargs + ) else: main_loop_task = asyncio.create_task( - self.endpoint.main_loop(websocket, client_id=client_id, **kwargs) + self.endpoint.main_loop( + websocket, client_id=client_id, **kwargs + ) + ) + done, pending = await asyncio.wait( + [main_loop_task, self.broadcaster.get_reader_task()], + return_when=asyncio.FIRST_COMPLETED, ) - done, pending = await asyncio.wait([main_loop_task, - self.broadcaster.get_reader_task()], - return_when=asyncio.FIRST_COMPLETED) logger.debug(f"task is done: {done}") # broadcaster's reader task is used by other endpoints and shouldn't be cancelled if main_loop_task in pending: