Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Nov 28, 2023
1 parent cb45c9c commit dc339ef
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
python -m pytest --cov=tastytrade --cov-report=term-missing tests/ --cov-fail-under=95
env:
TT_USERNAME: ${{ secrets.TT_USERNAME }}
TT_PASSWORD: ${{ secrets.TT_PASSWORD }}
TT_PASSWORD: ${{ secrets.TT_PASSWORD }}
89 changes: 69 additions & 20 deletions tastytrade/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import requests
from pydantic import BaseModel

from tastytrade.order import (InstrumentType, NewOCOOrder, NewOrder,
OrderStatus, PlacedOrder, PlacedOrderResponse,
PriceEffect)
from tastytrade.order import (InstrumentType, NewComplexOrder, NewOrder,
OrderStatus, PlacedComplexOrder, PlacedOrder,
PlacedOrderResponse, PriceEffect)
from tastytrade.session import ProductionSession, Session
from tastytrade.utils import (TastytradeError, TastytradeJsonDataclass,
validate_response)
Expand Down Expand Up @@ -457,7 +457,8 @@ def get_trading_status(self, session: Session) -> TradingStatus:
:return: a Tastytrade 'TradingStatus' object in JSON format.
"""
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/trading-status', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/'
'trading-status'),
headers=session.headers
)
validate_response(response) # throws exception if not 200
Expand Down Expand Up @@ -512,7 +513,8 @@ def get_balance_snapshots(
}

response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/balance-snapshots', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/balance-'
'snapshots'),
headers=session.headers,
params={k: v for k, v in params.items() if v is not None}
)
Expand Down Expand Up @@ -651,7 +653,8 @@ def get_history(
results = []
while True:
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/transactions', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/'
'transactions'),
headers=session.headers,
params={k: v for k, v in params.items() if v is not None}
)
Expand Down Expand Up @@ -683,7 +686,8 @@ def get_transaction(
:return: a Tastytrade 'Transaction' object in JSON format.
"""
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/transactions/{id}', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/transactions'
f'/{id}'),
headers=session.headers
)
validate_response(response)
Expand All @@ -707,7 +711,8 @@ def get_total_fees(
"""
params: Dict[str, Any] = {'date': date}
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/transactions/total-fees', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/transactions/'
'total-fees'),
headers=session.headers,
params=params
)
Expand Down Expand Up @@ -748,7 +753,8 @@ def get_net_liquidating_value_history(
params = {'time-back': time_back}

response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/net-liq/history', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/net-liq/'
'history'),
headers=session.headers,
params=params
)
Expand All @@ -767,7 +773,8 @@ def get_position_limit(self, session: Session) -> PositionLimit:
:return: a Tastytrade 'PositionLimit' object in JSON format.
"""
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/position-limit', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/position-'
'limit'),
headers=session.headers
)
validate_response(response)
Expand All @@ -793,7 +800,8 @@ def get_effective_margin_requirements(
if symbol:
symbol = symbol.replace('/', '%2F')
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/margin-requirements/{symbol}/effective', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/margin-'
f'requirements/{symbol}/effective'),
headers=session.headers
)
validate_response(response)
Expand All @@ -812,7 +820,8 @@ def get_margin_requirements(self, session: Session) -> MarginReport:
:return: a :class:`MarginReport` object.
"""
response = requests.get(
f'{session.base_url}/margin/accounts/{self.account_number}/requirements', # noqa: E501
(f'{session.base_url}/margin/accounts/{self.account_number}/'
'requirements'),
headers=session.headers
)
validate_response(response)
Expand All @@ -839,16 +848,41 @@ def get_live_orders(self, session: Session) -> List[PlacedOrder]:

return [PlacedOrder(**entry) for entry in data]

def get_complex_order(
self,
session: Session,
order_id: str
) -> PlacedComplexOrder:
"""
Gets a complex order with the given ID.
:param session: the session to use for the request.
:return:
a :class:`PlacedComplexOrder` object corresponding to the given ID
"""
response = requests.get(
(f'{session.base_url}/accounts/{self.account_number}/complex-'
f'orders/{order_id}'),
headers=session.headers
)
validate_response(response)

data = response.json()['data']

return PlacedComplexOrder(**data)

def get_order(self, session: Session, order_id: str) -> PlacedOrder:
"""
Gets an order with the given ID.
:param session: the session to use for the request.
:return: an :class:`Order` object corresponding to the given ID.
:return: a :class:`PlacedOrder` object corresponding to the given ID
"""
response = requests.get(
f'{session.base_url}/accounts/{self.account_number}/orders/{order_id}', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/orders'
f'/{order_id}'),
headers=session.headers
)
validate_response(response)
Expand All @@ -857,6 +891,20 @@ def get_order(self, session: Session, order_id: str) -> PlacedOrder:

return PlacedOrder(**data)

def delete_complex_order(self, session: Session, order_id: str) -> None:
"""
Delete a complex order by ID.
:param session: the session to use for the request.
:param order_id: the ID of the order to delete.
"""
response = requests.delete(
(f'{session.base_url}/accounts/{self.account_number}/complex-'
f'orders/{order_id}'),
headers=session.headers
)
validate_response(response)

def delete_order(self, session: Session, order_id: str) -> None:
"""
Delete an order by ID.
Expand All @@ -865,7 +913,8 @@ def delete_order(self, session: Session, order_id: str) -> None:
:param order_id: the ID of the order to delete.
"""
response = requests.delete(
f'{session.base_url}/accounts/{self.account_number}/orders/{order_id}', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/orders'
f'/{order_id}'),
headers=session.headers
)
validate_response(response)
Expand Down Expand Up @@ -979,10 +1028,10 @@ def place_order(

return PlacedOrderResponse(**data)

def place_complex_orders(
def place_complex_order(
self,
session: Session,
order: NewOCOOrder,
order: NewComplexOrder,
dry_run=True
) -> PlacedOrderResponse:
"""
Expand All @@ -995,14 +1044,13 @@ def place_complex_orders(
:return: a :class:`PlacedOrderResponse` object for the placed order.
"""
url = (f'{session.base_url}/accounts/{self.account_number}'
f'/complex-orders')
'/complex-orders')
if dry_run:
url += '/dry-run'
headers = session.headers
# required because we're passing the JSON as a string
headers['Content-Type'] = 'application/json'
json = order.json(exclude_none=True, by_alias=True)
json = json.replace('complex-order-type', 'type')

response = requests.post(url, headers=session.headers, data=json)
validate_response(response)
Expand Down Expand Up @@ -1031,7 +1079,8 @@ def replace_order(
# required because we're passing the JSON as a string
headers['Content-Type'] = 'application/json'
response = requests.put(
f'{session.base_url}/accounts/{self.account_number}/orders/{old_order_id}', # noqa: E501
(f'{session.base_url}/accounts/{self.account_number}/orders'
f'/{old_order_id}'),
headers=headers,
data=new_order.json(
exclude={'legs'},
Expand Down
6 changes: 4 additions & 2 deletions tastytrade/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,8 @@ def get_future_product(
"""
code = code.replace('/', '')
response = requests.get(
f'{session.base_url}/instruments/future-products/{exchange}/{code}', # noqa: E501
(f'{session.base_url}/instruments/future-products/{exchange}/'
f'{code}'),
headers=session.headers
)
validate_response(response)
Expand Down Expand Up @@ -768,7 +769,8 @@ def get_future_option_product(
"""
root_symbol = root_symbol.replace('/', '')
response = requests.get(
f'{session.base_url}/instruments/future-option-products/{exchange}/{root_symbol}', # noqa: E501
(f'{session.base_url}/instruments/future-option-products/'
f'{exchange}/{root_symbol}'),
headers=session.headers
)
validate_response(response)
Expand Down
6 changes: 4 additions & 2 deletions tastytrade/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def get_dividends(
"""
symbol = symbol.replace('/', '%2F')
response = requests.get(
f'{session.base_url}/market-metrics/historic-corporate-events/dividends/{symbol}', # noqa: E501
(f'{session.base_url}/market-metrics/historic-corporate-events/'
f'dividends/{symbol}'),
headers=session.headers
)
validate_response(response)
Expand All @@ -155,7 +156,8 @@ def get_earnings(
symbol = symbol.replace('/', '%2F')
params: Dict[str, Any] = {'start-date': start_date}
response = requests.get(
f'{session.base_url}/market-metrics/historic-corporate-events/earnings-reports/{symbol}', # noqa: E501
(f'{session.base_url}/market-metrics/historic-corporate-events/'
f'earnings-reports/{symbol}'),
headers=session.headers,
params=params
)
Expand Down
40 changes: 22 additions & 18 deletions tastytrade/order.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import List, Optional
from typing import Dict, List, Optional

from tastytrade import VERSION
from tastytrade.utils import TastytradeJsonDataclass
Expand Down Expand Up @@ -111,7 +111,7 @@ class FillInfo(TastytradeJsonDataclass):
quantity: Decimal
fill_price: Decimal
filled_at: datetime
destination_venue: str
destination_venue: Optional[str] = None
ext_group_fill_id: Optional[str] = None
ext_exec_id: Optional[str] = None

Expand All @@ -126,7 +126,7 @@ class Leg(TastytradeJsonDataclass):
instrument_type: InstrumentType
symbol: str
action: OrderAction
quantity: Decimal
quantity: Optional[Decimal] = None
remaining_quantity: Optional[Decimal] = None
fills: Optional[List[FillInfo]] = None

Expand Down Expand Up @@ -229,22 +229,20 @@ class NewOrder(TastytradeJsonDataclass):
rules: Optional[OrderRule] = None


class NewOCOOrder(TastytradeJsonDataclass):
class NewComplexOrder(TastytradeJsonDataclass):
"""
Dataclass containing information about an OCO order.
Dataclass containing information about a new OTOCO order.
Also used for modifying existing orders.
"""
complex_order_type: ComplexOrderType
source: str = f'tastyware/tastytrade:v{VERSION}'
orders: List[NewOrder]
source: str = f'tastyware/tastytrade:v{VERSION}'
trigger_order: Optional[NewOrder] = None
type: ComplexOrderType = ComplexOrderType.OCO


class NewOTOCOOrder(NewOCOOrder):
"""
Dataclass containing information about a new OTOCO order.
Also used for modifying existing orders.
"""
trigger_order: NewOrder
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.trigger_order is not None:
self.type = ComplexOrderType.OTOCO


class PlacedOrder(TastytradeJsonDataclass):
Expand All @@ -255,7 +253,6 @@ class PlacedOrder(TastytradeJsonDataclass):
account_number: str
time_in_force: OrderTimeInForce
order_type: OrderType
size: str
underlying_symbol: str
underlying_instrument_type: InstrumentType
status: OrderStatus
Expand All @@ -264,6 +261,7 @@ class PlacedOrder(TastytradeJsonDataclass):
edited: bool
updated_at: datetime
legs: List[Leg]
size: Optional[str] = None
id: Optional[str] = None
price: Optional[Decimal] = None
price_effect: Optional[PriceEffect] = None
Expand Down Expand Up @@ -291,14 +289,20 @@ class PlacedOrder(TastytradeJsonDataclass):
order_rule: Optional[OrderRule] = None


class ComplexOrder(TastytradeJsonDataclass):
class PlacedComplexOrder(TastytradeJsonDataclass):
"""
Dataclass containing information about a complex order.
Dataclass containing information about an already placed complex order.
"""
account_number: str
type: str
orders: List[PlacedOrder]
id: Optional[str] = None
trigger_order: Optional[PlacedOrder] = None
terminal_at: Optional[str] = None
ratio_price_threshold: Optional[Decimal] = None
ratio_price_comparator: Optional[str] = None
ratio_price_is_threshold_based_on_notional: Optional[bool] = None
related_orders: Optional[List[Dict[str, str]]] = None


class BuyingPowerEffect(TastytradeJsonDataclass):
Expand Down Expand Up @@ -344,7 +348,7 @@ class PlacedOrderResponse(TastytradeJsonDataclass):
buying_power_effect: BuyingPowerEffect
fee_calculation: FeeCalculation
order: Optional[PlacedOrder] = None
complex_order: Optional[ComplexOrder] = None
complex_order: Optional[PlacedComplexOrder] = None
warnings: Optional[List[Message]] = None
errors: Optional[List[Message]] = None

Expand Down
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@

from tastytrade import ProductionSession

CERT_USERNAME = 'tastyware'
CERT_PASSWORD = ':4s-S9/9L&Q~C]@v'


@pytest.fixture(scope='session')
def get_cert_credentials():
return CERT_USERNAME, CERT_PASSWORD


@pytest.fixture(scope='session')
def session():
username = os.environ.get('TT_USERNAME', None)
password = os.environ.get('TT_PASSWORD', None)

assert username is not None
assert password is not None

session = ProductionSession(username, password)
yield session

session.destroy()
Loading

0 comments on commit dc339ef

Please sign in to comment.