From ce5ef1d23a0d6f0f7cdb015d3d7dd503577474d5 Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Tue, 10 Dec 2024 14:54:49 -0500 Subject: [PATCH] add streamer reconnection callbacks (#186) * add streamer reconnection callbacks * add reconnect args to create() * change create to __await__ --- docs/account-streamer.rst | 24 +++++++ docs/data-streamer.rst | 25 ++++++- pyproject.toml | 4 +- tastytrade/streamer.py | 137 +++++++++++++++++++++++++------------- tests/test_streamer.py | 36 ++++++++++ uv.lock | 10 +-- 6 files changed, 180 insertions(+), 56 deletions(-) diff --git a/docs/account-streamer.rst b/docs/account-streamer.rst index 9b118b0..e8b97c3 100644 --- a/docs/account-streamer.rst +++ b/docs/account-streamer.rst @@ -1,6 +1,9 @@ Account Streamer ================ +Basic usage +----------- + The account streamer is used to track account-level updates, such as order fills, watchlist updates and quote alerts. Typically, you'll want a separate task running for the account streamer, which can then notify your application about important events. @@ -35,3 +38,24 @@ Probably the most important information the account streamer handles is order fi async for order in streamer.listen(PlacedOrder): print(order) + +Retry callback +-------------- + +The account streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever alerts you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection). +The callback function should look something like this: + +.. code-block:: python + + async def callback(streamer: AlertStreamer, arg1, arg2): + await streamer.subscribe_quote_alerts() + +The requirements are that the first parameter be the `AlertStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use. +This callback can then be used when creating the streamer: + +.. code-block:: python + + async with AlertStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer: + # ... + +The reconnection uses `websockets`' exponential backoff algorithm, which can be configured through environment variables `here `_. diff --git a/docs/data-streamer.rst b/docs/data-streamer.rst index 4d39680..4f85755 100644 --- a/docs/data-streamer.rst +++ b/docs/data-streamer.rst @@ -10,7 +10,7 @@ You can create a streamer using an active production session: .. code-block:: python from tastytrade import DXLinkStreamer - streamer = await DXLinkStreamer.create(session) + streamer = await DXLinkStreamer(session) Or, you can create a streamer using an asynchronous context manager: @@ -110,7 +110,7 @@ For example, we can use the streamer to create an option chain that will continu # the `streamer_symbol` property is the symbol used by the streamer streamer_symbols = [o.streamer_symbol for o in options] - streamer = await DXLinkStreamer.create(session) + streamer = await DXLinkStreamer(session) # subscribe to quotes and greeks for all options on that date await streamer.subscribe(Quote, [symbol] + streamer_symbols) await streamer.subscribe(Greeks, streamer_symbols) @@ -146,3 +146,24 @@ Now, we can access the quotes and greeks at any time, and they'll be up-to-date print(live_prices.quotes[symbol], live_prices.greeks[symbol]) >>> Quote(eventSymbol='.SPY230721C387', eventTime=0, sequence=0, timeNanoPart=0, bidTime=1689365699000, bidExchangeCode='X', bidPrice=62.01, bidSize=50.0, askTime=1689365699000, askExchangeCode='X', askPrice=62.83, askSize=50.0) Greeks(eventSymbol='.SPY230721C387', eventTime=0, eventFlags=0, index=7255910303911641088, time=1689398266363, sequence=0, price=62.6049270064687, volatility=0.536152815048564, delta=0.971506591907638, gamma=0.001814464566110275, theta=-0.1440768557397271, rho=0.0831882577866199, vega=0.0436861878838861) + +Retry callback +-------------- + +The data streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever events you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection). +The callback function should look something like this: + +.. code-block:: python + + async def callback(streamer: DXLinkStreamer, arg1, arg2): + await streamer.subscribe(Quote, ['SPY']) + +The requirements are that the first parameter be the `DXLinkStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use. +This callback can then be used when creating the streamer: + +.. code-block:: python + + async with DXLinkStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer: + # ... + +The reconnection uses `websockets`' exponential backoff algorithm, which can be configured through environment variables `here `_. diff --git a/pyproject.toml b/pyproject.toml index 109533f..7054dc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "httpx>=0.27.2", "pandas-market-calendars>=4.4.1", "pydantic>=2.9.2", - "websockets>=14.1", + "websockets>=14.1,<15", ] [project.urls] @@ -29,7 +29,7 @@ dev-dependencies = [ "pytest-aio>=1.5.0", "pytest-cov>=5.0.0", "ruff>=0.6.9", - "pyright>=1.1.384", + "pyright>=1.1.390", ] [tool.setuptools.package-data] diff --git a/tastytrade/streamer.py b/tastytrade/streamer.py index 6debe66..9d40238 100644 --- a/tastytrade/streamer.py +++ b/tastytrade/streamer.py @@ -1,12 +1,21 @@ import asyncio import json -from asyncio import Lock, Queue, QueueEmpty +from asyncio import Queue, QueueEmpty from collections import defaultdict from datetime import datetime from decimal import Decimal from enum import Enum from ssl import SSLContext, create_default_context -from typing import Any, AsyncIterator, Optional, Type, TypeVar, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Optional, + Type, + TypeVar, + Union, +) from pydantic import model_validator from websockets.asyncio.client import ClientConnection, connect @@ -158,8 +167,8 @@ class AlertStreamer: """ Used to subscribe to account-level updates (balances, orders, positions), public watchlist updates, quote alerts, and user-level messages. It should - always be initialized as an async context manager, or with the `create` - function, since the object cannot be fully instantiated without async. + always be initialized as an async context manager, or by awaiting it, + since the object cannot be fully instantiated without async. Example usage:: @@ -179,17 +188,32 @@ class AlertStreamer: async for order in streamer.listen(PlacedOrder): print(order) + Or:: + + streamer = await AlertStreamer(session) + """ - def __init__(self, session: Session): + def __init__( + self, + session: Session, + reconnect_args: tuple[Any, ...] = (), + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ): #: The active session used to initiate the streamer or make requests self.token: str = session.session_token #: The base url for the streamer websocket self.base_url: str = CERT_STREAMER_URL if session.is_test else STREAMER_URL + #: An async function to be called upon reconnection. The first argument must be + #: of type `AlertStreamer` and will be a reference to the streamer object. + self.reconnect_fn = reconnect_fn + #: Variable number of arguments to pass to the reconnect function + self.reconnect_args = reconnect_args self._queues: dict[str, Queue] = defaultdict(Queue) self._websocket: Optional[ClientConnection] = None self._connect_task = asyncio.create_task(self._connect()) + self._reconnect_task = None async def __aenter__(self): time_out = 100 @@ -201,20 +225,20 @@ async def __aenter__(self): return self - @classmethod - async def create(cls, session: Session) -> "AlertStreamer": - self = cls(session) - return await self.__aenter__() + def __await__(self): + return self.__aenter__().__await__() async def __aexit__(self, *exc): - await self.close() + self.close() - async def close(self): + def close(self) -> None: """ - Closes the websocket connection and cancels the heartbeat task. + Closes the websocket connection and cancels the pending tasks. """ self._connect_task.cancel() self._heartbeat_task.cancel() + if self._reconnect_task is not None: + self._reconnect_task.cancel() async def _connect(self) -> None: """ @@ -222,11 +246,16 @@ async def _connect(self) -> None: token provided during initialization. """ headers = {"Authorization": f"Bearer {self.token}"} + reconnecting = False async for websocket in connect(self.base_url, additional_headers=headers): self._websocket = websocket self._heartbeat_task = asyncio.create_task(self._heartbeat()) logger.debug("Websocket connection established.") + if reconnecting and self.reconnect_fn is not None: + self._reconnect_task = asyncio.create_task( + self.reconnect_fn(self, *self.reconnect_args) + ) try: async for raw_message in websocket: logger.debug("raw message: %s", raw_message) @@ -234,9 +263,10 @@ async def _connect(self) -> None: type_str = data.get("type") if type_str is not None: await self._map_message(type_str, data["data"]) - except ConnectionClosed: - logger.debug("Websocket connection closed, retrying...") - continue + except ConnectionClosed as e: + logger.error(f"Websocket connection closed with {e}") + logger.debug("Websocket connection closed, retrying...") + reconnecting = True async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]: """ @@ -252,7 +282,7 @@ async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]: while True: yield await self._queues[cls_str].get() - async def _map_message(self, type_str: str, data: dict): + async def _map_message(self, type_str: str, data: dict) -> None: """ I'm not sure what the user-status messages look like, so they're absent. """ @@ -322,8 +352,8 @@ class DXLinkStreamer: """ A :class:`DXLinkStreamer` object is used to fetch quotes or greeks for a given symbol or list of symbols. It should always be initialized as an - async context manager, or with the `create` function, since the object - cannot be fully instantiated without async. + async context manager, or by awaiting it, since the object cannot be + fully instantiated without async. Example usage:: @@ -337,13 +367,19 @@ class DXLinkStreamer: quote = await streamer.get_event(Quote) print(quote) + Or:: + + streamer = await DXLinkStreamer(session) + """ def __init__( - self, session: Session, ssl_context: SSLContext = create_default_context() + self, + session: Session, + reconnect_args: tuple[Any, ...] = (), + reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + ssl_context: SSLContext = create_default_context(), ): - self._counter = 0 - self._lock: Lock = Lock() self._queues: dict[str, Queue] = defaultdict(Queue) self._channels: dict[str, int] = { "Candle": 1, @@ -357,17 +393,21 @@ def __init__( "Underlying": 17, } self._subscription_state: dict[str, str] = defaultdict(lambda: "CHANNEL_CLOSED") + #: An async function to be called upon reconnection. The first argument must be + #: of type `DXLinkStreamer` and will be a reference to the streamer object. + self.reconnect_fn = reconnect_fn + #: Variable number of arguments to pass to the reconnect function + self.reconnect_args = reconnect_args - #: The unique client identifier received from the server self._session = session self._authenticated = False self._wss_url = session.dxlink_url self._auth_token = session.streamer_token self._ssl_context = ssl_context - - self._connect_task = asyncio.create_task(self._connect()) + self._reconnect_task = None async def __aenter__(self): + self._connect_task = asyncio.create_task(self._connect()) time_out = 100 while not self._authenticated: await asyncio.sleep(0.1) @@ -377,38 +417,33 @@ async def __aenter__(self): return self - @classmethod - async def create( - cls, session: Session, ssl_context: SSLContext = create_default_context() - ) -> "DXLinkStreamer": - self = cls(session, ssl_context=ssl_context) - return await self.__aenter__() + def __await__(self): + return self.__aenter__().__await__() async def __aexit__(self, *exc): - await self.close() + self.close() - async def close(self): + def close(self) -> None: """ Closes the websocket connection and cancels the heartbeat task. """ self._connect_task.cancel() self._heartbeat_task.cancel() + if self._reconnect_task is not None: + self._reconnect_task.cancel() async def _connect(self) -> None: """ Connect to the websocket server using the URL and authorization token provided during initialization. """ - + reconnecting = False async for websocket in connect(self._wss_url, ssl=self._ssl_context): + self._websocket = websocket + await self._setup_connection() try: - self._websocket = websocket - await self._setup_connection() - - # main loop async for raw_message in websocket: message = json.loads(raw_message) - logger.debug("received: %s", message) if message["type"] == "SETUP": await self._authenticate_connection() @@ -419,13 +454,20 @@ async def _connect(self) -> None: self._heartbeat_task = asyncio.create_task( self._heartbeat() ) + # run reconnect hook upon auth completion + if reconnecting and self.reconnect_fn is not None: + self._subscription_state.clear() + reconnecting = False + self._reconnect_task = asyncio.create_task( + self.reconnect_fn(self, *self.reconnect_args) + ) elif message["type"] == "CHANNEL_OPENED": channel = next( k for k, v in self._channels.items() if v == message["channel"] ) - self._subscription_state[channel] = message["type"] + self._subscription_state[channel] = "CHANNEL_OPENED" logger.debug("Channel opened: %s", message) elif message["type"] == "CHANNEL_CLOSED": channel = next( @@ -433,7 +475,7 @@ async def _connect(self) -> None: for k, v in self._channels.items() if v == message["channel"] ) - self._subscription_state[channel] = message["type"] + del self._subscription_state[channel] logger.debug("Channel closed: %s", message) elif message["type"] == "FEED_CONFIG": logger.debug("Feed configured: %s", message) @@ -442,12 +484,13 @@ async def _connect(self) -> None: elif message["type"] == "KEEPALIVE": pass else: - raise TastytradeError("Unknown message type:", message) - except ConnectionClosed: - logger.debug("Websocket connection closed, retrying...") - continue + logger.error(f"Streamer error: {message}") + except ConnectionClosed as e: + logger.error(f"Websocket connection closed with {e}") + logger.debug("Websocket connection closed, retrying...") + reconnecting = True - async def _setup_connection(self): + async def _setup_connection(self) -> None: message = { "type": "SETUP", "channel": 0, @@ -457,7 +500,7 @@ async def _setup_connection(self): } await self._websocket.send(json.dumps(message)) - async def _authenticate_connection(self): + async def _authenticate_connection(self) -> None: message = { "type": "AUTH", "channel": 0, @@ -687,7 +730,7 @@ async def unsubscribe_candle( } await self._websocket.send(json.dumps(message)) - async def _map_message(self, message) -> None: # pragma: no cover + async def _map_message(self, message) -> None: """ Takes the raw JSON data, parses the events and places them into their respective queues. diff --git a/tests/test_streamer.py b/tests/test_streamer.py index fb7747e..b59024d 100644 --- a/tests/test_streamer.py +++ b/tests/test_streamer.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timedelta from tastytrade import Account, AlertStreamer, DXLinkStreamer @@ -29,3 +30,38 @@ async def test_dxlink_streamer(session): await streamer.unsubscribe_candle(subs[0], "1d") await streamer.unsubscribe(Quote, [subs[0]]) await streamer.unsubscribe_all(Quote) + + +async def reconnect_alerts(streamer: AlertStreamer, ref: dict[str, bool]): + await streamer.subscribe_quote_alerts() + ref["test"] = True + + +async def test_account_streamer_reconnect(session): + ref = {} + streamer = await AlertStreamer( + session, reconnect_args=(ref,), reconnect_fn=reconnect_alerts + ) + await streamer.subscribe_public_watchlists() + await streamer.subscribe_user_messages(session) + accounts = Account.get_accounts(session) + await streamer.subscribe_accounts(accounts) + await streamer._websocket.close() # type: ignore + await asyncio.sleep(3) + assert "test" in ref + streamer.close() + + +async def reconnect_trades(streamer: DXLinkStreamer): + await streamer.subscribe(Trade, ["SPX"]) + + +async def test_dxlink_streamer_reconnect(session): + streamer = await DXLinkStreamer(session, reconnect_fn=reconnect_trades) + await streamer.subscribe(Quote, ["SPY"]) + _ = await streamer.get_event(Quote) + await streamer._websocket.close() + await asyncio.sleep(3) + trade = await streamer.get_event(Trade) + assert trade.event_symbol == "SPX" + streamer.close() diff --git a/uv.lock b/uv.lock index 27ec446..86315c8 100644 --- a/uv.lock +++ b/uv.lock @@ -427,15 +427,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.384" +version = "1.1.390" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/00/a23114619f9d005f4b0f35e037c76cee029174d090a6f73a355749c74f4a/pyright-1.1.384.tar.gz", hash = "sha256:25e54d61f55cbb45f1195ff89c488832d7a45d59f3e132f178fdf9ef6cafc706", size = 21956 } +sdist = { url = "https://files.pythonhosted.org/packages/ba/42/1e0392f35dd275f9f775baf7c86407cef7f6a0d9b8e099a93e5422a7e571/pyright-1.1.390.tar.gz", hash = "sha256:aad7f160c49e0fbf8209507a15e17b781f63a86a1facb69ca877c71ef2e9538d", size = 21950 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6d/4a/e7f4d71d194ba675f3577d11eebe4e17a592c4d1c3f9986d4b321ba3c809/pyright-1.1.384-py3-none-any.whl", hash = "sha256:f0b6f4db2da38f27aeb7035c26192f034587875f751b847e9ad42ed0c704ac9e", size = 18578 }, + { url = "https://files.pythonhosted.org/packages/43/20/3f492ca789fb17962ad23619959c7fa642082621751514296c58de3bb801/pyright-1.1.390-py3-none-any.whl", hash = "sha256:ecebfba5b6b50af7c1a44c2ba144ba2ab542c227eb49bc1f16984ff714e0e110", size = 18579 }, ] [[package]] @@ -569,12 +569,12 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "pandas-market-calendars", specifier = ">=4.4.1" }, { name = "pydantic", specifier = ">=2.9.2" }, - { name = "websockets", specifier = ">=14.1" }, + { name = "websockets", specifier = ">=14.1,<15" }, ] [package.metadata.requires-dev] dev = [ - { name = "pyright", specifier = ">=1.1.384" }, + { name = "pyright", specifier = ">=1.1.390" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-aio", specifier = ">=1.5.0" }, { name = "pytest-cov", specifier = ">=5.0.0" },