From 14036bb011f115f3f8227742f0eccc274007ffd7 Mon Sep 17 00:00:00 2001 From: Graeme22 Date: Tue, 14 Nov 2023 19:21:22 -0500 Subject: [PATCH] add extensive tests (cert session only); coverage >95 enforced --- .github/workflows/python-app.yml | 2 +- Makefile | 5 +- docs/index.rst | 1 + requirements.txt | 1 + tastytrade/account.py | 36 +++++++----- tastytrade/dxfeed/event.py | 2 +- tastytrade/instruments.py | 6 +- tastytrade/metrics.py | 15 +++-- tastytrade/search.py | 7 ++- tastytrade/session.py | 4 +- tastytrade/streamer.py | 2 +- tastytrade/utils.py | 2 +- tastytrade/watchlists.py | 24 ++++---- tests/conftest.py | 15 +++++ tests/test_account.py | 97 ++++++++++++++++++++++++++++++++ tests/test_instruments.py | 59 +++++++++++++++++++ tests/test_metrics.py | 5 ++ tests/test_session.py | 12 ++-- tests/test_streamer.py | 23 ++++++++ 19 files changed, 267 insertions(+), 51 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_account.py create mode 100644 tests/test_instruments.py create mode 100644 tests/test_metrics.py create mode 100644 tests/test_streamer.py diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 6d26701..c520ac0 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -31,7 +31,7 @@ jobs: mypy -p tastytrade - name: Testing... run: | - python -m pytest --cov=tastytrade --cov-report=term-missing tests/ + python -m pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95 env: TT_USERNAME: ${{ secrets.TT_USERNAME }} TT_PASSWORD: ${{ secrets.TT_PASSWORD }} diff --git a/Makefile b/Makefile index fa448eb..0f44586 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,7 @@ lint: mypy -p tests test: - python -m pytest --cov=tastytrade --cov-report=term-missing tests/ - -test: - python -m pytest --cov=tastytrade --cov-report=term-missing tests/ + python -m pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95 install: env/bin/pip install -e . diff --git a/docs/index.rst b/docs/index.rst index 85c099a..4f04e59 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,6 +7,7 @@ installation sessions data-streamer + watchlists .. toctree:: :maxdepth: 2 diff --git a/requirements.txt b/requirements.txt index 06d349f..d9be753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ websockets==11.0.3 pydantic==1.10.11 pytest==7.4.0 pytest_cov==4.1.0 +pytest-asyncio==0.21.1 diff --git a/tastytrade/account.py b/tastytrade/account.py index 6edf380..b5697d2 100644 --- a/tastytrade/account.py +++ b/tastytrade/account.py @@ -1,16 +1,22 @@ from datetime import date, datetime from decimal import Decimal -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import requests +from pydantic import BaseModel from tastytrade.order import (InstrumentType, NewOrder, OrderStatus, PlacedOrder, PlacedOrderResponse, PriceEffect) -from tastytrade.session import Session +from tastytrade.session import ProductionSession, Session from tastytrade.utils import (TastytradeError, TastytradeJsonDataclass, validate_response) +class EmptyDict(BaseModel): + class Config: + extra = 'forbid' + + class AccountBalance(TastytradeJsonDataclass): """ Dataclass containing account balance information. @@ -168,14 +174,9 @@ class MarginReportEntry(TastytradeJsonDataclass): code: str underlying_symbol: str underlying_type: str - expected_price_range_up_percent: Decimal - expected_price_range_down_percent: Decimal - point_of_no_return_percent: Decimal margin_calculation_type: str margin_requirement: Decimal margin_requirement_effect: PriceEffect - initial_requirement: Decimal - initial_requirement_effect: PriceEffect maintenance_requirement: Decimal maintenance_requirement_effect: PriceEffect buying_power: Decimal @@ -183,6 +184,11 @@ class MarginReportEntry(TastytradeJsonDataclass): groups: List[Dict[str, Any]] price_increase_percent: Decimal price_decrease_percent: Decimal + expected_price_range_up_percent: Optional[Decimal] = None + expected_price_range_down_percent: Optional[Decimal] = None + point_of_no_return_percent: Optional[Decimal] = None + initial_requirement: Optional[Decimal] = None + initial_requirement_effect: Optional[PriceEffect] = None class MarginReport(TastytradeJsonDataclass): @@ -207,8 +213,8 @@ class MarginReport(TastytradeJsonDataclass): reg_t_option_buying_power_effect: PriceEffect maintenance_excess: Decimal maintenance_excess_effect: PriceEffect - groups: List[MarginReportEntry] last_state_timestamp: int + groups: List[Union[MarginReportEntry, EmptyDict]] initial_requirement: Optional[Decimal] = None initial_requirement_effect: Optional[PriceEffect] = None @@ -706,15 +712,16 @@ def get_total_fees( def get_net_liquidating_value_history( self, - session: Session, + session: ProductionSession, time_back: Optional[str] = None, start_time: Optional[datetime] = None - ) -> List[NetLiqOhlc]: + ) -> List[NetLiqOhlc]: # pragma: no cover """ Returns a list of account net liquidating value snapshots over the specified time period. - :param session: the session to use for the request. + :param session: + the session to use for the request, can't be certification. :param time_back: the time period to get net liquidating value snapshots for. This param is required if start_time is not given. Possible values are: @@ -766,13 +773,14 @@ def get_position_limit(self, session: Session) -> PositionLimit: def get_effective_margin_requirements( self, - session: Session, + session: ProductionSession, symbol: str - ) -> MarginRequirement: + ) -> MarginRequirement: # pragma: no cover """ Get the effective margin requirements for a given symbol. - :param session: the session to use for the request. + :param session: + the session to use for the request, can't be certification :param symbol: the symbol to get margin requirements for. :return: a :class:`MarginRequirement` object. diff --git a/tastytrade/dxfeed/event.py b/tastytrade/dxfeed/event.py index c388d23..2cb25d8 100644 --- a/tastytrade/dxfeed/event.py +++ b/tastytrade/dxfeed/event.py @@ -25,7 +25,7 @@ class EventType(str, Enum): class Event(ABC): @classmethod - def from_stream(cls, data: list) -> List['Event']: + def from_stream(cls, data: list) -> List['Event']: # pragma: no cover """ Makes a list of event objects from a list of raw trade data fetched by a :class:`~tastyworks.streamer.DataStreamer`. diff --git a/tastytrade/instruments.py b/tastytrade/instruments.py index 0a121f6..79f7626 100644 --- a/tastytrade/instruments.py +++ b/tastytrade/instruments.py @@ -6,7 +6,7 @@ import requests from tastytrade.order import InstrumentType, TradeableTastytradeJsonDataclass -from tastytrade.session import Session +from tastytrade.session import ProductionSession, Session from tastytrade.utils import TastytradeJsonDataclass, validate_response @@ -1060,9 +1060,9 @@ def get_option_chain( def get_future_option_chain( - session: Session, + session: ProductionSession, symbol: str -) -> Dict[date, List[FutureOption]]: +) -> Dict[date, List[FutureOption]]: # pragma: no cover """ Returns a mapping of expiration date to a list of futures options objects representing the options chain for the given symbol. diff --git a/tastytrade/metrics.py b/tastytrade/metrics.py index 65c5312..9c59a7b 100644 --- a/tastytrade/metrics.py +++ b/tastytrade/metrics.py @@ -4,7 +4,7 @@ import requests -from tastytrade.session import Session +from tastytrade.session import ProductionSession, Session from tastytrade.utils import TastytradeJsonDataclass, validate_response @@ -91,9 +91,9 @@ class MarketMetricInfo(TastytradeJsonDataclass): def get_market_metrics( - session: Session, + session: ProductionSession, symbols: List[str] -) -> List[MarketMetricInfo]: +) -> List[MarketMetricInfo]: # pragma: no cover """ Retrieves market metrics for the given symbols. @@ -114,7 +114,10 @@ def get_market_metrics( return [MarketMetricInfo(**entry) for entry in data] -def get_dividends(session: Session, symbol: str) -> List[DividendInfo]: +def get_dividends( + session: ProductionSession, + symbol: str +) -> List[DividendInfo]: # pragma: no cover """ Retrieves dividend information for the given symbol. @@ -136,10 +139,10 @@ def get_dividends(session: Session, symbol: str) -> List[DividendInfo]: def get_earnings( - session: Session, + session: ProductionSession, symbol: str, start_date: date -) -> List[EarningsInfo]: +) -> List[EarningsInfo]: # pragma: no cover """ Retrieves earnings information for the given symbol. diff --git a/tastytrade/search.py b/tastytrade/search.py index b101517..4322cca 100644 --- a/tastytrade/search.py +++ b/tastytrade/search.py @@ -2,7 +2,7 @@ import requests -from tastytrade.session import Session +from tastytrade.session import ProductionSession from tastytrade.utils import TastytradeJsonDataclass @@ -14,7 +14,10 @@ class SymbolData(TastytradeJsonDataclass): description: str -def symbol_search(session: Session, symbol: str) -> List[SymbolData]: +def symbol_search( + session: ProductionSession, + symbol: str +) -> List[SymbolData]: # pragma: no cover """ Performs a symbol search using the Tastytrade API and returns a list of symbols that are similar to the given search phrase. diff --git a/tastytrade/session.py b/tastytrade/session.py index 3bba0a3..7ee9471 100644 --- a/tastytrade/session.py +++ b/tastytrade/session.py @@ -113,7 +113,7 @@ def __init__( self.validate() -class ProductionSession(Session): +class ProductionSession(Session): # pragma: no cover """ Contains a local user login which can then be used to interact with the remote API. @@ -326,7 +326,7 @@ def get_time_and_sale( def _map_event( event_type: str, event_dict: Any # Usually Dict[str, Any]; sometimes a list -) -> Event: +) -> Event: # pragma: no cover """ Parses the raw JSON data from the dxfeed REST API into event objects. diff --git a/tastytrade/streamer.py b/tastytrade/streamer.py index a8b505a..d387c46 100644 --- a/tastytrade/streamer.py +++ b/tastytrade/streamer.py @@ -248,7 +248,7 @@ async def _subscribe( await self._websocket.send(json.dumps(message)) # type: ignore -class DataStreamer: +class DataStreamer: # pragma: no cover """ A :class:`DataStreamer` object is used to fetch quotes or greeks for a given symbol or list of symbols. It should always be diff --git a/tastytrade/utils.py b/tastytrade/utils.py index beeb0ce..3e39356 100644 --- a/tastytrade/utils.py +++ b/tastytrade/utils.py @@ -30,7 +30,7 @@ class Config: allow_population_by_field_name = True -def validate_response(response: Response) -> None: +def validate_response(response: Response) -> None: # pragma: no cover """ Checks if the given code is an error; if so, raises an exception. diff --git a/tastytrade/watchlists.py b/tastytrade/watchlists.py index d6ed584..3f839da 100644 --- a/tastytrade/watchlists.py +++ b/tastytrade/watchlists.py @@ -3,7 +3,7 @@ import requests from tastytrade.instruments import InstrumentType -from tastytrade.session import Session +from tastytrade.session import ProductionSession from tastytrade.utils import TastytradeJsonDataclass, validate_response @@ -19,7 +19,7 @@ class Pair(TastytradeJsonDataclass): right_quantity: int -class PairsWatchlist(TastytradeJsonDataclass): +class PairsWatchlist(TastytradeJsonDataclass): # pragma: no cover """ Dataclass that represents a pairs watchlist object. """ @@ -28,7 +28,7 @@ class PairsWatchlist(TastytradeJsonDataclass): pairs_equations: List[Pair] @classmethod - def get_pairs_watchlists(cls, session: Session) -> List['PairsWatchlist']: + def get_pairs_watchlists(cls, session: ProductionSession) -> List['PairsWatchlist']: """ Fetches a list of all Tastytrade public pairs watchlists. @@ -48,7 +48,7 @@ def get_pairs_watchlists(cls, session: Session) -> List['PairsWatchlist']: @classmethod def get_pairs_watchlist( cls, - session: Session, + session: ProductionSession, name: str ) -> 'PairsWatchlist': """ @@ -70,7 +70,7 @@ def get_pairs_watchlist( return cls(**data) -class Watchlist(TastytradeJsonDataclass): +class Watchlist(TastytradeJsonDataclass): # pragma: no cover """ Dataclass that represents a watchlist object (public or private), with functions to update, publish, modify and remove watchlists. @@ -83,7 +83,7 @@ class Watchlist(TastytradeJsonDataclass): @classmethod def get_public_watchlists( cls, - session: Session, + session: ProductionSession, counts_only: bool = False ) -> List['Watchlist']: """ @@ -106,7 +106,7 @@ def get_public_watchlists( return [cls(**entry) for entry in data] @classmethod - def get_public_watchlist(cls, session: Session, name: str) -> 'Watchlist': + def get_public_watchlist(cls, session: ProductionSession, name: str) -> 'Watchlist': """ Fetches a Tastytrade public watchlist by name. @@ -126,7 +126,7 @@ def get_public_watchlist(cls, session: Session, name: str) -> 'Watchlist': return cls(**data) @classmethod - def get_private_watchlists(cls, session: Session) -> List['Watchlist']: + def get_private_watchlists(cls, session: ProductionSession) -> List['Watchlist']: """ Fetches a the user's private watchlists. @@ -145,7 +145,7 @@ def get_private_watchlists(cls, session: Session) -> List['Watchlist']: return [cls(**entry) for entry in data] @classmethod - def get_private_watchlist(cls, session: Session, name: str) -> 'Watchlist': + def get_private_watchlist(cls, session: ProductionSession, name: str) -> 'Watchlist': """ Fetches a user's watchlist by name. @@ -165,7 +165,7 @@ def get_private_watchlist(cls, session: Session, name: str) -> 'Watchlist': return cls(**data) @classmethod - def remove_private_watchlist(cls, session: Session, name: str) -> None: + def remove_private_watchlist(cls, session: ProductionSession, name: str) -> None: """ Deletes the named private watchlist. @@ -178,7 +178,7 @@ def remove_private_watchlist(cls, session: Session, name: str) -> None: ) validate_response(response) - def upload_private_watchlist(self, session: Session) -> None: + def upload_private_watchlist(self, session: ProductionSession) -> None: """ Creates a private remote watchlist identical to this local one. @@ -191,7 +191,7 @@ def upload_private_watchlist(self, session: Session) -> None: ) validate_response(response) - def update_private_watchlist(self, session: Session) -> None: + def update_private_watchlist(self, session: ProductionSession) -> None: """ Updates the existing private remote watchlist. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a2252fd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import os +import pytest + +from tastytrade import CertificationSession + + +@pytest.fixture(scope='session') +def session(): + username = os.environ.get('TT_USERNAME', None) + password = os.environ.get('TT_PASSWORD', None) + + session = CertificationSession(username, password) + yield session + + session.destroy() diff --git a/tests/test_account.py b/tests/test_account.py new file mode 100644 index 0000000..20f668a --- /dev/null +++ b/tests/test_account.py @@ -0,0 +1,97 @@ +from decimal import Decimal +import pytest + +from tastytrade import Account +from tastytrade.instruments import Equity +from tastytrade.order import NewOrder, OrderAction, OrderTimeInForce, OrderType, PriceEffect + + +@pytest.fixture +def account(session): + return Account.get_accounts(session)[1] + + +def test_get_account(session, account): + acc = Account.get_account(session, account.account_number) + assert acc == account + + +def test_get_trading_status(session, account): + account.get_trading_status(session) + + +def test_get_balances(session, account): + account.get_balances(session) + + +def test_get_balance_snapshots(session, account): + account.get_balance_snapshots(session) + + +def test_get_positions(session, account): + account.get_positions(session) + + +def test_get_history(session, account): + account.get_history(session) + + +def test_get_transaction(session, account): + TX_ID = 42961 # opening deposit + account.get_transaction(session, TX_ID) + + +def test_get_total_fees(session, account): + account.get_total_fees(session) + + +def test_get_position_limit(session, account): + account.get_position_limit(session) + + +def test_get_margin_requirements(session, account): + account.get_margin_requirements(session) + + +@pytest.fixture +def new_order(session): + symbol = Equity.get_equity(session, 'SPY') + leg = symbol.build_leg(Decimal(1), OrderAction.BUY_TO_OPEN) + + return NewOrder( + time_in_force=OrderTimeInForce.DAY, + order_type=OrderType.LIMIT, + legs=[leg], + price=Decimal(420), # over $3 so will never fill + price_effect=PriceEffect.DEBIT + ) + + +@pytest.fixture +def placed_order(session, account, new_order): + return account.place_order(session, new_order, dry_run=False).order + + +def test_place_and_delete_order(session, account, new_order): + order = account.place_order(session, new_order, dry_run=False).order + account.delete_order(session, order.id) + + +def test_replace_order(session, account, new_order, placed_order): + account.replace_order(session, placed_order.id, new_order) + + +def test_get_order(session, account, placed_order): + assert account.get_order(session, placed_order.id).id == placed_order.id + + +def test_delete_order(session, account, placed_order): + account.delete_order(session, placed_order.id) + + +def test_get_order_history(session, account): + account.get_order_history(session) + + +def test_get_live_orders(session, account): + account.get_live_orders(session) diff --git a/tests/test_instruments.py b/tests/test_instruments.py new file mode 100644 index 0000000..5337cc8 --- /dev/null +++ b/tests/test_instruments.py @@ -0,0 +1,59 @@ +from tastytrade.instruments import Cryptocurrency, Equity, FutureProduct, FutureOptionProduct, Option, NestedOptionChain, Warrant, get_quantity_decimal_precisions, get_option_chain + + +def test_get_cryptocurrency(session): + Cryptocurrency.get_cryptocurrency(session, 'ETH/USD') + + +def test_get_cryptocurrencies(session): + Cryptocurrency.get_cryptocurrencies(session) + + +def test_get_active_equities(session): + Equity.get_active_equities(session, page_offset=0) + + +def test_get_equities(session): + Equity.get_equities(session, ['AAPL', 'SPY']) + + +def test_get_equity(session): + Equity.get_equity(session, 'AAPL') + + +def test_get_future_product(session): + FutureProduct.get_future_product(session, 'ZN') + + +def test_get_future_option_product(session): + FutureOptionProduct.get_future_option_product(session, 'LO') + + +def test_get_future_option_products(session): + FutureOptionProduct.get_future_option_products(session) + + +def test_get_future_products(session): + FutureProduct.get_future_products(session) + + +def test_get_nested_option_chain(session): + NestedOptionChain.get_chain(session, 'SPY') + + +def test_get_warrants(session): + Warrant.get_warrants(session) + + +def test_get_quantity_decimal_precisions(session): + get_quantity_decimal_precisions(session) + + +def test_get_option_chain(session): + chain = get_option_chain(session, 'SPY') + symbols = [] + for options in chain.values(): + symbols.extend([o.symbol for o in options]) + break + Option.get_option(session, symbols[0]) + Option.get_options(session, symbols) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..7a6f9bf --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,5 @@ +from tastytrade.metrics import get_risk_free_rate + + +def test_get_risk_free_rate(session): + get_risk_free_rate(session) diff --git a/tests/test_session.py b/tests/test_session.py index 175ae3e..813d1aa 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,10 +3,14 @@ from tastytrade import CertificationSession -def test_session(): +def test_get_customer(session): + assert session.get_customer() != {} + + +def test_destroy(): + # here we create a new session to avoid destroying the active one username = os.environ.get('TT_USERNAME', None) - assert username is not None password = os.environ.get('TT_PASSWORD', None) - assert password is not None - session = CertificationSession(username, password) # noqa: F841 + session = CertificationSession(username, password) + assert session.destroy() diff --git a/tests/test_streamer.py b/tests/test_streamer.py new file mode 100644 index 0000000..68b5783 --- /dev/null +++ b/tests/test_streamer.py @@ -0,0 +1,23 @@ +import asyncio +import pytest +import pytest_asyncio + +from tastytrade import Account, AlertStreamer + +pytest_plugins = ('pytest_asyncio',) + + +@pytest_asyncio.fixture +async def streamer(session): + streamer = await AlertStreamer.create(session) + yield streamer + streamer.close() + + +@pytest.mark.asyncio +async def test_subscribe_all(session, streamer): + await streamer.public_watchlists_subscribe() + await streamer.quote_alerts_subscribe() + await streamer.user_message_subscribe(session) + accounts = Account.get_accounts(session) + await streamer.account_subscribe(accounts)