From ffdcc8b414cc6fb1784b3aed810142eb1524ff3b Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Mon, 9 Dec 2024 21:17:15 -0500 Subject: [PATCH] add streamer reconnection callbacks --- docs/account-streamer.rst | 22 ++++++++++ docs/data-streamer.rst | 19 +++++++++ tastytrade/streamer.py | 90 ++++++++++++++++++++++++++++++--------- 3 files changed, 112 insertions(+), 19 deletions(-) diff --git a/docs/account-streamer.rst b/docs/account-streamer.rst index 9b118b0..bb4f0be 100644 --- a/docs/account-streamer.rst +++ b/docs/account-streamer.rst @@ -1,6 +1,9 @@ Account Streamer ================ +Basic usage +----------- + The account streamer is used to track account-level updates, such as order fills, watchlist updates and quote alerts. Typically, you'll want a separate task running for the account streamer, which can then notify your application about important events. @@ -35,3 +38,22 @@ Probably the most important information the account streamer handles is order fi async for order in streamer.listen(PlacedOrder): print(order) + +Retry callback +-------------- + +The account streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever alerts you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection). +The callback function should look something like this: + +.. code-block:: python + + async def callback(streamer: AlertStreamer, arg1, arg2): + await streamer.subscribe_quote_alerts() + +The requirements are that the first parameter be the `AlertStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use. +This callback can then be used when creating the streamer: + +.. code-block:: python + + async with AlertStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer: + # ... diff --git a/docs/data-streamer.rst b/docs/data-streamer.rst index 4d39680..1353d74 100644 --- a/docs/data-streamer.rst +++ b/docs/data-streamer.rst @@ -146,3 +146,22 @@ Now, we can access the quotes and greeks at any time, and they'll be up-to-date print(live_prices.quotes[symbol], live_prices.greeks[symbol]) >>> Quote(eventSymbol='.SPY230721C387', eventTime=0, sequence=0, timeNanoPart=0, bidTime=1689365699000, bidExchangeCode='X', bidPrice=62.01, bidSize=50.0, askTime=1689365699000, askExchangeCode='X', askPrice=62.83, askSize=50.0) Greeks(eventSymbol='.SPY230721C387', eventTime=0, eventFlags=0, index=7255910303911641088, time=1689398266363, sequence=0, price=62.6049270064687, volatility=0.536152815048564, delta=0.971506591907638, gamma=0.001814464566110275, theta=-0.1440768557397271, rho=0.0831882577866199, vega=0.0436861878838861) + +Retry callback +-------------- + +The data streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever events you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection). +The callback function should look something like this: + +.. code-block:: python + + async def callback(streamer: DXLinkStreamer, arg1, arg2): + await streamer.subscribe(Quote, ['SPY']) + +The requirements are that the first parameter be the `DXLinkStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use. +This callback can then be used when creating the streamer: + +.. code-block:: python + + async with DXLinkStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer: + # ... diff --git a/tastytrade/streamer.py b/tastytrade/streamer.py index 6debe66..3361427 100644 --- a/tastytrade/streamer.py +++ b/tastytrade/streamer.py @@ -6,7 +6,16 @@ from decimal import Decimal from enum import Enum from ssl import SSLContext, create_default_context -from typing import Any, AsyncIterator, Optional, Type, TypeVar, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Optional, + Type, + TypeVar, + Union, +) from pydantic import model_validator from websockets.asyncio.client import ClientConnection, connect @@ -181,15 +190,26 @@ class AlertStreamer: """ - def __init__(self, session: Session): + def __init__( + self, + session: Session, + reconnect_args: tuple[Any, ...] = (), + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ): #: The active session used to initiate the streamer or make requests self.token: str = session.session_token #: The base url for the streamer websocket self.base_url: str = CERT_STREAMER_URL if session.is_test else STREAMER_URL + #: An async function to be called upon reconnection. The first argument must be + #: of type `AlertStreamer` and will be a reference to the streamer object. + self.reconnect_fn = reconnect_fn + #: Variable number of arguments to pass to the reconnect function + self.reconnect_args = reconnect_args self._queues: dict[str, Queue] = defaultdict(Queue) self._websocket: Optional[ClientConnection] = None self._connect_task = asyncio.create_task(self._connect()) + self._reconnect_task = None async def __aenter__(self): time_out = 100 @@ -202,19 +222,27 @@ async def __aenter__(self): return self @classmethod - async def create(cls, session: Session) -> "AlertStreamer": - self = cls(session) + async def create( + cls, + session: Session, + *, + reconnect_args: tuple[Any, ...] = (), + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ) -> "AlertStreamer": + self = cls(session, reconnect_args=reconnect_args, reconnect_fn=reconnect_fn) return await self.__aenter__() async def __aexit__(self, *exc): - await self.close() + self.close() - async def close(self): + def close(self): """ - Closes the websocket connection and cancels the heartbeat task. + Closes the websocket connection and cancels the pending tasks. """ self._connect_task.cancel() self._heartbeat_task.cancel() + if self._reconnect_task is not None: + self._reconnect_task.cancel() async def _connect(self) -> None: """ @@ -222,11 +250,16 @@ async def _connect(self) -> None: token provided during initialization. """ headers = {"Authorization": f"Bearer {self.token}"} + reconnecting = False async for websocket in connect(self.base_url, additional_headers=headers): self._websocket = websocket self._heartbeat_task = asyncio.create_task(self._heartbeat()) logger.debug("Websocket connection established.") + if reconnecting and self.reconnect_fn is not None: + self._reconnect_task = asyncio.create_task( + self.reconnect_fn(self, *self.reconnect_args) + ) try: async for raw_message in websocket: logger.debug("raw message: %s", raw_message) @@ -236,6 +269,7 @@ async def _connect(self) -> None: await self._map_message(type_str, data["data"]) except ConnectionClosed: logger.debug("Websocket connection closed, retrying...") + reconnecting = True continue async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]: @@ -340,7 +374,11 @@ class DXLinkStreamer: """ def __init__( - self, session: Session, ssl_context: SSLContext = create_default_context() + self, + session: Session, + reconnect_args: tuple[Any, ...] = (), + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ssl_context: SSLContext = create_default_context(), ): self._counter = 0 self._lock: Lock = Lock() @@ -357,8 +395,13 @@ def __init__( "Underlying": 17, } self._subscription_state: dict[str, str] = defaultdict(lambda: "CHANNEL_CLOSED") + #: An async function to be called upon reconnection. The first argument must be + #: of type `DXLinkStreamer` and will be a reference to the streamer object. + self.reconnect_fn = reconnect_fn + #: Variable number of arguments to pass to the reconnect function + self.reconnect_args = reconnect_args - #: The unique client identifier received from the server + # The unique client identifier received from the server self._session = session self._authenticated = False self._wss_url = session.dxlink_url @@ -366,6 +409,7 @@ def __init__( self._ssl_context = ssl_context self._connect_task = asyncio.create_task(self._connect()) + self._reconnect_task = None async def __aenter__(self): time_out = 100 @@ -379,36 +423,38 @@ async def __aenter__(self): @classmethod async def create( - cls, session: Session, ssl_context: SSLContext = create_default_context() + cls, + session: Session, + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ssl_context: SSLContext = create_default_context(), ) -> "DXLinkStreamer": - self = cls(session, ssl_context=ssl_context) + self = cls(session, reconnect_fn=reconnect_fn, ssl_context=ssl_context) return await self.__aenter__() async def __aexit__(self, *exc): - await self.close() + self.close() - async def close(self): + def close(self): """ Closes the websocket connection and cancels the heartbeat task. """ self._connect_task.cancel() self._heartbeat_task.cancel() + if self._reconnect_task is not None: + self._reconnect_task.cancel() async def _connect(self) -> None: """ Connect to the websocket server using the URL and authorization token provided during initialization. """ - + reconnecting = False async for websocket in connect(self._wss_url, ssl=self._ssl_context): + self._websocket = websocket + await self._setup_connection() try: - self._websocket = websocket - await self._setup_connection() - - # main loop async for raw_message in websocket: message = json.loads(raw_message) - logger.debug("received: %s", message) if message["type"] == "SETUP": await self._authenticate_connection() @@ -419,6 +465,11 @@ async def _connect(self) -> None: self._heartbeat_task = asyncio.create_task( self._heartbeat() ) + # run reconnect hook upon auth completion + if reconnecting and self.reconnect_fn is not None: + self._reconnect_task = asyncio.create_task( + self.reconnect_fn(self, *self.reconnect_args) + ) elif message["type"] == "CHANNEL_OPENED": channel = next( k @@ -445,6 +496,7 @@ async def _connect(self) -> None: raise TastytradeError("Unknown message type:", message) except ConnectionClosed: logger.debug("Websocket connection closed, retrying...") + reconnecting = True continue async def _setup_connection(self):