From de9e64b182d596935883095da6062764dab225b7 Mon Sep 17 00:00:00 2001 From: Calina Cenan Date: Wed, 20 Dec 2023 07:57:37 +0000 Subject: [PATCH] Fix more type issues. --- pdr_backend/ppss/data_pp.py | 8 ++++---- pdr_backend/util/feedstr.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pdr_backend/ppss/data_pp.py b/pdr_backend/ppss/data_pp.py index ccca9f812..b5ed04d07 100644 --- a/pdr_backend/ppss/data_pp.py +++ b/pdr_backend/ppss/data_pp.py @@ -3,8 +3,8 @@ import numpy as np from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed -from pdr_backend.util.feedstr import Feeds, verify_feeds_strs +from pdr_backend.models.feed import Feed as FeedMixin +from pdr_backend.util.feedstr import Feed, Feeds, verify_feeds_strs from pdr_backend.util.listutil import remove_dups from pdr_backend.util.pairstr import unpack_pair_str from pdr_backend.util.timeframestr import Timeframe, verify_timeframe_str @@ -123,7 +123,7 @@ def filter_feeds_s(self) -> str: return f"{self.timeframe} {self.predict_feeds_strs}" @enforce_types - def filter_feeds(self, cand_feeds: Dict[str, Feed]) -> Dict[str, Feed]: + def filter_feeds(self, cand_feeds: Dict[str, FeedMixin]) -> Dict[str, FeedMixin]: """ @description Filter to feeds that fit self.predict_feeds' @@ -139,7 +139,7 @@ def filter_feeds(self, cand_feeds: Dict[str, Feed]) -> Dict[str, Feed]: (self.timeframe, feed.exchange, feed.pair) for feed in self.predict_feeds ] - final_feeds: Dict[str, Feed] = {} + final_feeds: Dict[str, FeedMixin] = {} found_tups = set() # to avoid duplicates for feed in cand_feeds.values(): assert isinstance(feed, Feed) diff --git a/pdr_backend/util/feedstr.py b/pdr_backend/util/feedstr.py index 82851ce99..0ed36e5be 100644 --- a/pdr_backend/util/feedstr.py +++ b/pdr_backend/util/feedstr.py @@ -3,7 +3,7 @@ Complementary to models/feed.py which models a prediction feed contract. """ -from typing import List, Tuple +from typing import List, Set from enforce_typing import enforce_types @@ -74,7 +74,7 @@ def from_str(feed_str: str, do_verify: bool = True) -> "Feed": class Feeds(List[Feed]): @staticmethod - def from_strs(feeds_strs: str, do_verify: bool = True) -> "Feeds": + def from_strs(feeds_strs: List[str], do_verify: bool = True) -> "Feeds": if do_verify: if not feeds_strs: raise ValueError(feeds_strs) @@ -98,20 +98,20 @@ def __eq__(self, other): return len(intersection) == len(self) and len(intersection) == len(other) @property - def pairs(self) -> List[str]: + def pairs(self) -> Set[str]: return set(feed.pair for feed in self) @property - def exchanges(self) -> List[str]: + def exchanges(self) -> Set[str]: return set(feed.exchange for feed in self) @property - def signals(self) -> List[str]: + def signals(self) -> Set[str]: return set(feed.signal for feed in self) @enforce_types -def _unpack_feeds_str(feeds_str: str) -> List[Tuple[str, str, str]]: +def _unpack_feeds_str(feeds_str: str) -> List[Feed]: """ @description Unpack a *single* feeds str. It can have >1 feeds of course. @@ -171,7 +171,7 @@ def verify_feeds_str(feeds_str: str): @argument feeds_str -- e.g. "binance oh ADA/USDT BTC-USDT" """ - Feeds.from_str(feeds_str, do_verify=True) + Feeds.from_str(feeds_str) @enforce_types