Skip to content

Commit

Permalink
add advanced order instructions, handle CancelledError in streamer ta…
Browse files Browse the repository at this point in the history
…sks (#187)

* add advanced order instructions, handle CancelledError in streamer tasks

* lint

* small tweaks
  • Loading branch information
Graeme22 authored Dec 16, 2024
1 parent e743cfa commit 1ed4d8a
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "tastytrade"
copyright = "2024, Graeme Holliday"
author = "Graeme Holliday"
release = "9.4"
release = "9.5"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tastytrade"
version = "9.4"
version = "9.5"
description = "An unofficial, sync/async SDK for Tastytrade!"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
CERT_URL = "https://api.cert.tastyworks.com"
VAST_URL = "https://vast.tastyworks.com"
VERSION = "9.4"
VERSION = "9.5"

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down
12 changes: 12 additions & 0 deletions tastytrade/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ class OrderRule(TastytradeJsonDataclass):
order_conditions: list[OrderCondition]


class AdvancedInstructions(TastytradeJsonDataclass):
"""
Dataclass containing advanced order rules.
"""

#: By default, if a position meant to be closed by a closing order is no longer
#: open, the API will turn it into an opening order. With this flag, the API would
#: instead discard the closing order.
strict_position_effect_validation: bool = False


class NewOrder(TastytradeJsonDataclass):
"""
Dataclass containing information about a new order. Also used for
Expand All @@ -241,6 +252,7 @@ class NewOrder(TastytradeJsonDataclass):
partition_key: Optional[str] = None
preflight_id: Optional[str] = None
rules: Optional[OrderRule] = None
advanced_instructions: Optional[AdvancedInstructions] = None

@computed_field
@property
Expand Down
59 changes: 37 additions & 22 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,19 @@ def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()
await self.close()

def close(self) -> None:
async def close(self) -> None:
"""
Closes the websocket connection and cancels the pending tasks.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
tasks = [self._connect_task, self._heartbeat_task]
if self._reconnect_task is not None and not self._reconnect_task.done():
self._reconnect_task.cancel()
tasks.append(self._reconnect_task)
await asyncio.gather(*tasks)

async def _connect(self) -> None:
"""
Expand All @@ -265,6 +268,9 @@ async def _connect(self) -> None:
await self._map_message(type_str, data["data"])
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling main loop.")
return
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

Expand Down Expand Up @@ -327,10 +333,14 @@ async def _heartbeat(self) -> None:
Sends a heartbeat message every 10 seconds to keep the connection
alive.
"""
while True:
await self._subscribe(SubscriptionType.HEARTBEAT, "")
# send the heartbeat every 10 seconds
await asyncio.sleep(10)
try:
while True:
await self._subscribe(SubscriptionType.HEARTBEAT, "")
# send the heartbeat every 10 seconds
await asyncio.sleep(10)
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling heartbeat.")
return

async def _subscribe(
self,
Expand Down Expand Up @@ -399,7 +409,6 @@ def __init__(
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args

self._session = session
self._authenticated = False
self._wss_url = session.dxlink_url
self._auth_token = session.streamer_token
Expand All @@ -421,16 +430,19 @@ def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()
await self.close()

def close(self) -> None:
async def close(self) -> None:
"""
Closes the websocket connection and cancels the heartbeat task.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
tasks = [self._connect_task, self._heartbeat_task]
if self._reconnect_task is not None and not self._reconnect_task.done():
self._reconnect_task.cancel()
tasks.append(self._reconnect_task)
await asyncio.gather(*tasks)

async def _connect(self) -> None:
"""
Expand Down Expand Up @@ -487,6 +499,9 @@ async def _connect(self) -> None:
logger.error(f"Streamer error: {message}")
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling main loop.")
return
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

Expand Down Expand Up @@ -558,12 +573,15 @@ async def _heartbeat(self) -> None:
alive.
"""
message = {"type": "KEEPALIVE", "channel": 0}

while True:
logger.debug("sending keepalive message: %s", message)
await self._websocket.send(json.dumps(message))
# send the heartbeat every 30 seconds
await asyncio.sleep(30)
try:
while True:
logger.debug("sending keepalive message: %s", message)
await self._websocket.send(json.dumps(message))
# send the heartbeat every 30 seconds
await asyncio.sleep(30)
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling heartbeat.")
return

async def subscribe(self, event_class: Type[EventType], symbols: list[str]) -> None:
"""
Expand Down Expand Up @@ -644,7 +662,7 @@ async def unsubscribe(
Removes existing subscription for given list of symbols.
For candles, use :meth:`unsubscribe_candle` instead.
:param event_type: type of subscription to remove
:param event_class: type of subscription to remove
:param symbols: list of symbols to unsubscribe from
"""
if not self._authenticated:
Expand All @@ -663,11 +681,10 @@ async def subscribe_candle(
symbols: list[str],
interval: str,
start_time: datetime,
end_time: Optional[datetime] = None,
extended_trading_hours: bool = False,
) -> None:
"""
Subscribes to time series data for the given symbol.
Subscribes to candle data for the given list of symbols.
:param symbols: list of symbols to get data for
:param interval:
Expand Down Expand Up @@ -696,8 +713,6 @@ async def subscribe_candle(
for ticker in symbols
],
}
if end_time is not None:
raise TastytradeError("End time no longer supported")
await self._websocket.send(json.dumps(message))

async def unsubscribe_candle(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_account_streamer_reconnect(session):
await streamer._websocket.close() # type: ignore
await asyncio.sleep(3)
assert "test" in ref
streamer.close()
await streamer.close()


async def reconnect_trades(streamer: DXLinkStreamer):
Expand All @@ -64,4 +64,4 @@ async def test_dxlink_streamer_reconnect(session):
await asyncio.sleep(3)
trade = await streamer.get_event(Trade)
assert trade.event_symbol == "SPX"
streamer.close()
await streamer.close()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1ed4d8a

Please sign in to comment.