Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add advanced order instructions, handle CancelledError in streamer tasks #187

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "tastytrade"
copyright = "2024, Graeme Holliday"
author = "Graeme Holliday"
release = "9.4"
release = "9.5"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tastytrade"
version = "9.4"
version = "9.5"
description = "An unofficial, sync/async SDK for Tastytrade!"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
CERT_URL = "https://api.cert.tastyworks.com"
VAST_URL = "https://vast.tastyworks.com"
VERSION = "9.4"
VERSION = "9.5"

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down
12 changes: 12 additions & 0 deletions tastytrade/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ class OrderRule(TastytradeJsonDataclass):
order_conditions: list[OrderCondition]


class AdvancedInstructions(TastytradeJsonDataclass):
"""
Dataclass containing advanced order rules.
"""

#: By default, if a position meant to be closed by a closing order is no longer
#: open, the API will turn it into an opening order. With this flag, the API would
#: instead discard the closing order.
strict_position_effect_validation: bool = False


class NewOrder(TastytradeJsonDataclass):
"""
Dataclass containing information about a new order. Also used for
Expand All @@ -241,6 +252,7 @@ class NewOrder(TastytradeJsonDataclass):
partition_key: Optional[str] = None
preflight_id: Optional[str] = None
rules: Optional[OrderRule] = None
advanced_instructions: Optional[AdvancedInstructions] = None

@computed_field
@property
Expand Down
59 changes: 37 additions & 22 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,19 @@ def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()
await self.close()

def close(self) -> None:
async def close(self) -> None:
"""
Closes the websocket connection and cancels the pending tasks.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
tasks = [self._connect_task, self._heartbeat_task]
if self._reconnect_task is not None and not self._reconnect_task.done():
self._reconnect_task.cancel()
tasks.append(self._reconnect_task)
await asyncio.gather(*tasks)

async def _connect(self) -> None:
"""
Expand All @@ -265,6 +268,9 @@ async def _connect(self) -> None:
await self._map_message(type_str, data["data"])
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling main loop.")
return
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

Expand Down Expand Up @@ -327,10 +333,14 @@ async def _heartbeat(self) -> None:
Sends a heartbeat message every 10 seconds to keep the connection
alive.
"""
while True:
await self._subscribe(SubscriptionType.HEARTBEAT, "")
# send the heartbeat every 10 seconds
await asyncio.sleep(10)
try:
while True:
await self._subscribe(SubscriptionType.HEARTBEAT, "")
# send the heartbeat every 10 seconds
await asyncio.sleep(10)
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling heartbeat.")
return

async def _subscribe(
self,
Expand Down Expand Up @@ -399,7 +409,6 @@ def __init__(
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args

self._session = session
self._authenticated = False
self._wss_url = session.dxlink_url
self._auth_token = session.streamer_token
Expand All @@ -421,16 +430,19 @@ def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()
await self.close()

def close(self) -> None:
async def close(self) -> None:
"""
Closes the websocket connection and cancels the heartbeat task.
"""
self._connect_task.cancel()
self._heartbeat_task.cancel()
if self._reconnect_task is not None:
tasks = [self._connect_task, self._heartbeat_task]
if self._reconnect_task is not None and not self._reconnect_task.done():
self._reconnect_task.cancel()
tasks.append(self._reconnect_task)
await asyncio.gather(*tasks)

async def _connect(self) -> None:
"""
Expand Down Expand Up @@ -487,6 +499,9 @@ async def _connect(self) -> None:
logger.error(f"Streamer error: {message}")
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling main loop.")
return
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

Expand Down Expand Up @@ -558,12 +573,15 @@ async def _heartbeat(self) -> None:
alive.
"""
message = {"type": "KEEPALIVE", "channel": 0}

while True:
logger.debug("sending keepalive message: %s", message)
await self._websocket.send(json.dumps(message))
# send the heartbeat every 30 seconds
await asyncio.sleep(30)
try:
while True:
logger.debug("sending keepalive message: %s", message)
await self._websocket.send(json.dumps(message))
# send the heartbeat every 30 seconds
await asyncio.sleep(30)
except asyncio.CancelledError:
logger.debug("Websocket interrupted, cancelling heartbeat.")
return

async def subscribe(self, event_class: Type[EventType], symbols: list[str]) -> None:
"""
Expand Down Expand Up @@ -644,7 +662,7 @@ async def unsubscribe(
Removes existing subscription for given list of symbols.
For candles, use :meth:`unsubscribe_candle` instead.

:param event_type: type of subscription to remove
:param event_class: type of subscription to remove
:param symbols: list of symbols to unsubscribe from
"""
if not self._authenticated:
Expand All @@ -663,11 +681,10 @@ async def subscribe_candle(
symbols: list[str],
interval: str,
start_time: datetime,
end_time: Optional[datetime] = None,
extended_trading_hours: bool = False,
) -> None:
"""
Subscribes to time series data for the given symbol.
Subscribes to candle data for the given list of symbols.

:param symbols: list of symbols to get data for
:param interval:
Expand Down Expand Up @@ -696,8 +713,6 @@ async def subscribe_candle(
for ticker in symbols
],
}
if end_time is not None:
raise TastytradeError("End time no longer supported")
await self._websocket.send(json.dumps(message))

async def unsubscribe_candle(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_account_streamer_reconnect(session):
await streamer._websocket.close() # type: ignore
await asyncio.sleep(3)
assert "test" in ref
streamer.close()
await streamer.close()


async def reconnect_trades(streamer: DXLinkStreamer):
Expand All @@ -64,4 +64,4 @@ async def test_dxlink_streamer_reconnect(session):
await asyncio.sleep(3)
trade = await streamer.get_event(Trade)
assert trade.event_symbol == "SPX"
streamer.close()
await streamer.close()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading