Skip to content

Commit

Permalink
add streamer reconnection callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Dec 10, 2024
1 parent b7e88e2 commit ffdcc8b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 19 deletions.
22 changes: 22 additions & 0 deletions docs/account-streamer.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Account Streamer
================

Basic usage
-----------

The account streamer is used to track account-level updates, such as order fills, watchlist updates and quote alerts.
Typically, you'll want a separate task running for the account streamer, which can then notify your application about important events.

Expand Down Expand Up @@ -35,3 +38,22 @@ Probably the most important information the account streamer handles is order fi
async for order in streamer.listen(PlacedOrder):
print(order)
Retry callback
--------------

The account streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever alerts you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection).
The callback function should look something like this:

.. code-block:: python
async def callback(streamer: AlertStreamer, arg1, arg2):
await streamer.subscribe_quote_alerts()
The requirements are that the first parameter be the `AlertStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use.
This callback can then be used when creating the streamer:

.. code-block:: python
async with AlertStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer:
# ...
19 changes: 19 additions & 0 deletions docs/data-streamer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,22 @@ Now, we can access the quotes and greeks at any time, and they'll be up-to-date
print(live_prices.quotes[symbol], live_prices.greeks[symbol])
>>> Quote(eventSymbol='.SPY230721C387', eventTime=0, sequence=0, timeNanoPart=0, bidTime=1689365699000, bidExchangeCode='X', bidPrice=62.01, bidSize=50.0, askTime=1689365699000, askExchangeCode='X', askPrice=62.83, askSize=50.0) Greeks(eventSymbol='.SPY230721C387', eventTime=0, eventFlags=0, index=7255910303911641088, time=1689398266363, sequence=0, price=62.6049270064687, volatility=0.536152815048564, delta=0.971506591907638, gamma=0.001814464566110275, theta=-0.1440768557397271, rho=0.0831882577866199, vega=0.0436861878838861)

Retry callback
--------------

The data streamer has a special "callback" function which can be used to execute arbitrary code whenever the websocket reconnects. This is useful for re-subscribing to whatever events you wanted to subscribe to initially (in fact, you can probably use the same function/code you use when initializing the connection).
The callback function should look something like this:

.. code-block:: python
async def callback(streamer: DXLinkStreamer, arg1, arg2):
await streamer.subscribe(Quote, ['SPY'])
The requirements are that the first parameter be the `DXLinkStreamer` instance, and the function should be asynchronous. Other than that, you have the flexibility to decide what arguments you want to use.
This callback can then be used when creating the streamer:

.. code-block:: python
async with DXLinkStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer:
# ...
90 changes: 71 additions & 19 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from decimal import Decimal
from enum import Enum
from ssl import SSLContext, create_default_context
from typing import Any, AsyncIterator, Optional, Type, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Optional,
Type,
TypeVar,
Union,
)

from pydantic import model_validator
from websockets.asyncio.client import ClientConnection, connect
Expand Down Expand Up @@ -181,15 +190,26 @@ class AlertStreamer:
"""

def __init__(self, session: Session):
def __init__(
self,
session: Session,
reconnect_args: tuple[Any, ...] = (),
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
):
#: The active session used to initiate the streamer or make requests
self.token: str = session.session_token
#: The base url for the streamer websocket
self.base_url: str = CERT_STREAMER_URL if session.is_test else STREAMER_URL
#: An async function to be called upon reconnection. The first argument must be
#: of type `AlertStreamer` and will be a reference to the streamer object.
self.reconnect_fn = reconnect_fn
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args

self._queues: dict[str, Queue] = defaultdict(Queue)
self._websocket: Optional[ClientConnection] = None
self._connect_task = asyncio.create_task(self._connect())
self._reconnect_task = None

async def __aenter__(self):
time_out = 100
Expand All @@ -202,31 +222,44 @@ async def __aenter__(self):
return self

@classmethod
async def create(cls, session: Session) -> "AlertStreamer":
self = cls(session)
async def create(
cls,
session: Session,
*,
reconnect_args: tuple[Any, ...] = (),
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
) -> "AlertStreamer":
self = cls(session, reconnect_args=reconnect_args, reconnect_fn=reconnect_fn)
return await self.__aenter__()

async def __aexit__(self, *exc):
await self.close()
self.close()

async def close(self):
def close(self):
"""
Closes the websocket connection and cancels the heartbeat task.
Closes the websocket connection and cancels the pending tasks.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
self._reconnect_task.cancel()

async def _connect(self) -> None:
"""
Connect to the websocket server using the URL and authorization
token provided during initialization.
"""
headers = {"Authorization": f"Bearer {self.token}"}
reconnecting = False
async for websocket in connect(self.base_url, additional_headers=headers):
self._websocket = websocket
self._heartbeat_task = asyncio.create_task(self._heartbeat())
logger.debug("Websocket connection established.")

if reconnecting and self.reconnect_fn is not None:
self._reconnect_task = asyncio.create_task(
self.reconnect_fn(self, *self.reconnect_args)
)
try:
async for raw_message in websocket:
logger.debug("raw message: %s", raw_message)
Expand All @@ -236,6 +269,7 @@ async def _connect(self) -> None:
await self._map_message(type_str, data["data"])
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
reconnecting = True
continue

async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]:
Expand Down Expand Up @@ -340,7 +374,11 @@ class DXLinkStreamer:
"""

def __init__(
self, session: Session, ssl_context: SSLContext = create_default_context()
self,
session: Session,
reconnect_args: tuple[Any, ...] = (),
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
ssl_context: SSLContext = create_default_context(),
):
self._counter = 0
self._lock: Lock = Lock()
Expand All @@ -357,15 +395,21 @@ def __init__(
"Underlying": 17,
}
self._subscription_state: dict[str, str] = defaultdict(lambda: "CHANNEL_CLOSED")
#: An async function to be called upon reconnection. The first argument must be
#: of type `DXLinkStreamer` and will be a reference to the streamer object.
self.reconnect_fn = reconnect_fn
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args

#: The unique client identifier received from the server
# The unique client identifier received from the server
self._session = session
self._authenticated = False
self._wss_url = session.dxlink_url
self._auth_token = session.streamer_token
self._ssl_context = ssl_context

self._connect_task = asyncio.create_task(self._connect())
self._reconnect_task = None

async def __aenter__(self):
time_out = 100
Expand All @@ -379,36 +423,38 @@ async def __aenter__(self):

@classmethod
async def create(
cls, session: Session, ssl_context: SSLContext = create_default_context()
cls,
session: Session,
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
ssl_context: SSLContext = create_default_context(),
) -> "DXLinkStreamer":
self = cls(session, ssl_context=ssl_context)
self = cls(session, reconnect_fn=reconnect_fn, ssl_context=ssl_context)
return await self.__aenter__()

async def __aexit__(self, *exc):
await self.close()
self.close()

async def close(self):
def close(self):
"""
Closes the websocket connection and cancels the heartbeat task.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
self._reconnect_task.cancel()

async def _connect(self) -> None:
"""
Connect to the websocket server using the URL and
authorization token provided during initialization.
"""

reconnecting = False
async for websocket in connect(self._wss_url, ssl=self._ssl_context):
self._websocket = websocket
await self._setup_connection()
try:
self._websocket = websocket
await self._setup_connection()

# main loop
async for raw_message in websocket:
message = json.loads(raw_message)

logger.debug("received: %s", message)
if message["type"] == "SETUP":
await self._authenticate_connection()
Expand All @@ -419,6 +465,11 @@ async def _connect(self) -> None:
self._heartbeat_task = asyncio.create_task(
self._heartbeat()
)
# run reconnect hook upon auth completion
if reconnecting and self.reconnect_fn is not None:
self._reconnect_task = asyncio.create_task(
self.reconnect_fn(self, *self.reconnect_args)
)
elif message["type"] == "CHANNEL_OPENED":
channel = next(
k
Expand All @@ -445,6 +496,7 @@ async def _connect(self) -> None:
raise TastytradeError("Unknown message type:", message)
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
reconnecting = True
continue

async def _setup_connection(self):
Expand Down

0 comments on commit ffdcc8b

Please sign in to comment.