Skip to content

Commit

Permalink
add backtesting
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Sep 30, 2024
1 parent 672d459 commit 00df153
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 31 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ install:
uv pip install -e .

lint:
uv run ruff check tastytrade/
uv run ruff check tests/
uv run ruff check .
uv run ruff format .
uv run mypy -p tastytrade
uv run mypy -p tests

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ authors = [

dependencies = [
"fake-useragent>=1.5.1",
"httpx>=0.27.2",
"pandas-market-calendars>=4.4.1",
"pydantic>=2.9.2",
"requests>=2.32.3",
Expand Down
65 changes: 50 additions & 15 deletions tastytrade/account.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -474,28 +474,63 @@ def get_balances(self, session: Session) -> AccountBalance:
def get_balance_snapshots(
self,
session: Session,
per_page: int = 250,
page_offset: Optional[int] = None,
currency: str = "USD",
end_date: Optional[date] = None,
start_date: Optional[date] = None,
snapshot_date: Optional[date] = None,
time_of_day: Optional[str] = None,
time_of_day: Literal["BOD", "EOD"] = "EOD",
) -> List[AccountBalanceSnapshot]:
"""
Returns a list of two balance snapshots. The first one is the
specified date, or, if not provided, the oldest snapshot available.
The second one is the most recent snapshot.
If you provide the snapshot date, you must also provide the time of
day.
Returns a list of balance snapshots. This list will
just have a few snapshots if you don't pass a start
date; otherwise, it will be each day's balances in
the given range.
:param session: the session to use for the request.
:param currency: the currency to show balances in.
:param start_date: the starting date of the range.
:param end_date: the ending date of the range.
:param snapshot_date: the date of the snapshot to get.
:param time_of_day:
the time of day of the snapshot to get, either 'EOD' or 'BOD'.
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD' (Beginning Of Day).
"""
params = {"snapshot-date": snapshot_date, "time-of-day": time_of_day}
data = session.get(
f"/accounts/{self.account_number}/balance-snapshots",
params={k: v for k, v in params.items() if v is not None},
)
return [AccountBalanceSnapshot(**i) for i in data["items"]]
paginate = False
if page_offset is None:
page_offset = 0
paginate = True
params = {
"per-page": per_page,
"page-offset": page_offset,
"currency": currency,
"end-date": end_date,
"start-date": start_date,
"snapshot-date": snapshot_date,
"time-of-day": time_of_day,
}
snapshots = []
while True:
response = session.client.get(
(f"{session.base_url}/accounts/{self.account_number}/balance-snapshots"),
params={
k: v # type: ignore
for k, v in params.items()
if v is not None
},
)
validate_response(response)
json = response.json()
snapshots.extend([AccountBalanceSnapshot(**i) for i in json["data"]["items"]])
# handle pagination
pagination = json["pagination"]
if (
pagination["page-offset"] >= pagination["total-pages"] - 1
or not paginate
):
break
params["page-offset"] += 1 # type: ignore
return snapshots

def get_positions(
self,
Expand Down
205 changes: 205 additions & 0 deletions tastytrade/backtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import asyncio
from datetime import date, datetime
from decimal import Decimal
from typing import AsyncGenerator, List, Literal, Optional

import httpx
from fake_useragent import UserAgent # type: ignore
from pydantic import BaseModel, Field
from pydantic.alias_generators import to_camel

from tastytrade import BACKTEST_URL
from tastytrade.session import Session
from tastytrade.utils import (
TastytradeError,
validate_response,
)


class BacktestJsonDataclass(BaseModel):
"""
Dataclass for converting backtest JSON naming conventions to snake case.
"""

class Config:
alias_generator = to_camel
populate_by_name = True


class BacktestEntry(BacktestJsonDataclass):
"""
Dataclass of parameters for backtest trade entry.
"""

use_exact_DTE: bool = Field(default=True, serialization_alias="useExactDTE")
maximum_active_trials: Optional[int] = None
maximum_active_trials_behavior: Optional[Literal["close oldest", "don't enter"]] = (
None
)
frequency: str = "every day"


class BacktestExit(BacktestJsonDataclass):
"""
Dataclass of parameters for backtest trade exit.
"""

after_days_in_trade: Optional[int] = None
stop_loss_percentage: Optional[int] = None
take_profit_percentage: Optional[int] = None
at_days_to_expiration: Optional[int] = None


class BacktestLeg(BacktestJsonDataclass):
"""
Dataclass of parameters for placing legs of backtest trades.
Leg delta must be a multiple of 5.
"""

days_until_expiration: int = 45
delta: int = 15
direction: Literal["buy", "sell"] = "sell"
quantity: int = 1
side: Literal["call", "put"] = "call"


class Backtest(BacktestJsonDataclass):
"""
Dataclass of configuration options for a backtest.
Date must be <= 2024-07-31.
"""

symbol: str
entry_conditions: BacktestEntry
exit_conditions: BacktestExit
legs: List[BacktestLeg]
start_date: date
end_date: date = date(2024, 7, 31)
status: str = "pending"


class BacktestSnapshot(BacktestJsonDataclass):
"""
Dataclass containing a snapshot in time during the backtest.
"""

date_time: datetime
profit_loss: Decimal
normalized_underlying_price: Optional[Decimal] = None
underlying_price: Optional[Decimal] = None


class BacktestTrial(BacktestJsonDataclass):
"""
Dataclass containing information on trades placed during the backtest.
"""

close_date_time: datetime
open_date_time: datetime
profit_loss: Decimal


class BacktestStatistics(BaseModel):
"""
Dataclass containing statistics on the overall performance of a backtest.
"""

class Config:
populate_by_name = True

avg_bp_per_trade: Decimal = Field(validation_alias="Avg. BPR per trade")
avg_daily_pnl_change: Decimal = Field(validation_alias="Avg. daily change in PNL")
avg_daily_net_liq_change: Decimal = Field(
validation_alias="Avg. daily change in net liq"
)
avg_days_in_trade: Decimal = Field(validation_alias="Avg. days in trade")
avg_premium: Decimal = Field(validation_alias="Avg. premium")
avg_profit_loss_per_trade: Decimal = Field(
validation_alias="Avg. profit/loss per trade"
)
avg_return_per_trade: Decimal = Field(validation_alias="Avg. return per trade")
highest_profit: Decimal = Field(validation_alias="Highest profit")
loss_percentage: Decimal = Field(validation_alias="Loss percentage")
losses: int = Field(validation_alias="Losses")
max_drawdown: Decimal = Field(validation_alias="Max drawdown")
number_of_trades: int = Field(validation_alias="Number of trades")
premium_capture_rate: Decimal = Field(validation_alias="Premium capture rate")
return_on_used_capital: Decimal = Field(validation_alias="Return on used capital")
total_fees: Decimal = Field(validation_alias="Total fees")
total_premium: Decimal = Field(validation_alias="Total premium")
total_profit_loss: Decimal = Field(validation_alias="Total profit/loss")
used_capital: Decimal = Field(validation_alias="Used capital")
win_percentage: Decimal = Field(validation_alias="Win percentage")
wins: int = Field(validation_alias="Wins")
worst_loss: Decimal = Field(validation_alias="Worst loss")


class BacktestResults(BacktestJsonDataclass):
"""
Dataclass containing partial or finished results of a backtest.
"""

snapshots: Optional[List[BacktestSnapshot]]
statistics: Optional[BacktestStatistics]
trials: Optional[List[BacktestTrial]]


class BacktestResponse(Backtest):
"""
Dataclass containing a backtest and associated information.
"""

created_at: datetime
id: str
results: BacktestResults
eta: Optional[int] = None
progress: Optional[Decimal] = None


class BacktestSession:
"""
Class for creating a backtesting session which can be reused for multiple backtests.
Example usage::
from tastytrade import BacktestSession, Backtest
from tqdm.asyncio import tqdm # progress bar
backtest = Backtest(...)
backtest_session = BacktestSession(session)
results = [r async for r in tqdm(backtest_session.run(backtest))]
print(results[-1])
"""

def __init__(self, session: Session):
if session.is_test:
raise TastytradeError("Certification sessions can't run backtests!")
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": UserAgent().random,
}
# Pull backtest token
response = httpx.post(
f"{BACKTEST_URL}/sessions",
json={"tastytradeToken": session.session_token},
)
validate_response(response)
# Token used for backtesting
backtest_token = response.json()["token"]
headers["Authorization"] = f"Bearer {backtest_token}"
self.client = httpx.AsyncClient(base_url=BACKTEST_URL, headers=headers)

async def run(self, backtest: Backtest) -> AsyncGenerator[BacktestResponse, None]:
json = backtest.model_dump_json(by_alias=True, exclude_none=True)
response = await self.client.post("/backtests", data=json) # type: ignore
validate_response(response)
results = BacktestResponse(**response.json())
while results.status != "completed":
yield results
await asyncio.sleep(0.5)
response = await self.client.get(f"/backtests/{results.id}")
validate_response(response)
results = BacktestResponse(**response.json())
yield results
14 changes: 1 addition & 13 deletions tastytrade/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests
from fake_useragent import UserAgent # type: ignore

from tastytrade import API_URL, BACKTEST_URL, CERT_URL
from tastytrade import API_URL, CERT_URL
from tastytrade.utils import TastytradeError, TastytradeJsonDataclass, validate_response


Expand Down Expand Up @@ -96,18 +96,6 @@ def __init__(
self.streamer_token = data["token"]
#: URL for dxfeed websocket
self.dxlink_url = data["dxlink-url"]
if not is_test:
# Pull backtest token
response = requests.post(
f"{BACKTEST_URL}/sessions",
headers=headers,
json={"tastytradeToken": self.session_token},
)
validate_response(response)
#: Token used for backtesting
self.backtest_token = response.json()["token"]
else:
self.backtest_token = None

def get(self, url, **kwargs) -> Dict[str, Any]:
response = self.client.get(self.base_url + url, timeout=30, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion tastytrade/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from datetime import date, datetime, timedelta
from typing import Union

import pandas_market_calendars as mcal # type: ignore
import pytz
from httpx._models import Response as HTTPXReponse
from pydantic import BaseModel
from requests import Response

Expand Down Expand Up @@ -202,12 +204,13 @@ class TastytradeJsonDataclass(BaseModel):
A pydantic dataclass that converts keys from snake case to dasherized
and performs type validation and coercion.
"""

class Config:
alias_generator = _dasherize
populate_by_name = True


def validate_response(response: Response) -> None:
def validate_response(response: Union[Response, HTTPXReponse]) -> None:
"""
Checks if the given code is an error; if so, raises an exception.
Expand Down
28 changes: 28 additions & 0 deletions tests/test_backtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from datetime import timedelta

import pytest

from tastytrade import today_in_new_york
from tastytrade.backtest import (
Backtest,
BacktestEntry,
BacktestExit,
BacktestLeg,
BacktestSession,
)

pytest_plugins = ("pytest_asyncio",)


@pytest.mark.asyncio
async def test_backtest_simple(session):
backtest_session = BacktestSession(session)
backtest = Backtest(
symbol="SPY",
entry_conditions=BacktestEntry(),
exit_conditions=BacktestExit(at_days_to_expiration=21),
legs=[BacktestLeg(), BacktestLeg(side="put")],
start_date=today_in_new_york() - timedelta(days=365),
)
results = [r async for r in backtest_session.run(backtest)]
assert results[-1].status == "completed"
Loading

0 comments on commit 00df153

Please sign in to comment.