Skip to content

Commit

Permalink
update websockets, add autoreconnect (#180)
Browse files Browse the repository at this point in the history
* update websockets, add autoreconnect

* upgrade websockets dep
  • Loading branch information
Graeme22 authored Dec 9, 2024
1 parent 895779e commit 6b0506b
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 189 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "tastytrade"
copyright = "2024, Graeme Holliday"
author = "Graeme Holliday"
release = "9.3"
release = "9.4"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ build-backend = "hatchling.build"

[project]
name = "tastytrade"
version = "9.3"
version = "9.4"
description = "An unofficial, sync/async SDK for Tastytrade!"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "Graeme Holliday", email = "[email protected]"},
Expand All @@ -16,7 +16,7 @@ dependencies = [
"httpx>=0.27.2",
"pandas-market-calendars>=4.4.1",
"pydantic>=2.9.2",
"websockets>=13.1",
"websockets>=14.1",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
CERT_URL = "https://api.cert.tastyworks.com"
VAST_URL = "https://vast.tastyworks.com"
VERSION = "9.3"
VERSION = "9.4"

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down
114 changes: 62 additions & 52 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from ssl import SSLContext, create_default_context
from typing import Any, AsyncIterator, Optional, Type, TypeVar, Union

import websockets
from pydantic import model_validator
from websockets import WebSocketClientProtocol
from websockets.asyncio.client import ClientConnection, connect
from websockets.exceptions import ConnectionClosed

from tastytrade import logger
from tastytrade.account import Account, AccountBalance, CurrentPosition, TradingStatus
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(self, session: Session):
self.base_url: str = CERT_STREAMER_URL if session.is_test else STREAMER_URL

self._queues: dict[str, Queue] = defaultdict(Queue)
self._websocket: Optional[WebSocketClientProtocol] = None
self._websocket: Optional[ClientConnection] = None
self._connect_task = asyncio.create_task(self._connect())

async def __aenter__(self):
Expand Down Expand Up @@ -222,19 +222,21 @@ async def _connect(self) -> None:
token provided during initialization.
"""
headers = {"Authorization": f"Bearer {self.token}"}
async with websockets.connect(
self.base_url, extra_headers=headers
) as websocket: # type: ignore
async for websocket in connect(self.base_url, additional_headers=headers):
self._websocket = websocket
self._heartbeat_task = asyncio.create_task(self._heartbeat())

while True:
raw_message = await self._websocket.recv() # type: ignore
logger.debug("raw message: %s", raw_message)
data = json.loads(raw_message)
type_str = data.get("type")
if type_str is not None:
await self._map_message(type_str, data["data"])
logger.debug("Websocket connection established.")

try:
async for raw_message in websocket:
logger.debug("raw message: %s", raw_message)
data = json.loads(raw_message)
type_str = data.get("type")
if type_str is not None:
await self._map_message(type_str, data["data"])
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
continue

async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]:
"""
Expand Down Expand Up @@ -398,44 +400,52 @@ async def _connect(self) -> None:
authorization token provided during initialization.
"""

async with websockets.connect(
self._wss_url, ssl=self._ssl_context
) as websocket:
self._websocket = websocket
await self._setup_connection()

# main loop
while True:
raw_message = await self._websocket.recv()
message = json.loads(raw_message)

logger.debug("received: %s", message)
if message["type"] == "SETUP":
await self._authenticate_connection()
elif message["type"] == "AUTH_STATE":
if message["state"] == "AUTHORIZED":
self._authenticated = True
self._heartbeat_task = asyncio.create_task(self._heartbeat())
elif message["type"] == "CHANNEL_OPENED":
channel = next(
k for k, v in self._channels.items() if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
logger.debug("Channel opened: %s", message)
elif message["type"] == "CHANNEL_CLOSED":
channel = next(
k for k, v in self._channels.items() if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
logger.debug("Channel closed: %s", message)
elif message["type"] == "FEED_CONFIG":
logger.debug("Feed configured: %s", message)
elif message["type"] == "FEED_DATA":
await self._map_message(message["data"])
elif message["type"] == "KEEPALIVE":
pass
else:
raise TastytradeError("Unknown message type:", message)
async for websocket in connect(self._wss_url, ssl=self._ssl_context):
try:
self._websocket = websocket
await self._setup_connection()

# main loop
async for raw_message in websocket:
message = json.loads(raw_message)

logger.debug("received: %s", message)
if message["type"] == "SETUP":
await self._authenticate_connection()
elif message["type"] == "AUTH_STATE":
if message["state"] == "AUTHORIZED":
logger.debug("Websocket connection established.")
self._authenticated = True
self._heartbeat_task = asyncio.create_task(
self._heartbeat()
)
elif message["type"] == "CHANNEL_OPENED":
channel = next(
k
for k, v in self._channels.items()
if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
logger.debug("Channel opened: %s", message)
elif message["type"] == "CHANNEL_CLOSED":
channel = next(
k
for k, v in self._channels.items()
if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
logger.debug("Channel closed: %s", message)
elif message["type"] == "FEED_CONFIG":
logger.debug("Feed configured: %s", message)
elif message["type"] == "FEED_DATA":
await self._map_message(message["data"])
elif message["type"] == "KEEPALIVE":
pass
else:
raise TastytradeError("Unknown message type:", message)
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
continue

async def _setup_connection(self):
message = {
Expand Down
Loading

0 comments on commit 6b0506b

Please sign in to comment.