diff --git a/example_publisher/publisher.py b/example_publisher/publisher.py index 0063e34..6dd5662 100644 --- a/example_publisher/publisher.py +++ b/example_publisher/publisher.py @@ -32,12 +32,20 @@ def __init__(self, config: Config) -> None: if not getattr(self.config, self.config.provider_engine): raise ValueError(f"Missing {self.config.provider_engine} config") - if self.config.provider_engine == "coin_gecko": + if ( + self.config.provider_engine == "coin_gecko" + and config.coin_gecko is not None + ): self.provider = CoinGecko(config.coin_gecko) - elif self.config.provider_engine == "pyth_replicator": + elif ( + self.config.provider_engine == "pyth_replicator" + and config.pyth_replicator is not None + ): self.provider: Provider = PythReplicator(config.pyth_replicator) else: - raise ValueError(f"Unknown provider {self.config.provider_engine}") + raise ValueError( + f"Unknown provider {self.config.provider_engine}, possibly the env variables is not set." + ) self.pythd: Pythd = Pythd( address=config.pythd.endpoint, diff --git a/example_publisher/pythd.py b/example_publisher/pythd.py index 5c8a7de..e9ee423 100644 --- a/example_publisher/pythd.py +++ b/example_publisher/pythd.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field import sys import traceback -from dataclasses_json import config, dataclass_json -from typing import Awaitable, Callable, List +from dataclasses_json import config, DataClassJsonMixin +from typing import Any, Callable, Coroutine, List from structlog import get_logger from jsonrpc_websocket import Server @@ -15,22 +15,19 @@ TRADING = "trading" -@dataclass_json @dataclass -class Price: +class Price(DataClassJsonMixin): account: str exponent: int = field(metadata=config(field_name="price_exponent")) -@dataclass_json @dataclass -class Metadata: +class Metadata(DataClassJsonMixin): symbol: str -@dataclass_json @dataclass -class Product: +class Product(DataClassJsonMixin): account: str metadata: Metadata = field(metadata=config(field_name="attr_dict")) prices: List[Price] = field(metadata=config(field_name="price")) @@ -38,14 +35,16 @@ class Product: class Pythd: def __init__( - self, address: str, on_notify_price_sched: Callable[[SubscriptionId], Awaitable] + self, + address: str, + on_notify_price_sched: Callable[[SubscriptionId], Coroutine[Any, Any, None]], ) -> None: self.address = address - self.server: Server = None + self.server: Server self.on_notify_price_sched = on_notify_price_sched self._tasks = set() - async def connect(self) -> Server: + async def connect(self): self.server = Server(self.address) self.server.notify_price_sched = self._notify_price_sched task = await self.server.ws_connect() diff --git a/example_publisher/tests/test_pyth_replicator_manual_aggregate.py b/example_publisher/tests/test_pyth_replicator_manual_aggregate.py index b0d2a9c..a003744 100644 --- a/example_publisher/tests/test_pyth_replicator_manual_aggregate.py +++ b/example_publisher/tests/test_pyth_replicator_manual_aggregate.py @@ -1,9 +1,10 @@ import random +from typing import List from example_publisher.providers.pyth_replicator import manual_aggregate def test_manual_aggregate_works(): - prices = [1, 2, 3, 4, 5, 6, 8, 10, 12, 14] + prices: List[float] = [1, 2, 3, 4, 5, 6, 8, 10, 12, 14] random.shuffle(prices) agg_price, agg_confidence_interval = manual_aggregate(prices)