Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Oct 9, 2024
1 parent 96ae495 commit 094732d
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 94 deletions.
95 changes: 24 additions & 71 deletions tastytrade/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ class Account(TastytradeJsonDataclass):
submitting_user_id: Optional[str] = None

@classmethod
async def a_get_accounts(cls, session: Session, include_closed=False) -> List["Account"]:
async def a_get_accounts(
cls, session: Session, include_closed=False
) -> List["Account"]:
"""
Gets all trading accounts associated with the Tastytrade user.
Expand Down Expand Up @@ -579,69 +581,6 @@ async def a_get_balance_snapshots(
params["page-offset"] += 1 # type: ignore
return snapshots

async def a_get_balance_snapshots(
self,
session: Session,
per_page: int = 250,
page_offset: Optional[int] = None,
currency: str = "USD",
end_date: Optional[date] = None,
start_date: Optional[date] = None,
snapshot_date: Optional[date] = None,
time_of_day: Literal["BOD", "EOD"] = "EOD",
) -> List[AccountBalanceSnapshot]:
"""
Returns a list of balance snapshots. This list will
just have a few snapshots if you don't pass a start
date; otherwise, it will be each day's balances in
the given range.
:param session: the session to use for the request.
:param currency: the currency to show balances in.
:param start_date: the starting date of the range.
:param end_date: the ending date of the range.
:param snapshot_date: the date of the snapshot to get.
:param time_of_day:
the time of day of the snapshots to get, either 'EOD' (End Of Day) or 'BOD' (Beginning Of Day).
"""
paginate = False
if page_offset is None:
page_offset = 0
paginate = True
params = {
"per-page": per_page,
"page-offset": page_offset,
"currency": currency,
"end-date": end_date,
"start-date": start_date,
"snapshot-date": snapshot_date,
"time-of-day": time_of_day,
}
snapshots = []
while True:
response = await session.async_client.get(
f"/accounts/{self.account_number}/balance-snapshots",
params={
k: v # type: ignore
for k, v in params.items()
if v is not None
},
)
validate_response(response)
json = response.json()
snapshots.extend(
[AccountBalanceSnapshot(**i) for i in json["data"]["items"]]
)
# handle pagination
pagination = json["pagination"]
if (
pagination["page-offset"] >= pagination["total-pages"] - 1
or not paginate
):
break
params["page-offset"] += 1 # type: ignore
return snapshots

def get_balance_snapshots(
self,
session: Session,
Expand Down Expand Up @@ -992,7 +931,9 @@ async def a_get_transaction(self, session: Session, id: int) -> Transaction:
:param session: the session to use for the request.
:param id: the ID of the transaction to fetch.
"""
data = await session._a_get(f"/accounts/{self.account_number}/transactions/{id}")
data = await session._a_get(
f"/accounts/{self.account_number}/transactions/{id}"
)
return Transaction(**data)

def get_transaction(self, session: Session, id: int) -> Transaction:
Expand Down Expand Up @@ -1164,7 +1105,9 @@ async def a_get_margin_requirements(self, session: Session) -> MarginReport:
:param session: the session to use for the request.
"""
data = await session._a_get(f"/margin/accounts/{self.account_number}/requirements")
data = await session._a_get(
f"/margin/accounts/{self.account_number}/requirements"
)
return MarginReport(**data)

def get_margin_requirements(self, session: Session) -> MarginReport:
Expand Down Expand Up @@ -1195,13 +1138,17 @@ def get_live_orders(self, session: Session) -> List[PlacedOrder]:
data = session._get(f"/accounts/{self.account_number}/orders/live")
return [PlacedOrder(**i) for i in data["items"]]

async def a_get_live_complex_orders(self, session: Session) -> List[PlacedComplexOrder]:
async def a_get_live_complex_orders(
self, session: Session
) -> List[PlacedComplexOrder]:
"""
Get complex orders placed today for the account.
:param session: the session to use for the request.
"""
data = await session._a_get(f"/accounts/{self.account_number}/complex-orders/live")
data = await session._a_get(
f"/accounts/{self.account_number}/complex-orders/live"
)
return [PlacedComplexOrder(**i) for i in data["items"]]

def get_live_complex_orders(self, session: Session) -> List[PlacedComplexOrder]:
Expand All @@ -1213,7 +1160,9 @@ def get_live_complex_orders(self, session: Session) -> List[PlacedComplexOrder]:
data = session._get(f"/accounts/{self.account_number}/complex-orders/live")
return [PlacedComplexOrder(**i) for i in data["items"]]

async def a_get_complex_order(self, session: Session, order_id: int) -> PlacedComplexOrder:
async def a_get_complex_order(
self, session: Session, order_id: int
) -> PlacedComplexOrder:
"""
Gets a complex order with the given ID.
Expand Down Expand Up @@ -1244,7 +1193,9 @@ async def a_get_order(self, session: Session, order_id: int) -> PlacedOrder:
:param session: the session to use for the request.
:param order_id: the ID of the order to fetch.
"""
data = await session._a_get(f"/accounts/{self.account_number}/orders/{order_id}")
data = await session._a_get(
f"/accounts/{self.account_number}/orders/{order_id}"
)
return PlacedOrder(**data)

def get_order(self, session: Session, order_id: int) -> PlacedOrder:
Expand All @@ -1264,7 +1215,9 @@ async def a_delete_complex_order(self, session: Session, order_id: int) -> None:
:param session: the session to use for the request.
:param order_id: the ID of the order to delete.
"""
await session._a_delete(f"/accounts/{self.account_number}/complex-orders/{order_id}")
await session._a_delete(
f"/accounts/{self.account_number}/complex-orders/{order_id}"
)

def delete_complex_order(self, session: Session, order_id: int) -> None:
"""
Expand Down
16 changes: 12 additions & 4 deletions tastytrade/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def get_cryptocurrencies(
return [cls(**i) for i in data["items"]]

@classmethod
async def a_get_cryptocurrency(cls, session: Session, symbol: str) -> "Cryptocurrency":
async def a_get_cryptocurrency(
cls, session: Session, symbol: str
) -> "Cryptocurrency":
"""
Returns a Cryptocurrency object from the given symbol.
Expand Down Expand Up @@ -557,7 +559,9 @@ async def a_get_option(
"""
symbol = symbol.replace("/", "%2F")
params = {"active": active} if active is not None else None
data = await session._a_get(f"/instruments/equity-options/{symbol}", params=params)
data = await session._a_get(
f"/instruments/equity-options/{symbol}", params=params
)
return cls(**data)

@classmethod
Expand Down Expand Up @@ -1133,7 +1137,9 @@ class NestedFutureOptionChain(TastytradeJsonDataclass):
option_chains: List[NestedFutureOptionSubchain]

@classmethod
async def a_get_chain(cls, session: Session, symbol: str) -> "NestedFutureOptionChain":
async def a_get_chain(
cls, session: Session, symbol: str
) -> "NestedFutureOptionChain":
"""
Gets the futures option chain for the given symbol in nested format.
Expand Down Expand Up @@ -1226,7 +1232,9 @@ def get_warrant(cls, session: Session, symbol: str) -> "Warrant":
FutureProduct.model_rebuild()


async def a_get_quantity_decimal_precisions(session: Session) -> List[QuantityDecimalPrecision]:
async def a_get_quantity_decimal_precisions(
session: Session,
) -> List[QuantityDecimalPrecision]:
"""
Returns a list of QuantityDecimalPrecision objects for different
types of instruments.
Expand Down
12 changes: 9 additions & 3 deletions tastytrade/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,18 @@ class MarketMetricInfo(TastytradeJsonDataclass):
borrow_rate: Optional[Decimal] = None


async def a_get_market_metrics(session: Session, symbols: List[str]) -> List[MarketMetricInfo]:
async def a_get_market_metrics(
session: Session, symbols: List[str]
) -> List[MarketMetricInfo]:
"""
Retrieves market metrics for the given symbols.
:param session: active user session to use
:param symbols: list of symbols to retrieve metrics for
"""
data = await session._a_get("/market-metrics", params={"symbols": ",".join(symbols)})
data = await session._a_get(
"/market-metrics", params={"symbols": ",".join(symbols)}
)
return [MarketMetricInfo(**i) for i in data["items"]]


Expand Down Expand Up @@ -163,7 +167,9 @@ def get_dividends(session: Session, symbol: str) -> List[DividendInfo]:
return [DividendInfo(**i) for i in data["items"]]


async def a_get_earnings(session: Session, symbol: str, start_date: date) -> List[EarningsInfo]:
async def a_get_earnings(
session: Session, symbol: str, start_date: date
) -> List[EarningsInfo]:
"""
Retrieves earnings information for the given symbol.
Expand Down
1 change: 1 addition & 0 deletions tastytrade/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class User(TastytradeJsonDataclass):
"""
Dataclass containing information about a Tastytrade user.
"""

email: str
external_id: str
is_confirmed: bool
Expand Down
4 changes: 3 additions & 1 deletion tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ async def subscribe_user_messages(self, session: Session) -> None:
"""
Subscribes to user-level messages, e.g. new account creation.
"""
await self._subscribe(SubscriptionType.USER_MESSAGE, value=session.user.external_id)
await self._subscribe(
SubscriptionType.USER_MESSAGE, value=session.user.external_id
)

async def _heartbeat(self) -> None:
"""
Expand Down
12 changes: 9 additions & 3 deletions tastytrade/watchlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def get_pairs_watchlists(cls, session: Session) -> List["PairsWatchlist"]:
return [cls(**i) for i in data["items"]]

@classmethod
async def a_get_pairs_watchlist(cls, session: Session, name: str) -> "PairsWatchlist":
async def a_get_pairs_watchlist(
cls, session: Session, name: str
) -> "PairsWatchlist":
"""
Fetches a Tastytrade public pairs watchlist by name.
Expand Down Expand Up @@ -91,7 +93,9 @@ async def a_get_public_watchlists(
:param session: the session to use for the request.
:param counts_only: whether to only fetch the counts of the watchlists.
"""
data = await session._a_get("/public-watchlists", params={"counts-only": counts_only})
data = await session._a_get(
"/public-watchlists", params={"counts-only": counts_only}
)
return [cls(**i) for i in data["items"]]

@classmethod
Expand Down Expand Up @@ -213,7 +217,9 @@ async def a_update_private_watchlist(self, session: Session) -> None:
:param session: the session to use for the request.
"""
await session._a_put(f"/watchlists/{self.name}", json=self.model_dump(by_alias=True))
await session._a_put(
f"/watchlists/{self.name}", json=self.model_dump(by_alias=True)
)

def update_private_watchlist(self, session: Session) -> None:
"""
Expand Down
22 changes: 13 additions & 9 deletions tests/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,24 @@ def test_replace_and_delete_order(session, account, new_order, placed_order):

def test_place_oco_order(session, account):
# account must have a share of F for this to work
symbol = Equity.get_equity(session, 'F')
symbol = Equity.get_equity(session, "F")
closing = symbol.build_leg(Decimal(1), OrderAction.SELL_TO_CLOSE)
oco = NewComplexOrder(
orders=[
NewOrder(
time_in_force=OrderTimeInForce.GTC,
order_type=OrderType.LIMIT,
legs=[closing],
price=Decimal('100'), # will never fill
price_effect=PriceEffect.CREDIT
price=Decimal("100"), # will never fill
price_effect=PriceEffect.CREDIT,
),
NewOrder(
time_in_force=OrderTimeInForce.GTC,
order_type=OrderType.STOP,
legs=[closing],
stop_trigger=Decimal('1.5'), # will never fill
price_effect=PriceEffect.CREDIT
)
stop_trigger=Decimal("1.5"), # will never fill
price_effect=PriceEffect.CREDIT,
),
]
)
resp2 = account.place_complex_order(session, oco, dry_run=False)
Expand Down Expand Up @@ -256,14 +256,18 @@ async def placed_order_async(session, account, new_order):

async def test_get_order_async(session, account, placed_order_async):
sleep(3)
placed = await account.a_get_order(session, placed_order_async.id)
placed = await account.a_get_order(session, placed_order_async.id)
assert placed.id == placed_order_async.id


async def test_replace_and_delete_order_async(session, account, new_order, placed_order_async):
async def test_replace_and_delete_order_async(
session, account, new_order, placed_order_async
):
modified_order = new_order.model_copy()
modified_order.price = Decimal("2.01")
replaced = await account.a_replace_order(session, placed_order_async.id, modified_order)
replaced = await account.a_replace_order(
session, placed_order_async.id, modified_order
)
sleep(3)
await account.a_delete_order(session, replaced.id)

Expand Down
18 changes: 16 additions & 2 deletions tests/test_dxfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_parse_infinities_and_nan():
quote_data = ['SPY', 0, 0, 0, 0, 'Q', 0, 'Q', '-Infinity', 'Infinity', 'NaN', 'NaN']
quote_data = ["SPY", 0, 0, 0, 0, "Q", 0, "Q", "-Infinity", "Infinity", "NaN", "NaN"]
quote = Quote.from_stream(quote_data)[0]
quote = cast(Quote, quote)
assert quote.bidPrice is None
Expand All @@ -17,5 +17,19 @@ def test_parse_infinities_and_nan():

def test_malformatted_data():
with pytest.raises(TastytradeError):
quote_data = ['SPY', 0, 0, 0, 0, 'Q', 0, 'Q', 576.88, 576.9, 230.0, 300.0, 'extra']
quote_data = [
"SPY",
0,
0,
0,
0,
"Q",
0,
"Q",
576.88,
576.9,
230.0,
300.0,
"extra",
]
_ = Quote.from_stream(quote_data)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 094732d

Please sign in to comment.