From 7f7b9d893ab516c23c77ccb39aabc1a4f60235f9 Mon Sep 17 00:00:00 2001 From: Ro'e Katz Date: Wed, 4 Sep 2024 18:34:25 +0300 Subject: [PATCH] Refactor EventBroadcaster --- fastapi_websocket_pubsub/event_broadcaster.py | 270 ++++++------------ fastapi_websocket_pubsub/pub_sub_server.py | 4 +- 2 files changed, 88 insertions(+), 186 deletions(-) diff --git a/fastapi_websocket_pubsub/event_broadcaster.py b/fastapi_websocket_pubsub/event_broadcaster.py index 4aa2d55..124a67a 100644 --- a/fastapi_websocket_pubsub/event_broadcaster.py +++ b/fastapi_websocket_pubsub/event_broadcaster.py @@ -1,4 +1,5 @@ import asyncio +import contextlib from typing import Any from broadcaster import Broadcast @@ -30,84 +31,6 @@ class BroadcasterAlreadyStarted(EventBroadcasterException): pass -class EventBroadcasterContextManager: - """ - Manages the context for the EventBroadcaster - Friend-like class of EventBroadcaster (accessing "protected" members ) - """ - - def __init__( - self, - event_broadcaster: "EventBroadcaster", - listen: bool = True, - share: bool = True, - ) -> None: - """ - Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster - and if it subscribes to the internal notifier to share its events with the broadcaster - - Args: - event_broadcaster (EventBroadcaster): the broadcaster we manage the context for. - share (bool, optional): Should we share events with the broadcaster. Defaults to True. - listen (bool, optional): Should we listen for incoming events from the broadcaster. Defaults to True. - """ - self._event_broadcaster = event_broadcaster - self._share: bool = share - self._listen: bool = listen - - async def __aenter__(self): - async with self._event_broadcaster._context_manager_lock: - if self._listen: - self._event_broadcaster._listen_count += 1 - if self._event_broadcaster._listen_count == 1: - # We have our first listener start the read-task for it (And all those who'd follow) - logger.info( - "Listening for incoming events from broadcast channel (first listener started)" - ) - # Start task listening on incoming broadcasts - await self._event_broadcaster.start_reader_task() - - if self._share: - self._event_broadcaster._share_count += 1 - if self._event_broadcaster._share_count == 1: - # We have our first publisher - # Init the broadcast used for sharing (reading has its own) - logger.debug( - "Subscribing to ALL_TOPICS, and sharing messages with broadcast channel" - ) - # Subscribe to internal events form our own event notifier and broadcast them - await self._event_broadcaster._subscribe_to_all_topics() - else: - logger.debug( - f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}" - ) - return self - - async def __aexit__(self, exc_type, exc, tb): - async with self._event_broadcaster._context_manager_lock: - try: - if self._listen: - self._event_broadcaster._listen_count -= 1 - # if this was last listener - we can stop the reading task - if self._event_broadcaster._listen_count == 0: - # Cancel task reading broadcast subscriptions - if self._event_broadcaster._subscription_task is not None: - logger.info("Cancelling broadcast listen task") - self._event_broadcaster._subscription_task.cancel() - self._event_broadcaster._subscription_task = None - - if self._share: - self._event_broadcaster._share_count -= 1 - # if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore - if self._event_broadcaster._share_count == 0: - # Unsubscribe from internal events - logger.debug("Unsubscribing from ALL TOPICS") - await self._event_broadcaster._unsubscribe_from_topics() - - except: - logger.exception("Failed to exit EventBroadcaster context") - - class EventBroadcaster: """ Bridge EventNotifier to work across processes and machines by sharing their events through a broadcasting channel @@ -135,31 +58,57 @@ def __init__( notifier (EventNotifier): the event notifier managing our internal events - which will be bridge via the broadcaster channel (str, optional): Channel name. Defaults to "EventNotifier". broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast. - is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False + DEPRECATED is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False + # TODO: Like that? """ - # Broadcast init params self._broadcast_url = broadcast_url self._broadcast_type = broadcast_type or Broadcast - # Publish broadcast (initialized within async with statement) - self._sharing_broadcast_channel = None - # channel to operate on self._channel = channel - # Async-io task for reading broadcasts (initialized within async with statement) self._subscription_task = None - # Uniqueue instance id (used to avoid reading own notifications sent in broadcast) self._id = gen_uid() - # The internal events notifier self._notifier = notifier - self._is_publish_only = is_publish_only - self._publish_lock = None - # used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share) - self._listen_count: int = 0 - self._share_count: int = 0 - # If we opt to manage the context directly (i.e. call async with on the event broadcaster itself) - self._context_manager = None - self._context_manager_lock = asyncio.Lock() - self._tasks = set() - self.listening_broadcast_channel = None + self._broadcast_channel = None + self._connect_lock = asyncio.Lock() + self._refcount = 0 + is_publish_only = is_publish_only # Depracated + + async def connect(self): + async with self._connect_lock: + if self._refcount == 0: # TODO: Is that needed? + try: + self._broadcast_channel = self._broadcast_type(self._broadcast_url) + await self._broadcast_channel.connect() + except Exception as e: + logger.error( + f"Failed to connect to broadcast channel for reading incoming events: {e}" + ) + raise e + await self._subscribe_notifier() + self._subscription_task = asyncio.create_task( + self.__read_notifications__() + ) + self._refcount += 1 + + async def _close(self): + if self._broadcast_channel is not None: + await self._unsubscribe_notifier() + await self._broadcast_channel.disconnect() + await self.wait_until_done() + self._broadcast_channel = None + + async def close(self): + async with self._connect_lock: + if self._refcount == 0: + return + self._refcount -= 1 + if self._refcount == 0: + await self._close() + + async def __aenter__(self): + await self.connect() + + async def __aexit__(self, exc_type, exc, tb): + await self.close() async def __broadcast_notifications__(self, subscription: Subscription, data): """ @@ -174,11 +123,12 @@ async def __broadcast_notifications__(self, subscription: Subscription, data): {"topic": subscription.topic, "notifier_id": self._id} ) ) + note = BroadcastNotification( notifier_id=self._id, topics=[subscription.topic], data=data ) - # Publish event to broadcast + # Publish event to broadcast using a new connection from connection pool async with self._broadcast_type( self._broadcast_url ) as sharing_broadcast_channel: @@ -186,124 +136,76 @@ async def __broadcast_notifications__(self, subscription: Subscription, data): self._channel, pydantic_serialize(note) ) - async def _subscribe_to_all_topics(self): + async def _subscribe_notifier(self): return await self._notifier.subscribe( self._id, ALL_TOPICS, self.__broadcast_notifications__ ) - async def _unsubscribe_from_topics(self): + async def _unsubscribe_notifier(self): return await self._notifier.unsubscribe(self._id) def get_context(self, listen=True, share=True): - """ - Create a new context manager you can call 'async with' on, configuring the broadcaster for listening, sharing, or both. - - Args: - listen (bool, optional): Should we listen for events incoming from the broadcast channel. Defaults to True. - share (bool, optional): Should we share events with the broadcast channel. Defaults to True. - - Returns: - EventBroadcasterContextManager: the context - """ - return EventBroadcasterContextManager(self, listen=listen, share=share) + """Backward compatibility for the old interface""" + return self def get_listening_context(self): - return EventBroadcasterContextManager(self, listen=True, share=False) + """Backward compatibility for the old interface""" + return self def get_sharing_context(self): - return EventBroadcasterContextManager(self, listen=False, share=True) - - async def __aenter__(self): - """ - Convince caller (also backward compaltability) - """ - if self._context_manager is None: - self._context_manager = self.get_context(listen=not self._is_publish_only) - return await self._context_manager.__aenter__() - - async def __aexit__(self, exc_type, exc, tb): - await self._context_manager.__aexit__(exc_type, exc, tb) - - async def start_reader_task(self): - """Spawn a task reading incoming broadcasts and posting them to the intreal notifier - Raises: - BroadcasterAlreadyStarted: if called more than once per context - Returns: - the spawned task - """ - # Make sure a task wasn't started already - if self._subscription_task is not None: - # we already started a task for this worker process - logger.debug( - "No need for listen task, already started broadcast listen task for this notifier" - ) - return - - # Init new broadcast channel for reading - try: - if self.listening_broadcast_channel is None: - self.listening_broadcast_channel = self._broadcast_type( - self._broadcast_url - ) - await self.listening_broadcast_channel.connect() - except Exception as e: - logger.error( - f"Failed to connect to broadcast channel for reading incoming events: {e}" - ) - raise e - - # Trigger the task - logger.debug("Spawning broadcast listen task") - self._subscription_task = asyncio.create_task(self.__read_notifications__()) - return self._subscription_task + """Backward compatibility for the old interface""" + return self def get_reader_task(self): return self._subscription_task + async def wait_until_done(self): + if self._subscription_task is not None: + await self._subscription_task + self._subscription_task = None + async def __read_notifications__(self): """ read incoming broadcasts and posting them to the intreal notifier """ logger.debug("Starting broadcaster listener") + + notify_tasks = set() try: # Subscribe to our channel - async with self.listening_broadcast_channel.subscribe( + async with self._broadcast_channel.subscribe( channel=self._channel ) as subscriber: async for event in subscriber: - try: - notification = BroadcastNotification.parse_raw(event.message) - # Avoid re-publishing our own broadcasts - if notification.notifier_id != self._id: - logger.debug( - "Handling incoming broadcast event: {}".format( - { - "topics": notification.topics, - "src": notification.notifier_id, - } - ) + notification = BroadcastNotification.parse_raw(event.message) + # Avoid re-publishing our own broadcasts + if notification.notifier_id != self._id: + logger.debug( + "Handling incoming broadcast event: {}".format( + { + "topics": notification.topics, + "src": notification.notifier_id, + } ) - # Notify subscribers of message received from broadcast - task = asyncio.create_task( - self._notifier.notify( - notification.topics, - notification.data, - notifier_id=self._id, - ) + ) + # Notify subscribers of message received from broadcast + task = asyncio.create_task( + self._notifier.notify( + notification.topics, + notification.data, + notifier_id=self._id, ) + ) - self._tasks.add(task) + notify_tasks.add(task) - def cleanup(task): - self._tasks.remove(task) + def cleanup(t): + notify_tasks.remove(t) - task.add_done_callback(cleanup) - except: - logger.exception("Failed handling incoming broadcast") + task.add_done_callback(cleanup) logger.info( "No more events to read from subscriber (underlying connection closed)" ) finally: - if self.listening_broadcast_channel is not None: - await self.listening_broadcast_channel.disconnect() - self.listening_broadcast_channel = None + # TODO: return_exceptions? + await asyncio.gather(*notify_tasks, return_exceptions=True) diff --git a/fastapi_websocket_pubsub/pub_sub_server.py b/fastapi_websocket_pubsub/pub_sub_server.py index 83f3913..0a861a3 100644 --- a/fastapi_websocket_pubsub/pub_sub_server.py +++ b/fastapi_websocket_pubsub/pub_sub_server.py @@ -34,7 +34,7 @@ def __init__( on_connect: List[Coroutine] = None, on_disconnect: List[Coroutine] = None, rpc_channel_get_remote_id: bool = False, - ignore_broadcaster_disconnected = True, + ignore_broadcaster_disconnected: bool = True, ): """ The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications. @@ -104,7 +104,7 @@ async def publish(self, topics: Union[TopicList, Topic], data=None): logger.debug(f"Publishing message to topics: {topics}") if self.broadcaster is not None: logger.debug(f"Acquiring broadcaster sharing context") - async with self.broadcaster.get_context(listen=False, share=True): + async with self.broadcaster: await self.notifier.notify(topics, data, notifier_id=self._id) # otherwise just notify else: