Skip to content

Commit

Permalink
Fix more type issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
calina-c committed Dec 20, 2023
1 parent baeceac commit de9e64b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions pdr_backend/ppss/data_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
from enforce_typing import enforce_types

from pdr_backend.models.feed import Feed
from pdr_backend.util.feedstr import Feeds, verify_feeds_strs
from pdr_backend.models.feed import Feed as FeedMixin
from pdr_backend.util.feedstr import Feed, Feeds, verify_feeds_strs
from pdr_backend.util.listutil import remove_dups
from pdr_backend.util.pairstr import unpack_pair_str
from pdr_backend.util.timeframestr import Timeframe, verify_timeframe_str
Expand Down Expand Up @@ -123,7 +123,7 @@ def filter_feeds_s(self) -> str:
return f"{self.timeframe} {self.predict_feeds_strs}"

@enforce_types
def filter_feeds(self, cand_feeds: Dict[str, Feed]) -> Dict[str, Feed]:
def filter_feeds(self, cand_feeds: Dict[str, FeedMixin]) -> Dict[str, FeedMixin]:
"""
@description
Filter to feeds that fit self.predict_feeds'
Expand All @@ -139,7 +139,7 @@ def filter_feeds(self, cand_feeds: Dict[str, Feed]) -> Dict[str, Feed]:
(self.timeframe, feed.exchange, feed.pair) for feed in self.predict_feeds
]

final_feeds: Dict[str, Feed] = {}
final_feeds: Dict[str, FeedMixin] = {}
found_tups = set() # to avoid duplicates
for feed in cand_feeds.values():
assert isinstance(feed, Feed)
Expand Down
14 changes: 7 additions & 7 deletions pdr_backend/util/feedstr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Complementary to models/feed.py which models a prediction feed contract.
"""

from typing import List, Tuple
from typing import List, Set

from enforce_typing import enforce_types

Expand Down Expand Up @@ -74,7 +74,7 @@ def from_str(feed_str: str, do_verify: bool = True) -> "Feed":

class Feeds(List[Feed]):
@staticmethod
def from_strs(feeds_strs: str, do_verify: bool = True) -> "Feeds":
def from_strs(feeds_strs: List[str], do_verify: bool = True) -> "Feeds":
if do_verify:
if not feeds_strs:
raise ValueError(feeds_strs)
Expand All @@ -98,20 +98,20 @@ def __eq__(self, other):
return len(intersection) == len(self) and len(intersection) == len(other)

@property
def pairs(self) -> List[str]:
def pairs(self) -> Set[str]:
return set(feed.pair for feed in self)

@property
def exchanges(self) -> List[str]:
def exchanges(self) -> Set[str]:
return set(feed.exchange for feed in self)

@property
def signals(self) -> List[str]:
def signals(self) -> Set[str]:
return set(feed.signal for feed in self)


@enforce_types
def _unpack_feeds_str(feeds_str: str) -> List[Tuple[str, str, str]]:
def _unpack_feeds_str(feeds_str: str) -> List[Feed]:
"""
@description
Unpack a *single* feeds str. It can have >1 feeds of course.
Expand Down Expand Up @@ -171,7 +171,7 @@ def verify_feeds_str(feeds_str: str):
@argument
feeds_str -- e.g. "binance oh ADA/USDT BTC-USDT"
"""
Feeds.from_str(feeds_str, do_verify=True)
Feeds.from_str(feeds_str)


@enforce_types
Expand Down

0 comments on commit de9e64b

Please sign in to comment.