From 1ed4d8a90cbf3ff45c7b68363c2658e5f7af0214 Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Mon, 16 Dec 2024 15:26:47 -0500 Subject: [PATCH] add advanced order instructions, handle CancelledError in streamer tasks (#187) * add advanced order instructions, handle CancelledError in streamer tasks * lint * small tweaks --- docs/conf.py | 2 +- pyproject.toml | 2 +- tastytrade/__init__.py | 2 +- tastytrade/order.py | 12 +++++++++ tastytrade/streamer.py | 59 ++++++++++++++++++++++++++---------------- tests/test_streamer.py | 4 +-- uv.lock | 2 +- 7 files changed, 55 insertions(+), 28 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 2062e19..fa00cf8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,7 @@ project = "tastytrade" copyright = "2024, Graeme Holliday" author = "Graeme Holliday" -release = "9.4" +release = "9.5" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 7054dc5..a529a4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tastytrade" -version = "9.4" +version = "9.5" description = "An unofficial, sync/async SDK for Tastytrade!" readme = "README.md" requires-python = ">=3.9" diff --git a/tastytrade/__init__.py b/tastytrade/__init__.py index 5c62315..b4ae6ab 100644 --- a/tastytrade/__init__.py +++ b/tastytrade/__init__.py @@ -4,7 +4,7 @@ BACKTEST_URL = "https://backtester.vast.tastyworks.com" CERT_URL = "https://api.cert.tastyworks.com" VAST_URL = "https://vast.tastyworks.com" -VERSION = "9.4" +VERSION = "9.5" logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/tastytrade/order.py b/tastytrade/order.py index 2442c27..45e4219 100644 --- a/tastytrade/order.py +++ b/tastytrade/order.py @@ -221,6 +221,17 @@ class OrderRule(TastytradeJsonDataclass): order_conditions: list[OrderCondition] +class AdvancedInstructions(TastytradeJsonDataclass): + """ + Dataclass containing advanced order rules. + """ + + #: By default, if a position meant to be closed by a closing order is no longer + #: open, the API will turn it into an opening order. With this flag, the API would + #: instead discard the closing order. + strict_position_effect_validation: bool = False + + class NewOrder(TastytradeJsonDataclass): """ Dataclass containing information about a new order. Also used for @@ -241,6 +252,7 @@ class NewOrder(TastytradeJsonDataclass): partition_key: Optional[str] = None preflight_id: Optional[str] = None rules: Optional[OrderRule] = None + advanced_instructions: Optional[AdvancedInstructions] = None @computed_field @property diff --git a/tastytrade/streamer.py b/tastytrade/streamer.py index 9d40238..779476e 100644 --- a/tastytrade/streamer.py +++ b/tastytrade/streamer.py @@ -229,16 +229,19 @@ def __await__(self): return self.__aenter__().__await__() async def __aexit__(self, *exc): - self.close() + await self.close() - def close(self) -> None: + async def close(self) -> None: """ Closes the websocket connection and cancels the pending tasks. """ self._connect_task.cancel() self._heartbeat_task.cancel() - if self._reconnect_task is not None: + tasks = [self._connect_task, self._heartbeat_task] + if self._reconnect_task is not None and not self._reconnect_task.done(): self._reconnect_task.cancel() + tasks.append(self._reconnect_task) + await asyncio.gather(*tasks) async def _connect(self) -> None: """ @@ -265,6 +268,9 @@ async def _connect(self) -> None: await self._map_message(type_str, data["data"]) except ConnectionClosed as e: logger.error(f"Websocket connection closed with {e}") + except asyncio.CancelledError: + logger.debug("Websocket interrupted, cancelling main loop.") + return logger.debug("Websocket connection closed, retrying...") reconnecting = True @@ -327,10 +333,14 @@ async def _heartbeat(self) -> None: Sends a heartbeat message every 10 seconds to keep the connection alive. """ - while True: - await self._subscribe(SubscriptionType.HEARTBEAT, "") - # send the heartbeat every 10 seconds - await asyncio.sleep(10) + try: + while True: + await self._subscribe(SubscriptionType.HEARTBEAT, "") + # send the heartbeat every 10 seconds + await asyncio.sleep(10) + except asyncio.CancelledError: + logger.debug("Websocket interrupted, cancelling heartbeat.") + return async def _subscribe( self, @@ -399,7 +409,6 @@ def __init__( #: Variable number of arguments to pass to the reconnect function self.reconnect_args = reconnect_args - self._session = session self._authenticated = False self._wss_url = session.dxlink_url self._auth_token = session.streamer_token @@ -421,16 +430,19 @@ def __await__(self): return self.__aenter__().__await__() async def __aexit__(self, *exc): - self.close() + await self.close() - def close(self) -> None: + async def close(self) -> None: """ Closes the websocket connection and cancels the heartbeat task. """ self._connect_task.cancel() self._heartbeat_task.cancel() - if self._reconnect_task is not None: + tasks = [self._connect_task, self._heartbeat_task] + if self._reconnect_task is not None and not self._reconnect_task.done(): self._reconnect_task.cancel() + tasks.append(self._reconnect_task) + await asyncio.gather(*tasks) async def _connect(self) -> None: """ @@ -487,6 +499,9 @@ async def _connect(self) -> None: logger.error(f"Streamer error: {message}") except ConnectionClosed as e: logger.error(f"Websocket connection closed with {e}") + except asyncio.CancelledError: + logger.debug("Websocket interrupted, cancelling main loop.") + return logger.debug("Websocket connection closed, retrying...") reconnecting = True @@ -558,12 +573,15 @@ async def _heartbeat(self) -> None: alive. """ message = {"type": "KEEPALIVE", "channel": 0} - - while True: - logger.debug("sending keepalive message: %s", message) - await self._websocket.send(json.dumps(message)) - # send the heartbeat every 30 seconds - await asyncio.sleep(30) + try: + while True: + logger.debug("sending keepalive message: %s", message) + await self._websocket.send(json.dumps(message)) + # send the heartbeat every 30 seconds + await asyncio.sleep(30) + except asyncio.CancelledError: + logger.debug("Websocket interrupted, cancelling heartbeat.") + return async def subscribe(self, event_class: Type[EventType], symbols: list[str]) -> None: """ @@ -644,7 +662,7 @@ async def unsubscribe( Removes existing subscription for given list of symbols. For candles, use :meth:`unsubscribe_candle` instead. - :param event_type: type of subscription to remove + :param event_class: type of subscription to remove :param symbols: list of symbols to unsubscribe from """ if not self._authenticated: @@ -663,11 +681,10 @@ async def subscribe_candle( symbols: list[str], interval: str, start_time: datetime, - end_time: Optional[datetime] = None, extended_trading_hours: bool = False, ) -> None: """ - Subscribes to time series data for the given symbol. + Subscribes to candle data for the given list of symbols. :param symbols: list of symbols to get data for :param interval: @@ -696,8 +713,6 @@ async def subscribe_candle( for ticker in symbols ], } - if end_time is not None: - raise TastytradeError("End time no longer supported") await self._websocket.send(json.dumps(message)) async def unsubscribe_candle( diff --git a/tests/test_streamer.py b/tests/test_streamer.py index b59024d..aa8a248 100644 --- a/tests/test_streamer.py +++ b/tests/test_streamer.py @@ -49,7 +49,7 @@ async def test_account_streamer_reconnect(session): await streamer._websocket.close() # type: ignore await asyncio.sleep(3) assert "test" in ref - streamer.close() + await streamer.close() async def reconnect_trades(streamer: DXLinkStreamer): @@ -64,4 +64,4 @@ async def test_dxlink_streamer_reconnect(session): await asyncio.sleep(3) trade = await streamer.get_event(Trade) assert trade.event_symbol == "SPX" - streamer.close() + await streamer.close() diff --git a/uv.lock b/uv.lock index 86315c8..a545b88 100644 --- a/uv.lock +++ b/uv.lock @@ -546,7 +546,7 @@ wheels = [ [[package]] name = "tastytrade" -version = "9.4" +version = "9.5" source = { editable = "." } dependencies = [ { name = "httpx" },