Skip to content

Commit

Permalink
add extensive tests (cert session only); coverage >95 enforced
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Nov 15, 2023
1 parent bd74562 commit 14036bb
Show file tree
Hide file tree
Showing 19 changed files with 267 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
mypy -p tastytrade
- name: Testing...
run: |
python -m pytest --cov=tastytrade --cov-report=term-missing tests/
python -m pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95
env:
TT_USERNAME: ${{ secrets.TT_USERNAME }}
TT_PASSWORD: ${{ secrets.TT_PASSWORD }}
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ lint:
mypy -p tests

test:
python -m pytest --cov=tastytrade --cov-report=term-missing tests/

test:
python -m pytest --cov=tastytrade --cov-report=term-missing tests/
python -m pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95

install:
env/bin/pip install -e .
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
installation
sessions
data-streamer
watchlists

.. toctree::
:maxdepth: 2
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ websockets==11.0.3
pydantic==1.10.11
pytest==7.4.0
pytest_cov==4.1.0
pytest-asyncio==0.21.1
36 changes: 22 additions & 14 deletions tastytrade/account.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import requests
from pydantic import BaseModel

from tastytrade.order import (InstrumentType, NewOrder, OrderStatus,
PlacedOrder, PlacedOrderResponse, PriceEffect)
from tastytrade.session import Session
from tastytrade.session import ProductionSession, Session
from tastytrade.utils import (TastytradeError, TastytradeJsonDataclass,
validate_response)


class EmptyDict(BaseModel):
class Config:
extra = 'forbid'


class AccountBalance(TastytradeJsonDataclass):
"""
Dataclass containing account balance information.
Expand Down Expand Up @@ -168,21 +174,21 @@ class MarginReportEntry(TastytradeJsonDataclass):
code: str
underlying_symbol: str
underlying_type: str
expected_price_range_up_percent: Decimal
expected_price_range_down_percent: Decimal
point_of_no_return_percent: Decimal
margin_calculation_type: str
margin_requirement: Decimal
margin_requirement_effect: PriceEffect
initial_requirement: Decimal
initial_requirement_effect: PriceEffect
maintenance_requirement: Decimal
maintenance_requirement_effect: PriceEffect
buying_power: Decimal
buying_power_effect: PriceEffect
groups: List[Dict[str, Any]]
price_increase_percent: Decimal
price_decrease_percent: Decimal
expected_price_range_up_percent: Optional[Decimal] = None
expected_price_range_down_percent: Optional[Decimal] = None
point_of_no_return_percent: Optional[Decimal] = None
initial_requirement: Optional[Decimal] = None
initial_requirement_effect: Optional[PriceEffect] = None


class MarginReport(TastytradeJsonDataclass):
Expand All @@ -207,8 +213,8 @@ class MarginReport(TastytradeJsonDataclass):
reg_t_option_buying_power_effect: PriceEffect
maintenance_excess: Decimal
maintenance_excess_effect: PriceEffect
groups: List[MarginReportEntry]
last_state_timestamp: int
groups: List[Union[MarginReportEntry, EmptyDict]]
initial_requirement: Optional[Decimal] = None
initial_requirement_effect: Optional[PriceEffect] = None

Expand Down Expand Up @@ -706,15 +712,16 @@ def get_total_fees(

def get_net_liquidating_value_history(
self,
session: Session,
session: ProductionSession,
time_back: Optional[str] = None,
start_time: Optional[datetime] = None
) -> List[NetLiqOhlc]:
) -> List[NetLiqOhlc]: # pragma: no cover
"""
Returns a list of account net liquidating value snapshots over the
specified time period.
:param session: the session to use for the request.
:param session:
the session to use for the request, can't be certification.
:param time_back:
the time period to get net liquidating value snapshots for. This
param is required if start_time is not given. Possible values are:
Expand Down Expand Up @@ -766,13 +773,14 @@ def get_position_limit(self, session: Session) -> PositionLimit:

def get_effective_margin_requirements(
self,
session: Session,
session: ProductionSession,
symbol: str
) -> MarginRequirement:
) -> MarginRequirement: # pragma: no cover
"""
Get the effective margin requirements for a given symbol.
:param session: the session to use for the request.
:param session:
the session to use for the request, can't be certification
:param symbol: the symbol to get margin requirements for.
:return: a :class:`MarginRequirement` object.
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/dxfeed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class EventType(str, Enum):

class Event(ABC):
@classmethod
def from_stream(cls, data: list) -> List['Event']:
def from_stream(cls, data: list) -> List['Event']: # pragma: no cover
"""
Makes a list of event objects from a list of raw trade data fetched by
a :class:`~tastyworks.streamer.DataStreamer`.
Expand Down
6 changes: 3 additions & 3 deletions tastytrade/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import requests

from tastytrade.order import InstrumentType, TradeableTastytradeJsonDataclass
from tastytrade.session import Session
from tastytrade.session import ProductionSession, Session
from tastytrade.utils import TastytradeJsonDataclass, validate_response


Expand Down Expand Up @@ -1060,9 +1060,9 @@ def get_option_chain(


def get_future_option_chain(
session: Session,
session: ProductionSession,
symbol: str
) -> Dict[date, List[FutureOption]]:
) -> Dict[date, List[FutureOption]]: # pragma: no cover
"""
Returns a mapping of expiration date to a list of futures options
objects representing the options chain for the given symbol.
Expand Down
15 changes: 9 additions & 6 deletions tastytrade/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import requests

from tastytrade.session import Session
from tastytrade.session import ProductionSession, Session
from tastytrade.utils import TastytradeJsonDataclass, validate_response


Expand Down Expand Up @@ -91,9 +91,9 @@ class MarketMetricInfo(TastytradeJsonDataclass):


def get_market_metrics(
session: Session,
session: ProductionSession,
symbols: List[str]
) -> List[MarketMetricInfo]:
) -> List[MarketMetricInfo]: # pragma: no cover
"""
Retrieves market metrics for the given symbols.
Expand All @@ -114,7 +114,10 @@ def get_market_metrics(
return [MarketMetricInfo(**entry) for entry in data]


def get_dividends(session: Session, symbol: str) -> List[DividendInfo]:
def get_dividends(
session: ProductionSession,
symbol: str
) -> List[DividendInfo]: # pragma: no cover
"""
Retrieves dividend information for the given symbol.
Expand All @@ -136,10 +139,10 @@ def get_dividends(session: Session, symbol: str) -> List[DividendInfo]:


def get_earnings(
session: Session,
session: ProductionSession,
symbol: str,
start_date: date
) -> List[EarningsInfo]:
) -> List[EarningsInfo]: # pragma: no cover
"""
Retrieves earnings information for the given symbol.
Expand Down
7 changes: 5 additions & 2 deletions tastytrade/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import requests

from tastytrade.session import Session
from tastytrade.session import ProductionSession
from tastytrade.utils import TastytradeJsonDataclass


Expand All @@ -14,7 +14,10 @@ class SymbolData(TastytradeJsonDataclass):
description: str


def symbol_search(session: Session, symbol: str) -> List[SymbolData]:
def symbol_search(
session: ProductionSession,
symbol: str
) -> List[SymbolData]: # pragma: no cover
"""
Performs a symbol search using the Tastytrade API and returns a
list of symbols that are similar to the given search phrase.
Expand Down
4 changes: 2 additions & 2 deletions tastytrade/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.validate()


class ProductionSession(Session):
class ProductionSession(Session): # pragma: no cover
"""
Contains a local user login which can then be used to interact with the
remote API.
Expand Down Expand Up @@ -326,7 +326,7 @@ def get_time_and_sale(
def _map_event(
event_type: str,
event_dict: Any # Usually Dict[str, Any]; sometimes a list
) -> Event:
) -> Event: # pragma: no cover
"""
Parses the raw JSON data from the dxfeed REST API into event objects.
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async def _subscribe(
await self._websocket.send(json.dumps(message)) # type: ignore


class DataStreamer:
class DataStreamer: # pragma: no cover
"""
A :class:`DataStreamer` object is used to fetch quotes or greeks
for a given symbol or list of symbols. It should always be
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Config:
allow_population_by_field_name = True


def validate_response(response: Response) -> None:
def validate_response(response: Response) -> None: # pragma: no cover
"""
Checks if the given code is an error; if so, raises an exception.
Expand Down
24 changes: 12 additions & 12 deletions tastytrade/watchlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import requests

from tastytrade.instruments import InstrumentType
from tastytrade.session import Session
from tastytrade.session import ProductionSession
from tastytrade.utils import TastytradeJsonDataclass, validate_response


Expand All @@ -19,7 +19,7 @@ class Pair(TastytradeJsonDataclass):
right_quantity: int


class PairsWatchlist(TastytradeJsonDataclass):
class PairsWatchlist(TastytradeJsonDataclass): # pragma: no cover
"""
Dataclass that represents a pairs watchlist object.
"""
Expand All @@ -28,7 +28,7 @@ class PairsWatchlist(TastytradeJsonDataclass):
pairs_equations: List[Pair]

@classmethod
def get_pairs_watchlists(cls, session: Session) -> List['PairsWatchlist']:
def get_pairs_watchlists(cls, session: ProductionSession) -> List['PairsWatchlist']:
"""
Fetches a list of all Tastytrade public pairs watchlists.
Expand All @@ -48,7 +48,7 @@ def get_pairs_watchlists(cls, session: Session) -> List['PairsWatchlist']:
@classmethod
def get_pairs_watchlist(
cls,
session: Session,
session: ProductionSession,
name: str
) -> 'PairsWatchlist':
"""
Expand All @@ -70,7 +70,7 @@ def get_pairs_watchlist(
return cls(**data)


class Watchlist(TastytradeJsonDataclass):
class Watchlist(TastytradeJsonDataclass): # pragma: no cover
"""
Dataclass that represents a watchlist object (public or private),
with functions to update, publish, modify and remove watchlists.
Expand All @@ -83,7 +83,7 @@ class Watchlist(TastytradeJsonDataclass):
@classmethod
def get_public_watchlists(
cls,
session: Session,
session: ProductionSession,
counts_only: bool = False
) -> List['Watchlist']:
"""
Expand All @@ -106,7 +106,7 @@ def get_public_watchlists(
return [cls(**entry) for entry in data]

@classmethod
def get_public_watchlist(cls, session: Session, name: str) -> 'Watchlist':
def get_public_watchlist(cls, session: ProductionSession, name: str) -> 'Watchlist':
"""
Fetches a Tastytrade public watchlist by name.
Expand All @@ -126,7 +126,7 @@ def get_public_watchlist(cls, session: Session, name: str) -> 'Watchlist':
return cls(**data)

@classmethod
def get_private_watchlists(cls, session: Session) -> List['Watchlist']:
def get_private_watchlists(cls, session: ProductionSession) -> List['Watchlist']:
"""
Fetches a the user's private watchlists.
Expand All @@ -145,7 +145,7 @@ def get_private_watchlists(cls, session: Session) -> List['Watchlist']:
return [cls(**entry) for entry in data]

@classmethod
def get_private_watchlist(cls, session: Session, name: str) -> 'Watchlist':
def get_private_watchlist(cls, session: ProductionSession, name: str) -> 'Watchlist':
"""
Fetches a user's watchlist by name.
Expand All @@ -165,7 +165,7 @@ def get_private_watchlist(cls, session: Session, name: str) -> 'Watchlist':
return cls(**data)

@classmethod
def remove_private_watchlist(cls, session: Session, name: str) -> None:
def remove_private_watchlist(cls, session: ProductionSession, name: str) -> None:
"""
Deletes the named private watchlist.
Expand All @@ -178,7 +178,7 @@ def remove_private_watchlist(cls, session: Session, name: str) -> None:
)
validate_response(response)

def upload_private_watchlist(self, session: Session) -> None:
def upload_private_watchlist(self, session: ProductionSession) -> None:
"""
Creates a private remote watchlist identical to this local one.
Expand All @@ -191,7 +191,7 @@ def upload_private_watchlist(self, session: Session) -> None:
)
validate_response(response)

def update_private_watchlist(self, session: Session) -> None:
def update_private_watchlist(self, session: ProductionSession) -> None:
"""
Updates the existing private remote watchlist.
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import pytest

from tastytrade import CertificationSession


@pytest.fixture(scope='session')
def session():
username = os.environ.get('TT_USERNAME', None)
password = os.environ.get('TT_PASSWORD', None)

session = CertificationSession(username, password)
yield session

session.destroy()
Loading

0 comments on commit 14036bb

Please sign in to comment.