From 00df153060439ad4db2ca84a94c7e3ec94af91f4 Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Mon, 30 Sep 2024 16:18:34 -0500 Subject: [PATCH] add backtesting --- Makefile | 4 +- pyproject.toml | 1 + tastytrade/account.py | 65 ++++++++++--- tastytrade/backtest.py | 205 +++++++++++++++++++++++++++++++++++++++++ tastytrade/session.py | 14 +-- tastytrade/utils.py | 5 +- tests/test_backtest.py | 28 ++++++ uv.lock | 2 + 8 files changed, 293 insertions(+), 31 deletions(-) create mode 100644 tastytrade/backtest.py create mode 100644 tests/test_backtest.py diff --git a/Makefile b/Makefile index 0b83cb1..e81c09d 100644 --- a/Makefile +++ b/Makefile @@ -5,8 +5,8 @@ install: uv pip install -e . lint: - uv run ruff check tastytrade/ - uv run ruff check tests/ + uv run ruff check . + uv run ruff format . uv run mypy -p tastytrade uv run mypy -p tests diff --git a/pyproject.toml b/pyproject.toml index 432317d..60ab408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ authors = [ dependencies = [ "fake-useragent>=1.5.1", + "httpx>=0.27.2", "pandas-market-calendars>=4.4.1", "pydantic>=2.9.2", "requests>=2.32.3", diff --git a/tastytrade/account.py b/tastytrade/account.py index 1551eb1..8b8adc4 100644 --- a/tastytrade/account.py +++ b/tastytrade/account.py @@ -1,6 +1,6 @@ from datetime import date, datetime from decimal import Decimal -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel @@ -474,28 +474,63 @@ def get_balances(self, session: Session) -> AccountBalance: def get_balance_snapshots( self, session: Session, + per_page: int = 250, + page_offset: Optional[int] = None, + currency: str = "USD", + end_date: Optional[date] = None, + start_date: Optional[date] = None, snapshot_date: Optional[date] = None, - time_of_day: Optional[str] = None, + time_of_day: Literal["BOD", "EOD"] = "EOD", ) -> List[AccountBalanceSnapshot]: """ - Returns a list of two balance snapshots. The first one is the - specified date, or, if not provided, the oldest snapshot available. - The second one is the most recent snapshot. - - If you provide the snapshot date, you must also provide the time of - day. + Returns a list of balance snapshots. This list will + just have a few snapshots if you don't pass a start + date; otherwise, it will be each day's balances in + the given range. :param session: the session to use for the request. + :param currency: the currency to show balances in. + :param start_date: the starting date of the range. + :param end_date: the ending date of the range. :param snapshot_date: the date of the snapshot to get. :param time_of_day: - the time of day of the snapshot to get, either 'EOD' or 'BOD'. + the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD' (Beginning Of Day). """ - params = {"snapshot-date": snapshot_date, "time-of-day": time_of_day} - data = session.get( - f"/accounts/{self.account_number}/balance-snapshots", - params={k: v for k, v in params.items() if v is not None}, - ) - return [AccountBalanceSnapshot(**i) for i in data["items"]] + paginate = False + if page_offset is None: + page_offset = 0 + paginate = True + params = { + "per-page": per_page, + "page-offset": page_offset, + "currency": currency, + "end-date": end_date, + "start-date": start_date, + "snapshot-date": snapshot_date, + "time-of-day": time_of_day, + } + snapshots = [] + while True: + response = session.client.get( + (f"{session.base_url}/accounts/{self.account_number}/balance-snapshots"), + params={ + k: v # type: ignore + for k, v in params.items() + if v is not None + }, + ) + validate_response(response) + json = response.json() + snapshots.extend([AccountBalanceSnapshot(**i) for i in json["data"]["items"]]) + # handle pagination + pagination = json["pagination"] + if ( + pagination["page-offset"] >= pagination["total-pages"] - 1 + or not paginate + ): + break + params["page-offset"] += 1 # type: ignore + return snapshots def get_positions( self, diff --git a/tastytrade/backtest.py b/tastytrade/backtest.py new file mode 100644 index 0000000..1ebb9aa --- /dev/null +++ b/tastytrade/backtest.py @@ -0,0 +1,205 @@ +import asyncio +from datetime import date, datetime +from decimal import Decimal +from typing import AsyncGenerator, List, Literal, Optional + +import httpx +from fake_useragent import UserAgent # type: ignore +from pydantic import BaseModel, Field +from pydantic.alias_generators import to_camel + +from tastytrade import BACKTEST_URL +from tastytrade.session import Session +from tastytrade.utils import ( + TastytradeError, + validate_response, +) + + +class BacktestJsonDataclass(BaseModel): + """ + Dataclass for converting backtest JSON naming conventions to snake case. + """ + + class Config: + alias_generator = to_camel + populate_by_name = True + + +class BacktestEntry(BacktestJsonDataclass): + """ + Dataclass of parameters for backtest trade entry. + """ + + use_exact_DTE: bool = Field(default=True, serialization_alias="useExactDTE") + maximum_active_trials: Optional[int] = None + maximum_active_trials_behavior: Optional[Literal["close oldest", "don't enter"]] = ( + None + ) + frequency: str = "every day" + + +class BacktestExit(BacktestJsonDataclass): + """ + Dataclass of parameters for backtest trade exit. + """ + + after_days_in_trade: Optional[int] = None + stop_loss_percentage: Optional[int] = None + take_profit_percentage: Optional[int] = None + at_days_to_expiration: Optional[int] = None + + +class BacktestLeg(BacktestJsonDataclass): + """ + Dataclass of parameters for placing legs of backtest trades. + Leg delta must be a multiple of 5. + """ + + days_until_expiration: int = 45 + delta: int = 15 + direction: Literal["buy", "sell"] = "sell" + quantity: int = 1 + side: Literal["call", "put"] = "call" + + +class Backtest(BacktestJsonDataclass): + """ + Dataclass of configuration options for a backtest. + Date must be <= 2024-07-31. + """ + + symbol: str + entry_conditions: BacktestEntry + exit_conditions: BacktestExit + legs: List[BacktestLeg] + start_date: date + end_date: date = date(2024, 7, 31) + status: str = "pending" + + +class BacktestSnapshot(BacktestJsonDataclass): + """ + Dataclass containing a snapshot in time during the backtest. + """ + + date_time: datetime + profit_loss: Decimal + normalized_underlying_price: Optional[Decimal] = None + underlying_price: Optional[Decimal] = None + + +class BacktestTrial(BacktestJsonDataclass): + """ + Dataclass containing information on trades placed during the backtest. + """ + + close_date_time: datetime + open_date_time: datetime + profit_loss: Decimal + + +class BacktestStatistics(BaseModel): + """ + Dataclass containing statistics on the overall performance of a backtest. + """ + + class Config: + populate_by_name = True + + avg_bp_per_trade: Decimal = Field(validation_alias="Avg. BPR per trade") + avg_daily_pnl_change: Decimal = Field(validation_alias="Avg. daily change in PNL") + avg_daily_net_liq_change: Decimal = Field( + validation_alias="Avg. daily change in net liq" + ) + avg_days_in_trade: Decimal = Field(validation_alias="Avg. days in trade") + avg_premium: Decimal = Field(validation_alias="Avg. premium") + avg_profit_loss_per_trade: Decimal = Field( + validation_alias="Avg. profit/loss per trade" + ) + avg_return_per_trade: Decimal = Field(validation_alias="Avg. return per trade") + highest_profit: Decimal = Field(validation_alias="Highest profit") + loss_percentage: Decimal = Field(validation_alias="Loss percentage") + losses: int = Field(validation_alias="Losses") + max_drawdown: Decimal = Field(validation_alias="Max drawdown") + number_of_trades: int = Field(validation_alias="Number of trades") + premium_capture_rate: Decimal = Field(validation_alias="Premium capture rate") + return_on_used_capital: Decimal = Field(validation_alias="Return on used capital") + total_fees: Decimal = Field(validation_alias="Total fees") + total_premium: Decimal = Field(validation_alias="Total premium") + total_profit_loss: Decimal = Field(validation_alias="Total profit/loss") + used_capital: Decimal = Field(validation_alias="Used capital") + win_percentage: Decimal = Field(validation_alias="Win percentage") + wins: int = Field(validation_alias="Wins") + worst_loss: Decimal = Field(validation_alias="Worst loss") + + +class BacktestResults(BacktestJsonDataclass): + """ + Dataclass containing partial or finished results of a backtest. + """ + + snapshots: Optional[List[BacktestSnapshot]] + statistics: Optional[BacktestStatistics] + trials: Optional[List[BacktestTrial]] + + +class BacktestResponse(Backtest): + """ + Dataclass containing a backtest and associated information. + """ + + created_at: datetime + id: str + results: BacktestResults + eta: Optional[int] = None + progress: Optional[Decimal] = None + + +class BacktestSession: + """ + Class for creating a backtesting session which can be reused for multiple backtests. + + Example usage:: + + from tastytrade import BacktestSession, Backtest + from tqdm.asyncio import tqdm # progress bar + + backtest = Backtest(...) + backtest_session = BacktestSession(session) + results = [r async for r in tqdm(backtest_session.run(backtest))] + print(results[-1]) + + """ + + def __init__(self, session: Session): + if session.is_test: + raise TastytradeError("Certification sessions can't run backtests!") + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": UserAgent().random, + } + # Pull backtest token + response = httpx.post( + f"{BACKTEST_URL}/sessions", + json={"tastytradeToken": session.session_token}, + ) + validate_response(response) + # Token used for backtesting + backtest_token = response.json()["token"] + headers["Authorization"] = f"Bearer {backtest_token}" + self.client = httpx.AsyncClient(base_url=BACKTEST_URL, headers=headers) + + async def run(self, backtest: Backtest) -> AsyncGenerator[BacktestResponse, None]: + json = backtest.model_dump_json(by_alias=True, exclude_none=True) + response = await self.client.post("/backtests", data=json) # type: ignore + validate_response(response) + results = BacktestResponse(**response.json()) + while results.status != "completed": + yield results + await asyncio.sleep(0.5) + response = await self.client.get(f"/backtests/{results.id}") + validate_response(response) + results = BacktestResponse(**response.json()) + yield results diff --git a/tastytrade/session.py b/tastytrade/session.py index a5f984d..83b587d 100644 --- a/tastytrade/session.py +++ b/tastytrade/session.py @@ -3,7 +3,7 @@ import requests from fake_useragent import UserAgent # type: ignore -from tastytrade import API_URL, BACKTEST_URL, CERT_URL +from tastytrade import API_URL, CERT_URL from tastytrade.utils import TastytradeError, TastytradeJsonDataclass, validate_response @@ -96,18 +96,6 @@ def __init__( self.streamer_token = data["token"] #: URL for dxfeed websocket self.dxlink_url = data["dxlink-url"] - if not is_test: - # Pull backtest token - response = requests.post( - f"{BACKTEST_URL}/sessions", - headers=headers, - json={"tastytradeToken": self.session_token}, - ) - validate_response(response) - #: Token used for backtesting - self.backtest_token = response.json()["token"] - else: - self.backtest_token = None def get(self, url, **kwargs) -> Dict[str, Any]: response = self.client.get(self.base_url + url, timeout=30, **kwargs) diff --git a/tastytrade/utils.py b/tastytrade/utils.py index cba2e0d..32c26dc 100644 --- a/tastytrade/utils.py +++ b/tastytrade/utils.py @@ -1,7 +1,9 @@ from datetime import date, datetime, timedelta +from typing import Union import pandas_market_calendars as mcal # type: ignore import pytz +from httpx._models import Response as HTTPXReponse from pydantic import BaseModel from requests import Response @@ -202,12 +204,13 @@ class TastytradeJsonDataclass(BaseModel): A pydantic dataclass that converts keys from snake case to dasherized and performs type validation and coercion. """ + class Config: alias_generator = _dasherize populate_by_name = True -def validate_response(response: Response) -> None: +def validate_response(response: Union[Response, HTTPXReponse]) -> None: """ Checks if the given code is an error; if so, raises an exception. diff --git a/tests/test_backtest.py b/tests/test_backtest.py new file mode 100644 index 0000000..a22da95 --- /dev/null +++ b/tests/test_backtest.py @@ -0,0 +1,28 @@ +from datetime import timedelta + +import pytest + +from tastytrade import today_in_new_york +from tastytrade.backtest import ( + Backtest, + BacktestEntry, + BacktestExit, + BacktestLeg, + BacktestSession, +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_backtest_simple(session): + backtest_session = BacktestSession(session) + backtest = Backtest( + symbol="SPY", + entry_conditions=BacktestEntry(), + exit_conditions=BacktestExit(at_days_to_expiration=21), + legs=[BacktestLeg(), BacktestLeg(side="put")], + start_date=today_in_new_york() - timedelta(days=365), + ) + results = [r async for r in backtest_session.run(backtest)] + assert results[-1].status == "completed" diff --git a/uv.lock b/uv.lock index 09e1e51..9dd8c45 100644 --- a/uv.lock +++ b/uv.lock @@ -2659,6 +2659,7 @@ version = "8.4" source = { virtual = "." } dependencies = [ { name = "fake-useragent", marker = "python_full_version >= '3.10'" }, + { name = "httpx", marker = "python_full_version >= '3.10'" }, { name = "pandas-market-calendars", marker = "python_full_version >= '3.10'" }, { name = "pydantic", marker = "python_full_version >= '3.10'" }, { name = "requests", marker = "python_full_version >= '3.10'" }, @@ -2685,6 +2686,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "fake-useragent", specifier = ">=1.5.1" }, + { name = "httpx", specifier = ">=0.27.2" }, { name = "pandas-market-calendars", specifier = ">=4.4.1" }, { name = "pydantic", specifier = ">=2.9.2" }, { name = "requests", specifier = ">=2.32.3" },