Skip to content

Commit

Permalink
Refactor EventBroadcaster
Browse files Browse the repository at this point in the history
  • Loading branch information
roekatz committed Sep 4, 2024
1 parent c757bd7 commit 804da6a
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 202 deletions.
270 changes: 86 additions & 184 deletions fastapi_websocket_pubsub/event_broadcaster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
from typing import Any

from broadcaster import Broadcast
Expand Down Expand Up @@ -30,84 +31,6 @@ class BroadcasterAlreadyStarted(EventBroadcasterException):
pass


class EventBroadcasterContextManager:
"""
Manages the context for the EventBroadcaster
Friend-like class of EventBroadcaster (accessing "protected" members )
"""

def __init__(
self,
event_broadcaster: "EventBroadcaster",
listen: bool = True,
share: bool = True,
) -> None:
"""
Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster
and if it subscribes to the internal notifier to share its events with the broadcaster
Args:
event_broadcaster (EventBroadcaster): the broadcaster we manage the context for.
share (bool, optional): Should we share events with the broadcaster. Defaults to True.
listen (bool, optional): Should we listen for incoming events from the broadcaster. Defaults to True.
"""
self._event_broadcaster = event_broadcaster
self._share: bool = share
self._listen: bool = listen

async def __aenter__(self):
async with self._event_broadcaster._context_manager_lock:
if self._listen:
self._event_broadcaster._listen_count += 1
if self._event_broadcaster._listen_count == 1:
# We have our first listener start the read-task for it (And all those who'd follow)
logger.info(
"Listening for incoming events from broadcast channel (first listener started)"
)
# Start task listening on incoming broadcasts
await self._event_broadcaster.start_reader_task()

if self._share:
self._event_broadcaster._share_count += 1
if self._event_broadcaster._share_count == 1:
# We have our first publisher
# Init the broadcast used for sharing (reading has its own)
logger.debug(
"Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
)
# Subscribe to internal events form our own event notifier and broadcast them
await self._event_broadcaster._subscribe_to_all_topics()
else:
logger.debug(
f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}"
)
return self

async def __aexit__(self, exc_type, exc, tb):
async with self._event_broadcaster._context_manager_lock:
try:
if self._listen:
self._event_broadcaster._listen_count -= 1
# if this was last listener - we can stop the reading task
if self._event_broadcaster._listen_count == 0:
# Cancel task reading broadcast subscriptions
if self._event_broadcaster._subscription_task is not None:
logger.info("Cancelling broadcast listen task")
self._event_broadcaster._subscription_task.cancel()
self._event_broadcaster._subscription_task = None

if self._share:
self._event_broadcaster._share_count -= 1
# if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
if self._event_broadcaster._share_count == 0:
# Unsubscribe from internal events
logger.debug("Unsubscribing from ALL TOPICS")
await self._event_broadcaster._unsubscribe_from_topics()

except:
logger.exception("Failed to exit EventBroadcaster context")


class EventBroadcaster:
"""
Bridge EventNotifier to work across processes and machines by sharing their events through a broadcasting channel
Expand Down Expand Up @@ -135,31 +58,57 @@ def __init__(
notifier (EventNotifier): the event notifier managing our internal events - which will be bridge via the broadcaster
channel (str, optional): Channel name. Defaults to "EventNotifier".
broadcast_type (Broadcast, optional): Broadcast class to use. None - Defaults to Broadcast.
is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
DEPRECATED is_publish_only (bool, optional): [For default context] Should the broadcaster only transmit events and not listen to any. Defaults to False
# TODO: Like that?
"""
# Broadcast init params
self._broadcast_url = broadcast_url
self._broadcast_type = broadcast_type or Broadcast
# Publish broadcast (initialized within async with statement)
self._sharing_broadcast_channel = None
# channel to operate on
self._channel = channel
# Async-io task for reading broadcasts (initialized within async with statement)
self._subscription_task = None
# Uniqueue instance id (used to avoid reading own notifications sent in broadcast)
self._id = gen_uid()
# The internal events notifier
self._notifier = notifier
self._is_publish_only = is_publish_only
self._publish_lock = None
# used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
self._listen_count: int = 0
self._share_count: int = 0
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
self._context_manager = None
self._context_manager_lock = asyncio.Lock()
self._tasks = set()
self.listening_broadcast_channel = None
self._broadcast_channel = None
self._connect_lock = asyncio.Lock()
self._refcount = 0
is_publish_only = is_publish_only # Depracated

async def connect(self):
async with self._connect_lock:
if self._refcount == 0: # TODO: Is that needed?
try:
self._broadcast_channel = self._broadcast_type(self._broadcast_url)
await self._broadcast_channel.connect()
except Exception as e:
logger.error(
f"Failed to connect to broadcast channel for reading incoming events: {e}"
)
raise e
await self._subscribe_notifier()
self._subscription_task = asyncio.create_task(
self.__read_notifications__()
)
self._refcount += 1

async def _close(self):
if self._broadcast_channel is not None:
await self._unsubscribe_notifier()
await self._broadcast_channel.disconnect()
await self.wait_until_done()
self._broadcast_channel = None

async def close(self):
async with self._connect_lock:
if self._refcount == 0:
return
self._refcount -= 1
if self._refcount == 0:
await self._close()

async def __aenter__(self):
await self.connect()

async def __aexit__(self, exc_type, exc, tb):
await self.close()

async def __broadcast_notifications__(self, subscription: Subscription, data):
"""
Expand All @@ -174,136 +123,89 @@ async def __broadcast_notifications__(self, subscription: Subscription, data):
{"topic": subscription.topic, "notifier_id": self._id}
)
)

note = BroadcastNotification(
notifier_id=self._id, topics=[subscription.topic], data=data
)

# Publish event to broadcast
# Publish event to broadcast using a new connection from connection pool
async with self._broadcast_type(
self._broadcast_url
) as sharing_broadcast_channel:
await sharing_broadcast_channel.publish(
self._channel, pydantic_serialize(note)
)

async def _subscribe_to_all_topics(self):
async def _subscribe_notifier(self):
return await self._notifier.subscribe(
self._id, ALL_TOPICS, self.__broadcast_notifications__
)

async def _unsubscribe_from_topics(self):
async def _unsubscribe_notifier(self):
return await self._notifier.unsubscribe(self._id)

def get_context(self, listen=True, share=True):
"""
Create a new context manager you can call 'async with' on, configuring the broadcaster for listening, sharing, or both.
Args:
listen (bool, optional): Should we listen for events incoming from the broadcast channel. Defaults to True.
share (bool, optional): Should we share events with the broadcast channel. Defaults to True.
Returns:
EventBroadcasterContextManager: the context
"""
return EventBroadcasterContextManager(self, listen=listen, share=share)
"""Backward compatibility for the old interface"""
return self

def get_listening_context(self):
return EventBroadcasterContextManager(self, listen=True, share=False)
"""Backward compatibility for the old interface"""
return self

def get_sharing_context(self):
return EventBroadcasterContextManager(self, listen=False, share=True)

async def __aenter__(self):
"""
Convince caller (also backward compaltability)
"""
if self._context_manager is None:
self._context_manager = self.get_context(listen=not self._is_publish_only)
return await self._context_manager.__aenter__()

async def __aexit__(self, exc_type, exc, tb):
await self._context_manager.__aexit__(exc_type, exc, tb)

async def start_reader_task(self):
"""Spawn a task reading incoming broadcasts and posting them to the intreal notifier
Raises:
BroadcasterAlreadyStarted: if called more than once per context
Returns:
the spawned task
"""
# Make sure a task wasn't started already
if self._subscription_task is not None:
# we already started a task for this worker process
logger.debug(
"No need for listen task, already started broadcast listen task for this notifier"
)
return

# Init new broadcast channel for reading
try:
if self.listening_broadcast_channel is None:
self.listening_broadcast_channel = self._broadcast_type(
self._broadcast_url
)
await self.listening_broadcast_channel.connect()
except Exception as e:
logger.error(
f"Failed to connect to broadcast channel for reading incoming events: {e}"
)
raise e

# Trigger the task
logger.debug("Spawning broadcast listen task")
self._subscription_task = asyncio.create_task(self.__read_notifications__())
return self._subscription_task
"""Backward compatibility for the old interface"""
return self

def get_reader_task(self):
return self._subscription_task

async def wait_until_done(self):
if self._subscription_task is not None:
await self._subscription_task
self._subscription_task = None

async def __read_notifications__(self):
"""
read incoming broadcasts and posting them to the intreal notifier
"""
logger.debug("Starting broadcaster listener")

notify_tasks = set()
try:
# Subscribe to our channel
async with self.listening_broadcast_channel.subscribe(
async with self._broadcast_channel.subscribe(
channel=self._channel
) as subscriber:
async for event in subscriber:
try:
notification = BroadcastNotification.parse_raw(event.message)
# Avoid re-publishing our own broadcasts
if notification.notifier_id != self._id:
logger.debug(
"Handling incoming broadcast event: {}".format(
{
"topics": notification.topics,
"src": notification.notifier_id,
}
)
notification = BroadcastNotification.parse_raw(event.message)
# Avoid re-publishing our own broadcasts
if notification.notifier_id != self._id:
logger.debug(
"Handling incoming broadcast event: {}".format(
{
"topics": notification.topics,
"src": notification.notifier_id,
}
)
# Notify subscribers of message received from broadcast
task = asyncio.create_task(
self._notifier.notify(
notification.topics,
notification.data,
notifier_id=self._id,
)
)
# Notify subscribers of message received from broadcast
task = asyncio.create_task(
self._notifier.notify(
notification.topics,
notification.data,
notifier_id=self._id,
)
)

self._tasks.add(task)
notify_tasks.add(task)

def cleanup(task):
self._tasks.remove(task)
def cleanup(t):
notify_tasks.remove(t)

task.add_done_callback(cleanup)
except:
logger.exception("Failed handling incoming broadcast")
task.add_done_callback(cleanup)
logger.info(
"No more events to read from subscriber (underlying connection closed)"
)
finally:
if self.listening_broadcast_channel is not None:
await self.listening_broadcast_channel.disconnect()
self.listening_broadcast_channel = None
# TODO: return_exceptions?
await asyncio.gather(*notify_tasks, return_exceptions=True)
34 changes: 16 additions & 18 deletions fastapi_websocket_pubsub/pub_sub_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
on_connect: List[Coroutine] = None,
on_disconnect: List[Coroutine] = None,
rpc_channel_get_remote_id: bool = False,
ignore_broadcaster_disconnected = True,
ignore_broadcaster_disconnected=True,
):
"""
The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications.
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
broadcaster
if isinstance(broadcaster, EventBroadcaster) or broadcaster is None
else EventBroadcaster(broadcaster, self.notifier)
)
) # TODO: Connect broadcaster if needed
self.methods = (
methods_class(self.notifier)
if methods_class is not None
Expand Down Expand Up @@ -128,22 +128,20 @@ async def on_disconnect(self, channel: RpcChannel):
await self.notifier.unsubscribe(subscriber_id)

async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs):
if self.broadcaster is not None:
async with self.broadcaster:
logger.debug("Entering endpoint's main loop with broadcaster")
if self._ignore_broadcaster_disconnected:
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
else:
main_loop_task = asyncio.create_task(
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
)
done, pending = await asyncio.wait([main_loop_task,
self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED)
logger.debug(f"task is done: {done}")
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
if main_loop_task in pending:
main_loop_task.cancel()
# TODO: Maybe just connect the broadcaster in a single place? then raise an event.
if self.broadcaster is not None and not self._ignore_broadcaster_disconnected:
logger.debug("Entering endpoint's main loop with broadcaster")
main_loop_task = asyncio.create_task(
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
)
done, pending = await asyncio.wait(
[main_loop_task, self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED,
)
logger.debug(f"task is done: {done}")
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
if main_loop_task in pending:
main_loop_task.cancel()
else:
logger.debug("Entering endpoint's main loop without broadcaster")
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
Expand Down

0 comments on commit 804da6a

Please sign in to comment.