diff --git a/pdr_backend/accuracy/app.py b/pdr_backend/accuracy/app.py index 68c5fabce..cc8873567 100644 --- a/pdr_backend/accuracy/app.py +++ b/pdr_backend/accuracy/app.py @@ -77,7 +77,10 @@ def save_statistics_to_file(): "0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703", network_param ) - contract_information = fetch_contract_id_and_spe(contract_addresses, network_param) + contracts_list_unfiltered = fetch_contract_id_and_spe( + contract_addresses, + network_param, + ) while True: try: @@ -85,13 +88,13 @@ def save_statistics_to_file(): for statistic_type in statistic_types: seconds_per_epoch = statistic_type["seconds_per_epoch"] - contracts = list( + contracts_list = list( filter( lambda item, spe=seconds_per_epoch: int( item["seconds_per_epoch"] ) == spe, - contract_information, + contracts_list_unfiltered, ) ) @@ -99,10 +102,14 @@ def save_statistics_to_file(): statistic_type["alias"] ) - contract_ids = [contract["id"] for contract in contracts] - # Get statistics for all contracts + contract_ids = [contract_item["ID"] for contract_item in contracts_list] + statistics = calculate_statistics_for_all_assets( - contract_ids, contracts, start_ts_param, end_ts_param, network_param + contract_ids, + contracts_list, + start_ts_param, + end_ts_param, + network_param, ) output.append( diff --git a/pdr_backend/data_eng/gql_data_factory.py b/pdr_backend/data_eng/gql_data_factory.py new file mode 100644 index 000000000..f84b25085 --- /dev/null +++ b/pdr_backend/data_eng/gql_data_factory.py @@ -0,0 +1,227 @@ +import os +from typing import Dict, Callable + +from enforce_typing import enforce_types +import polars as pl + +from pdr_backend.data_eng.plutil import ( + has_data, + newest_ut, +) +from pdr_backend.data_eng.table_pdr_predictions import ( + predictions_schema, + get_pdr_predictions_df, +) +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.subgraph_predictions import get_all_contract_ids_by_owner +from pdr_backend.util.timeutil import pretty_timestr, current_ut + + +@enforce_types +class GQLDataFactory: + """ + Roles: + - From each GQL API, fill >=1 gql_dfs -> parquet files data lake + - From gql_dfs, calculate other dfs and stats + - All timestamps, after fetching, are transformed into milliseconds wherever appropriate + + Finally: + - "timestamp" values are ut: int is unix time, UTC, in ms (not s) + - "datetime" values ares python datetime.datetime, UTC + """ + + def __init__(self, ppss: PPSS): + self.ppss = ppss + + # filter by feed contract address + network = get_sapphire_postfix(ppss.web3_pp.network) + contract_list = get_all_contract_ids_by_owner( + owner_address=self.ppss.web3_pp.owner_addrs, + network=network, + ) + contract_list = [f.lower() for f in contract_list] + + # configure all tables that will be recorded onto lake + self.record_config = { + "pdr_predictions": { + "fetch_fn": get_pdr_predictions_df, + "schema": predictions_schema, + "config": { + "contract_list": contract_list, + }, + }, + } + + def get_gql_dfs(self) -> Dict[str, pl.DataFrame]: + """ + @description + Get historical dataframes across many feeds and timeframes. + + @return + predictions_df -- *polars* Dataframe. See class docstring + """ + print("Get predictions data across many feeds and timeframes.") + + # Ss_timestamp is calculated dynamically if ss.fin_timestr = "now". + # But, we don't want fin_timestamp changing as we gather data here. + # To solve, for a given call to this method, we make a constant fin_ut + fin_ut = self.ppss.data_ss.fin_timestamp + + print(f" Data start: {pretty_timestr(self.ppss.data_ss.st_timestamp)}") + print(f" Data fin: {pretty_timestr(fin_ut)}") + + self._update(fin_ut) + gql_dfs = self._load_parquet(fin_ut) + + print("Get historical data across many subgraphs. Done.") + + # postconditions + assert len(gql_dfs.values()) > 0 + for df in gql_dfs.values(): + assert isinstance(df, pl.DataFrame) + + return gql_dfs + + def _update(self, fin_ut: int): + """ + @description + Iterate across all gql queries and update their parquet files: + - Predictoors + - Slots + - Claims + + Improve this by: + 1. Break out raw data from any transformed/cleaned data + 2. Integrate other queries and summaries + 3. Integrate config/pp if needed + @arguments + fin_ut -- a timestamp, in ms, in UTC + """ + + for k, record in self.record_config.items(): + filename = self._parquet_filename(k) + print(f" filename={filename}") + + st_ut = self._calc_start_ut(filename) + print(f" Aim to fetch data from start time: {pretty_timestr(st_ut)}") + if st_ut > min(current_ut(), fin_ut): + print(" Given start time, no data to gather. Exit.") + continue + + # to satisfy mypy, get an explicit function pointer + do_fetch: Callable[[str, int, int, Dict], pl.DataFrame] = record["fetch_fn"] + + # call the function + print(f" Fetching {k}") + df = do_fetch(self.ppss.web3_pp.network, st_ut, fin_ut, record["config"]) + + # postcondition + if len(df) > 0: + assert df.schema == record["schema"] + + # save to parquet + self._save_parquet(filename, df) + + def _calc_start_ut(self, filename: str) -> int: + """ + @description + Calculate start timestamp, reconciling whether file exists and where + its data starts. If file exists, you can only append to end. + + @arguments + filename - parquet file with data. May or may not exist. + + @return + start_ut - timestamp (ut) to start grabbing data for (in ms) + """ + if not os.path.exists(filename): + print(" No file exists yet, so will fetch all data") + return self.ppss.data_ss.st_timestamp + + print(" File already exists") + if not has_data(filename): + print(" File has no data, so delete it") + os.remove(filename) + return self.ppss.data_ss.st_timestamp + + file_utN = newest_ut(filename) + return file_utN + 1000 + + def _load_parquet(self, fin_ut: int) -> Dict[str, pl.DataFrame]: + """ + @arguments + fin_ut -- finish timestamp + + @return + gql_dfs -- dict of [gql_filename] : df + Where df has columns=GQL_COLS+"datetime", and index=timestamp + """ + print(" Load parquet.") + st_ut = self.ppss.data_ss.st_timestamp + + dfs: Dict[str, pl.DataFrame] = {} # [parquet_filename] : df + + for k, record in self.record_config.items(): + filename = self._parquet_filename(k) + print(f" filename={filename}") + + # load all data from file + df = pl.read_parquet(filename) + df = df.filter( + (pl.col("timestamp") >= st_ut) & (pl.col("timestamp") <= fin_ut) + ) + + # postcondition + assert df.schema == record["schema"] + dfs[k] = df + + return dfs + + def _parquet_filename(self, filename_str: str) -> str: + """ + @description + Computes the lake-path for the parquet file. + + @arguments + filename_str -- eg "subgraph_predictions" + + @return + parquet_filename -- name for parquet file. + """ + basename = f"{filename_str}.parquet" + filename = os.path.join(self.ppss.data_ss.parquet_dir, basename) + return filename + + @enforce_types + def _save_parquet(self, filename: str, df: pl.DataFrame): + """write to parquet file + parquet only supports appending via the pyarrow engine + """ + + # precondition + assert "timestamp" in df.columns and df["timestamp"].dtype == pl.Int64 + assert len(df) > 0 + if len(df) > 1: + assert ( + df.head(1)["timestamp"].to_list()[0] + < df.tail(1)["timestamp"].to_list()[0] + ) + + if os.path.exists(filename): # "append" existing file + cur_df = pl.read_parquet(filename) + df = pl.concat([cur_df, df]) + + # check for duplicates and throw error if any found + duplicate_rows = df.filter(pl.struct("ID").is_duplicated()) + if len(duplicate_rows) > 0: + raise Exception( + f"Not saved. Duplicate rows found. {len(duplicate_rows)} rows: {duplicate_rows}" + ) + + df.write_parquet(filename) + n_new = df.shape[0] - cur_df.shape[0] + print(f" Just appended {n_new} df rows to file {filename}") + else: # write new file + df.write_parquet(filename) + print(f" Just saved df with {df.shape[0]} rows to new file {filename}") diff --git a/pdr_backend/data_eng/table_pdr_predictions.py b/pdr_backend/data_eng/table_pdr_predictions.py new file mode 100644 index 000000000..b131ce947 --- /dev/null +++ b/pdr_backend/data_eng/table_pdr_predictions.py @@ -0,0 +1,88 @@ +from typing import List, Dict +from enforce_typing import enforce_types + +import polars as pl +from polars import Utf8, Int64, Float64, Boolean + +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.subgraph_predictions import ( + fetch_filtered_predictions, + FilterMode, +) +from pdr_backend.util.timeutil import ms_to_seconds + +# RAW_PREDICTIONS_SCHEMA +predictions_schema = { + "ID": Utf8, + "pair": Utf8, + "timeframe": Utf8, + "prediction": Boolean, + "stake": Float64, + "trueval": Boolean, + "timestamp": Int64, + "source": Utf8, + "payout": Float64, + "slot": Int64, + "user": Utf8, +} + + +def _object_list_to_df(objects: List[object], schema: Dict) -> pl.DataFrame: + """ + @description + Convert list objects to a dataframe using their __dict__ structure. + """ + # Get all predictions into a dataframe + obj_dicts = [object.__dict__ for object in objects] + obj_df = pl.DataFrame(obj_dicts, schema=schema) + assert obj_df.schema == schema + + return obj_df + + +def _transform_timestamp_to_ms(df: pl.DataFrame) -> pl.DataFrame: + df = df.with_columns( + [ + pl.col("timestamp").mul(1000).alias("timestamp"), + ] + ) + return df + + +@enforce_types +def get_pdr_predictions_df( + network: str, st_ut: int, fin_ut: int, config: Dict +) -> pl.DataFrame: + """ + @description + Fetch raw predictions from predictoor subgraph + Update function for graphql query, returns raw data + + Transforms ts into ms as required for data factory + """ + network = get_sapphire_postfix(network) + + # fetch predictions + predictions = fetch_filtered_predictions( + ms_to_seconds(st_ut), + ms_to_seconds(fin_ut), + config["contract_list"], + network, + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + if len(predictions) == 0: + print(" No predictions to fetch. Exit.") + return pl.DataFrame() + + # convert predictions to df and transform timestamp into ms + predictions_df = _object_list_to_df(predictions, predictions_schema) + predictions_df = _transform_timestamp_to_ms(predictions_df) + + # cull any records outside of our time range and sort them by timestamp + predictions_df = predictions_df.filter( + pl.col("timestamp").is_between(st_ut, fin_ut) + ).sort("timestamp") + + return predictions_df diff --git a/pdr_backend/data_eng/test/conftest.py b/pdr_backend/data_eng/test/conftest.py index 794dba56a..e0bf3b306 100644 --- a/pdr_backend/data_eng/test/conftest.py +++ b/pdr_backend/data_eng/test/conftest.py @@ -1,10 +1,49 @@ +from typing import List + +from enforce_typing import enforce_types import pytest from pdr_backend.data_eng.model_factory import ModelFactory +from pdr_backend.models.prediction import ( + Prediction, + mock_prediction, +) from pdr_backend.ppss.model_ss import ModelSS +from pdr_backend.util.test_data import ( + sample_first_predictions, + sample_second_predictions, + sample_daily_predictions, +) @pytest.fixture(scope="session") def model_factory(): model_ss = ModelSS({"approach": "LIN"}) return ModelFactory(model_ss) + + +@enforce_types +@pytest.fixture(scope="session") +def _sample_first_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) + for prediction_tuple in sample_first_predictions + ] + + +@enforce_types +@pytest.fixture(scope="session") +def _sample_second_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) + for prediction_tuple in sample_second_predictions + ] + + +@enforce_types +@pytest.fixture(scope="session") +def _sample_daily_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) + for prediction_tuple in sample_daily_predictions + ] diff --git a/pdr_backend/data_eng/test/resources.py b/pdr_backend/data_eng/test/resources.py index c130ef7d8..82539e7bb 100644 --- a/pdr_backend/data_eng/test/resources.py +++ b/pdr_backend/data_eng/test/resources.py @@ -5,6 +5,7 @@ from pdr_backend.data_eng.constants import TOHLCV_COLS, TOHLCV_SCHEMA_PL +from pdr_backend.data_eng.gql_data_factory import GQLDataFactory from pdr_backend.data_eng.model_data_factory import ModelDataFactory from pdr_backend.data_eng.ohlcv_data_factory import OhlcvDataFactory from pdr_backend.data_eng.plutil import ( @@ -14,6 +15,8 @@ ) from pdr_backend.ppss.data_pp import DataPP from pdr_backend.ppss.data_ss import DataSS +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.ppss.web3_pp import mock_web3_pp @enforce_types @@ -35,6 +38,15 @@ def _data_pp_ss_1feed(tmpdir, feed, st_timestr=None, fin_timestr=None): return pp, ss, ohlcv_data_factory, model_data_factory +@enforce_types +def _gql_data_factory(tmpdir, feed, st_timestr=None, fin_timestr=None): + network = "sapphire-mainnet" + ppss = mock_ppss("5m", [feed], network, str(tmpdir), st_timestr, fin_timestr) + ppss.web3_pp = mock_web3_pp(network) + gql_data_factory = GQLDataFactory(ppss) + return ppss, gql_data_factory + + @enforce_types def _data_pp(predict_feeds) -> DataPP: return DataPP( diff --git a/pdr_backend/data_eng/test/test_gql_data_factory.py b/pdr_backend/data_eng/test/test_gql_data_factory.py new file mode 100644 index 000000000..a96f1b748 --- /dev/null +++ b/pdr_backend/data_eng/test/test_gql_data_factory.py @@ -0,0 +1,280 @@ +from typing import List +from unittest.mock import patch + +from enforce_typing import enforce_types +import polars as pl + +from pdr_backend.data_eng.test.resources import _gql_data_factory +from pdr_backend.data_eng.table_pdr_predictions import predictions_schema +from pdr_backend.ppss.web3_pp import del_network_override +from pdr_backend.util.subgraph_predictions import FilterMode +from pdr_backend.util.timeutil import timestr_to_ut + +# ==================================================================== +# test parquet updating +pdr_predictions_record = "pdr_predictions" + + +@patch("pdr_backend.data_eng.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql1( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + del_network_override(monkeypatch) + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + "2023-11-02_0:00", + "2023-11-04_0:00", + n_preds=2, + ) + + +@patch("pdr_backend.data_eng.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql2( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + del_network_override(monkeypatch) + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + "2023-11-02_0:00", + "2023-11-06_0:00", + n_preds=4, + ) + + +@patch("pdr_backend.data_eng.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql3( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + del_network_override(monkeypatch) + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + "2023-11-01_0:00", + "2023-11-07_0:00", + n_preds=6, + ) + + +@patch("pdr_backend.data_eng.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql_iteratively( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + del_network_override(monkeypatch) + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + iterations = [ + ("2023-11-02_0:00", "2023-11-04_0:00", 2), + ("2023-11-01_0:00", "2023-11-05_0:00", 3), + ("2023-11-02_0:00", "2023-11-07_0:00", 5), + ] + + for st_timestr, fin_timestr, n_preds in iterations: + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + st_timestr, + fin_timestr, + n_preds=n_preds, + ) + + +@enforce_types +def _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_predictions, + st_timestr: str, + fin_timestr: str, + n_preds, +): + """ + @arguments + n_preds -- expected # predictions. Typically int. If '>1K', expect >1000 + """ + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus h ETH/USDT", + st_timestr, + fin_timestr, + ) + + # setup: filename + # everything will be inside the gql folder + filename = gql_data_factory._parquet_filename(pdr_predictions_record) + assert ".parquet" in filename + + fin_ut = timestr_to_ut(fin_timestr) + st_ut = gql_data_factory._calc_start_ut(filename) + + # calculate ms locally so we can filter raw Predictions + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # filter preds that will be returned from subgraph to client + target_preds = [ + x for x in sample_predictions if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + mock_fetch_filtered_predictions.return_value = target_preds + + # work 1: update parquet + gql_data_factory._update(fin_ut) + + # assert params + mock_fetch_filtered_predictions.assert_called_with( + st_ut_sec, + fin_ut_sec, + ["0x123"], + "mainnet", + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + # read parquet and columns + def _preds_in_parquet(filename: str) -> List[int]: + df = pl.read_parquet(filename) + assert df.schema == predictions_schema + return df["timestamp"].to_list() + + # assert expected length of preds in parquet + preds: List[int] = _preds_in_parquet(filename) + if isinstance(n_preds, int): + assert len(preds) == n_preds + elif n_preds == ">1K": + assert len(preds) > 1000 + + # preds may not match start or end time + assert preds[0] != st_ut + assert preds[-1] != fin_ut + + # assert all target_preds are registered in parquet + target_preds_ts = [pred.__dict__["timestamp"] for pred in target_preds] + for target_pred in target_preds_ts: + assert target_pred * 1000 in preds + + +@patch("pdr_backend.data_eng.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +def test_load_and_verify_schema( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + del_network_override(monkeypatch) + st_timestr = "2023-11-02_0:00" + fin_timestr = "2023-11-07_0:00" + + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + _sample_daily_predictions, + st_timestr, + fin_timestr, + n_preds=5, + ) + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus h ETH/USDT", + st_timestr, + fin_timestr, + ) + + fin_ut = timestr_to_ut(fin_timestr) + gql_dfs = gql_data_factory._load_parquet(fin_ut) + + assert len(gql_dfs) == 1 + assert len(gql_dfs[pdr_predictions_record]) == 5 + assert gql_dfs[pdr_predictions_record].schema == predictions_schema + + +# ==================================================================== +# test if appropriate calls are made + + +@enforce_types +@patch("pdr_backend.data_eng.gql_data_factory.get_all_contract_ids_by_owner") +@patch("pdr_backend.data_eng.gql_data_factory.GQLDataFactory._update") +@patch("pdr_backend.data_eng.gql_data_factory.GQLDataFactory._load_parquet") +def test_get_gql_dfs_calls( + mock_load_parquet, + mock_update, + mock_get_all_contract_ids_by_owner, + tmpdir, + _sample_daily_predictions, + monkeypatch, +): + """Test core DataFactory functions are being called""" + del_network_override(monkeypatch) + + st_timestr = "2023-11-02_0:00" + fin_timestr = "2023-11-07_0:00" + + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus h ETH/USDT", + st_timestr, + fin_timestr, + ) + + # calculate ms locally so we can filter raw Predictions + st_ut = timestr_to_ut(st_timestr) + fin_ut = timestr_to_ut(fin_timestr) + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # mock_load_parquet should return the values from a simple code block + mock_load_parquet.return_value = { + pdr_predictions_record: pl.DataFrame( + [ + x.__dict__ + for x in _sample_daily_predictions + if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + ).with_columns([pl.col("timestamp").mul(1000).alias("timestamp")]) + } + + # call and assert + gql_dfs = gql_data_factory.get_gql_dfs() + assert isinstance(gql_dfs, dict) + assert isinstance(gql_dfs[pdr_predictions_record], pl.DataFrame) + assert len(gql_dfs[pdr_predictions_record]) == 5 + + mock_update.assert_called_once() + mock_load_parquet.assert_called_once() diff --git a/pdr_backend/models/prediction.py b/pdr_backend/models/prediction.py index 073b8e80a..a0d6623cd 100644 --- a/pdr_backend/models/prediction.py +++ b/pdr_backend/models/prediction.py @@ -1,15 +1,17 @@ from typing import Union +from enforce_typing import enforce_types + +@enforce_types class Prediction: # pylint: disable=too-many-instance-attributes - # pylint: disable=redefined-builtin def __init__( self, - id: str, + ID: str, pair: str, timeframe: str, - prediction: Union[bool, None], + prediction: Union[bool, None], # prediction = subgraph.predicted_value stake: Union[float, None], trueval: Union[bool, None], timestamp: int, # timestamp == prediction submitted timestamp @@ -18,7 +20,7 @@ def __init__( slot: int, # slot/epoch timestamp user: str, ) -> None: - self.id = id + self.ID = ID self.pair = pair self.timeframe = timeframe self.prediction = prediction @@ -29,3 +31,38 @@ def __init__( self.payout = payout self.slot = slot self.user = user + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_prediction(prediction_tuple: tuple) -> Prediction: + ( + pair_str, + timeframe_str, + prediction, + stake, + trueval, + timestamp, + source, + payout, + slot, + user, + ) = prediction_tuple + + ID = f"{pair_str}-{timeframe_str}-{slot}-{user}" + return Prediction( + ID=ID, + pair=pair_str, + timeframe=timeframe_str, + prediction=prediction, + stake=stake, + trueval=trueval, + timestamp=timestamp, + source=source, + payout=payout, + slot=slot, + user=user, + ) diff --git a/pdr_backend/models/test/test_prediction.py b/pdr_backend/models/test/test_prediction.py new file mode 100644 index 000000000..2f4c858c2 --- /dev/null +++ b/pdr_backend/models/test/test_prediction.py @@ -0,0 +1,24 @@ +from pdr_backend.models.prediction import mock_prediction, Prediction + +from pdr_backend.util.test_data import ( + sample_first_predictions, +) + + +def test_mock_predictions(): + predictions = [ + mock_prediction(prediction_tuple) + for prediction_tuple in sample_first_predictions + ] + + assert len(predictions) == 2 + assert isinstance(predictions[0], Prediction) + assert isinstance(predictions[1], Prediction) + assert ( + predictions[0].ID + == "ADA/USDT-5m-1701503100-0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd" + ) + assert ( + predictions[1].ID + == "BTC/USDT-5m-1701589500-0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd" + ) diff --git a/pdr_backend/ppss/ppss.py b/pdr_backend/ppss/ppss.py index 8fcdfe49f..fd665ba3b 100644 --- a/pdr_backend/ppss/ppss.py +++ b/pdr_backend/ppss/ppss.py @@ -90,7 +90,12 @@ def mock_feed_ppss( @enforce_types def mock_ppss( - timeframe: str, predict_feeds: List[str], network: Optional[str] = None, tmpdir=None + timeframe: str, + predict_feeds: List[str], + network: Optional[str] = None, + tmpdir: Optional[str] = None, + st_timestr: Optional[str] = "2023-06-18", + fin_timestr: Optional[str] = "2023-06-21", ) -> PPSS: network = network or "development" yaml_str = fast_test_yaml_str(tmpdir) @@ -112,8 +117,8 @@ def mock_ppss( { "input_feeds": predict_feeds, "parquet_dir": os.path.join(tmpdir, "parquet_data"), - "st_timestr": "2023-06-18", - "fin_timestr": "2023-06-21", + "st_timestr": st_timestr, + "fin_timestr": fin_timestr, "max_n_train": 100, "autoregressive_n": 2, } diff --git a/pdr_backend/ppss/test/test_web3_pp.py b/pdr_backend/ppss/test/test_web3_pp.py index acf342c95..2e97d3a27 100644 --- a/pdr_backend/ppss/test/test_web3_pp.py +++ b/pdr_backend/ppss/test/test_web3_pp.py @@ -147,7 +147,10 @@ def _mock_contract(*args, **kwarg): # pylint: disable=unused-argument m.contract_address = feed.address return m - with patch("pdr_backend.ppss.web3_pp.PredictoorContract", _mock_contract): + with patch( + "pdr_backend.models.predictoor_contract.PredictoorContract", + _mock_contract, + ): contracts = web3_pp.get_contracts([feed.address]) assert list(contracts.keys()) == [feed.address] assert contracts[feed.address].contract_address == feed.address diff --git a/pdr_backend/ppss/web3_pp.py b/pdr_backend/ppss/web3_pp.py index 9a06596c3..5b80cf3e2 100644 --- a/pdr_backend/ppss/web3_pp.py +++ b/pdr_backend/ppss/web3_pp.py @@ -166,8 +166,8 @@ def del_network_override(monkeypatch): def mock_web3_pp(network: str) -> Web3PP: D1 = { "address_file": "address.json 1", - "rpc_url": "rpc url 1", - "subgraph_url": "subgraph url 1", + "rpc_url": "http://example.com/rpc", + "subgraph_url": "http://example.com/subgraph", "stake_token": "0xStake1", "owner_addrs": "0xOwner1", } @@ -273,7 +273,8 @@ def inplace_mock_w3_and_contract_with_tracking( mock_contract_func = Mock() mock_contract_func.return_value = _mock_pdr_contract monkeypatch.setattr( - "pdr_backend.ppss.web3_pp.PredictoorContract", mock_contract_func + "pdr_backend.models.predictoor_contract.PredictoorContract", + mock_contract_func, ) def advance_func(*args, **kwargs): # pylint: disable=unused-argument diff --git a/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py b/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py index 916134729..7ea375415 100644 --- a/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py +++ b/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py @@ -3,4 +3,4 @@ def test_predictoor_agent1(tmpdir, monkeypatch): - run_agent_test(tmpdir, monkeypatch, PredictoorAgent1) + run_agent_test(str(tmpdir), monkeypatch, PredictoorAgent1) diff --git a/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py b/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py index e0e53ba5b..53e8ee36d 100644 --- a/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py +++ b/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py @@ -3,4 +3,4 @@ def test_predictoor_agent3(tmpdir, monkeypatch): - run_agent_test(tmpdir, monkeypatch, PredictoorAgent3) + run_agent_test(str(tmpdir), monkeypatch, PredictoorAgent3) diff --git a/pdr_backend/predictoor/test/predictoor_agent_runner.py b/pdr_backend/predictoor/test/predictoor_agent_runner.py index 3a63fd064..ac7da105a 100644 --- a/pdr_backend/predictoor/test/predictoor_agent_runner.py +++ b/pdr_backend/predictoor/test/predictoor_agent_runner.py @@ -19,7 +19,7 @@ @enforce_types -def run_agent_test(tmpdir, monkeypatch, predictoor_agent_class): +def run_agent_test(tmpdir: str, monkeypatch, predictoor_agent_class): monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) feed, ppss = mock_feed_ppss("5m", "binanceus", "BTC/USDT", tmpdir=tmpdir) inplace_mock_query_feed_contracts(ppss.web3_pp, feed) diff --git a/pdr_backend/util/check_network.py b/pdr_backend/util/check_network.py index bf15be41f..3f74e0ed6 100644 --- a/pdr_backend/util/check_network.py +++ b/pdr_backend/util/check_network.py @@ -34,8 +34,10 @@ def print_stats(contract_dict, field_name, threshold=0.9): def check_dfbuyer(dfbuyer_addr, contract_query_result, subgraph_url, tokens): ts_now = time.time() ts_start_time = int((ts_now // WEEK) * WEEK) + + contracts_sg_dict = contract_query_result["data"]["predictContracts"] contract_addresses = [ - i["id"] for i in contract_query_result["data"]["predictContracts"] + contract_sg_dict["id"] for contract_sg_dict in contracts_sg_dict ] sofar = get_consume_so_far_per_contract( subgraph_url, @@ -66,6 +68,7 @@ def get_expected_consume(for_ts: int, tokens: int): return n_intervals * amount_per_feed_per_interval +@enforce_types def check_network_main(ppss: PPSS, lookback_hours: int): subgraph_url = ppss.web3_pp.subgraph_url web3_config = ppss.web3_pp.web3_config diff --git a/pdr_backend/util/cli_module.py b/pdr_backend/util/cli_module.py index 6fd6f319b..18975488b 100644 --- a/pdr_backend/util/cli_module.py +++ b/pdr_backend/util/cli_module.py @@ -167,7 +167,7 @@ def do_get_traction_info(): print_args(args) ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) - get_traction_info_main(ppss, args.FEEDS, args.ST, args.END, args.PQDIR) + get_traction_info_main(ppss, args.ST, args.END, args.PQDIR) @enforce_types diff --git a/pdr_backend/util/csvs.py b/pdr_backend/util/csvs.py index 46a0bd636..af2b8f745 100644 --- a/pdr_backend/util/csvs.py +++ b/pdr_backend/util/csvs.py @@ -99,7 +99,7 @@ def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str): for prediction in predictions: writer.writerow( [ - prediction.id, + prediction.ID, prediction.timestamp, prediction.slot, prediction.stake, diff --git a/pdr_backend/util/get_predictions_info.py b/pdr_backend/util/get_predictions_info.py index e0cb1a9f8..bcecbe07f 100644 --- a/pdr_backend/util/get_predictions_info.py +++ b/pdr_backend/util/get_predictions_info.py @@ -4,6 +4,7 @@ from pdr_backend.ppss.ppss import PPSS from pdr_backend.util.csvs import save_analysis_csv +from pdr_backend.util.networkutil import get_sapphire_postfix from pdr_backend.util.predictoor_stats import get_cli_statistics from pdr_backend.util.subgraph_predictions import ( get_all_contract_ids_by_owner, @@ -21,14 +22,7 @@ def get_predictions_info_main( end_timestr: str, pq_dir: str, ): - # get network - if "main" in ppss.web3_pp.network: - network = "mainnet" - elif "test" in ppss.web3_pp.network: - network = "testnet" - else: - raise ValueError(ppss.web3_pp.network) - + network = get_sapphire_postfix(ppss.web3_pp.network) start_ut: int = ms_to_seconds(timestr_to_ut(start_timestr)) end_ut: int = ms_to_seconds(timestr_to_ut(end_timestr)) diff --git a/pdr_backend/util/get_predictoors_info.py b/pdr_backend/util/get_predictoors_info.py index 0e38984e6..7fcbaeb3b 100644 --- a/pdr_backend/util/get_predictoors_info.py +++ b/pdr_backend/util/get_predictoors_info.py @@ -4,6 +4,7 @@ from pdr_backend.ppss.ppss import PPSS from pdr_backend.util.csvs import save_prediction_csv +from pdr_backend.util.networkutil import get_sapphire_postfix from pdr_backend.util.predictoor_stats import get_cli_statistics from pdr_backend.util.subgraph_predictions import ( fetch_filtered_predictions, @@ -20,13 +21,7 @@ def get_predictoors_info_main( end_timestr: str, csv_output_dir: str, ): - if "main" in ppss.web3_pp.network: - network = "mainnet" - elif "test" in ppss.web3_pp.network: - network = "testnet" - else: - raise ValueError(ppss.web3_pp.network) - + network = get_sapphire_postfix(ppss.web3_pp.network) start_ut: int = ms_to_seconds(timestr_to_ut(start_timestr)) end_ut: int = ms_to_seconds(timestr_to_ut(end_timestr)) diff --git a/pdr_backend/util/get_traction_info.py b/pdr_backend/util/get_traction_info.py index 4fca86c6f..cb0f4c6c4 100644 --- a/pdr_backend/util/get_traction_info.py +++ b/pdr_backend/util/get_traction_info.py @@ -11,68 +11,31 @@ plot_traction_daily_statistics, plot_slot_daily_statistics, ) -from pdr_backend.util.subgraph_predictions import ( - get_all_contract_ids_by_owner, - fetch_filtered_predictions, - FilterMode, -) -from pdr_backend.util.timeutil import ms_to_seconds, timestr_to_ut +from pdr_backend.data_eng.gql_data_factory import GQLDataFactory @enforce_types def get_traction_info_main( - ppss: PPSS, addrs_str: str, start_timestr: str, end_timestr: str, pq_dir: str + ppss: PPSS, start_timestr: str, end_timestr: str, pq_dir: str ): - # get network - if "main" in ppss.web3_pp.network: - network = "mainnet" - elif "test" in ppss.web3_pp.network: - network = "testnet" - else: - raise ValueError(ppss.web3_pp.network) - - start_ut: int = ms_to_seconds(timestr_to_ut(start_timestr)) - end_ut: int = ms_to_seconds(timestr_to_ut(end_timestr)) - - # filter by contract address - if addrs_str == "": - address_filter = [] - elif "," in addrs_str: - address_filter = addrs_str.lower().split(",") - else: - address_filter = [addrs_str.lower()] + data_ss = ppss.data_ss + data_ss.d["st_timestr"] = start_timestr + data_ss.d["fin_timestr"] = end_timestr - contract_list = get_all_contract_ids_by_owner( - owner_address=ppss.web3_pp.owner_addrs, - network=network, - ) + gql_data_factory = GQLDataFactory(ppss) + gql_dfs = gql_data_factory.get_gql_dfs() - contract_list = [ - x.lower() - for x in contract_list - if x.lower() in address_filter or address_filter == [] - ] - - # fetch predictions - predictions = fetch_filtered_predictions( - start_ut, - end_ut, - contract_list, - network, - FilterMode.CONTRACT, - payout_only=False, - trueval_only=False, - ) - - if len(predictions) == 0: + if len(gql_dfs) == 0: print("No records found. Please adjust start and end times.") return + predictions_df = gql_dfs["pdr_predictions"] + # calculate predictoor traction statistics and draw plots - stats_df = get_traction_statistics(predictions) + stats_df = get_traction_statistics(predictions_df) plot_traction_cum_sum_statistics(stats_df, pq_dir) plot_traction_daily_statistics(stats_df, pq_dir) # calculate slot statistics and draw plots - slots_df = get_slot_statistics(predictions) + slots_df = get_slot_statistics(predictions_df) plot_slot_daily_statistics(slots_df, pq_dir) diff --git a/pdr_backend/util/networkutil.py b/pdr_backend/util/networkutil.py index eba18a7ad..e6b155c99 100644 --- a/pdr_backend/util/networkutil.py +++ b/pdr_backend/util/networkutil.py @@ -12,6 +12,16 @@ def is_sapphire_network(chain_id: int) -> bool: return chain_id in [SAPPHIRE_TESTNET_CHAINID, SAPPHIRE_MAINNET_CHAINID] +@enforce_types +def get_sapphire_postfix(network: str) -> str: + if network == "sapphire-testnet": + return "testnet" + if network == "sapphire-mainnet": + return "mainnet" + + raise ValueError(f"'{network}' is not valid name") + + @enforce_types def send_encrypted_tx( contract_instance, diff --git a/pdr_backend/util/predictoor_stats.py b/pdr_backend/util/predictoor_stats.py index 3502255ca..300913abb 100644 --- a/pdr_backend/util/predictoor_stats.py +++ b/pdr_backend/util/predictoor_stats.py @@ -200,13 +200,7 @@ def get_cli_statistics(all_predictions: List[Prediction]) -> None: @enforce_types -def get_traction_statistics( - all_predictions: List[Prediction], -) -> pl.DataFrame: - # Get all predictions into a dataframe - preds_dicts = [pred.__dict__ for pred in all_predictions] - preds_df = pl.DataFrame(preds_dicts) - +def get_traction_statistics(preds_df: pl.DataFrame) -> pl.DataFrame: # Calculate predictoor traction statistics # Predictoor addresses are aggregated historically stats_df = ( @@ -287,7 +281,7 @@ def plot_traction_cum_sum_statistics(stats_df: pl.DataFrame, pq_dir: str) -> Non ticks = int(len(dates) / 5) if len(dates) > 5 else 2 # draw cum_unique_predictoors - chart_path = os.path.join(charts_dir, "cum_daily_unique_predictoors.png") + chart_path = os.path.join(charts_dir, "daily_cumulative_unique_predictoors.png") plt.figure(figsize=(10, 6)) plt.plot( stats_df["datetime"].to_pandas(), @@ -306,13 +300,7 @@ def plot_traction_cum_sum_statistics(stats_df: pl.DataFrame, pq_dir: str) -> Non @enforce_types -def get_slot_statistics( - all_predictions: List[Prediction], -) -> pl.DataFrame: - # Get all predictions into a dataframe - preds_dicts = [pred.__dict__ for pred in all_predictions] - preds_df = pl.DataFrame(preds_dicts) - +def get_slot_statistics(preds_df: pl.DataFrame) -> pl.DataFrame: # Create a key to group predictions slots_df = ( preds_df.with_columns( @@ -409,7 +397,7 @@ def plot_slot_daily_statistics(slots_df: pl.DataFrame, pq_dir: str) -> None: ticks = int(len(dates) / 5) if len(dates) > 5 else 2 # draw daily predictoor stake in $OCEAN - chart_path = os.path.join(charts_dir, "daily_slot_average_stake.png") + chart_path = os.path.join(charts_dir, "daily_average_stake.png") plt.figure(figsize=(10, 6)) plt.plot( slots_daily_df["datetime"].to_pandas(), @@ -419,7 +407,7 @@ def plot_slot_daily_statistics(slots_df: pl.DataFrame, pq_dir: str) -> None: ) plt.xlabel("Date") plt.ylabel("Average $OCEAN Staked") - plt.title("Daily Average $OCEAN Staked") + plt.title("Daily average $OCEAN staked per slot, across all Feeds") plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) plt.tight_layout() plt.savefig(chart_path) @@ -437,7 +425,7 @@ def plot_slot_daily_statistics(slots_df: pl.DataFrame, pq_dir: str) -> None: ) plt.xlabel("Date") plt.ylabel("Average Predictoors") - plt.title("Daily Average Predictoors") + plt.title("Average # Predictoors competing per slot, per feed") plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) plt.tight_layout() plt.savefig(chart_path) diff --git a/pdr_backend/util/subgraph_predictions.py b/pdr_backend/util/subgraph_predictions.py index 8b8d11ecf..5e7c3e24f 100644 --- a/pdr_backend/util/subgraph_predictions.py +++ b/pdr_backend/util/subgraph_predictions.py @@ -9,7 +9,7 @@ class ContractIdAndSPE(TypedDict): - id: str + ID: str seconds_per_epoch: int name: str @@ -18,6 +18,7 @@ class FilterMode(Enum): NONE = 0 CONTRACT = 1 PREDICTOOR = 2 + CONTRACT_TS = 3 @enforce_types @@ -67,7 +68,9 @@ def fetch_filtered_predictions( # pylint: disable=line-too-long if filter_mode == FilterMode.NONE: - where_clause = f", where: {{slot_: {{slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" + where_clause = f", where: {{timestamp_gt: {start_ts}, timestamp_lt: {end_ts}}}" + elif filter_mode == FilterMode.CONTRACT_TS: + where_clause = f", where: {{timestamp_gt: {start_ts}, timestamp_lt: {end_ts}, slot_: {{predictContract_in: {json.dumps(filters)}}}}}" elif filter_mode == FilterMode.CONTRACT: where_clause = f", where: {{slot_: {{predictContract_in: {json.dumps(filters)}, slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" elif filter_mode == FilterMode.PREDICTOOR: @@ -123,35 +126,37 @@ def fetch_filtered_predictions( if len(data) == 0: break - for prediction in data: - info725 = prediction["slot"]["predictContract"]["token"]["nft"]["nftData"] + for prediction_sg_dict in data: + info725 = prediction_sg_dict["slot"]["predictContract"]["token"]["nft"][ + "nftData" + ] info = info725_to_info(info725) pair = info["pair"] timeframe = info["timeframe"] source = info["source"] - timestamp = prediction["timestamp"] - slot = prediction["slot"]["slot"] - user = prediction["user"]["id"] + timestamp = prediction_sg_dict["timestamp"] + slot = prediction_sg_dict["slot"]["slot"] + user = prediction_sg_dict["user"]["id"] trueval = None payout = None predicted_value = None stake = None - if payout_only is True and prediction["payout"] is None: + if payout_only is True and prediction_sg_dict["payout"] is None: continue - if not prediction["payout"] is None: - stake = float(prediction["stake"]) - trueval = prediction["payout"]["trueValue"] - predicted_value = prediction["payout"]["predictedValue"] - payout = float(prediction["payout"]["payout"]) + if not prediction_sg_dict["payout"] is None: + stake = float(prediction_sg_dict["stake"]) + trueval = prediction_sg_dict["payout"]["trueValue"] + predicted_value = prediction_sg_dict["payout"]["predictedValue"] + payout = float(prediction_sg_dict["payout"]["payout"]) if trueval_only is True and trueval is None: continue - prediction_obj = Prediction( - id=prediction["id"], + prediction = Prediction( + ID=prediction_sg_dict["id"], pair=pair, timeframe=timeframe, prediction=predicted_value, @@ -163,7 +168,7 @@ def fetch_filtered_predictions( slot=slot, user=user, ) - predictions.append(prediction_obj) + predictions.append(prediction) return predictions @@ -227,21 +232,17 @@ def fetch_contract_id_and_spe( contract_addresses: List[str], network: str ) -> List[ContractIdAndSPE]: """ - This function queries a GraphQL endpoint to retrieve contract details such as - the contract ID and seconds per epoch for each contract address provided. - It supports querying both mainnet and testnet networks. + @description + Query a GraphQL endpoint to retrieve details of contracts, like + contract ID and seconds per epoch. - Args: - contract_addresses (List[str]): A list of contract addresses to query. - network (str): The blockchain network to query ('mainnet' or 'testnet'). - - Raises: - Exception: If the network is not 'mainnet' or 'testnet', or if no data is returned. + @arguments + contract_addresses - contract addresses to query + network - where to query. Eg 'mainnet' or 'testnet' - Returns: - List[ContractDetail]: A list of dictionaries containing contract details. + @return + contracts_list - where each item has contract details """ - if network not in ("mainnet", "testnet"): raise Exception("Invalid network, pick mainnet or testnet") @@ -268,15 +269,15 @@ def fetch_contract_id_and_spe( if "data" not in result: raise Exception("Error fetching contracts: No data returned") - # Parse the results and construct ContractDetail objects - contract_data = result["data"]["predictContracts"] - contracts: List[ContractIdAndSPE] = [ - { - "id": contract["id"], - "seconds_per_epoch": contract["secondsPerEpoch"], - "name": contract["token"]["name"], + contracts_sg_dict = result["data"]["predictContracts"] + + contracts_list: List[ContractIdAndSPE] = [] + for contract_sg_dict in contracts_sg_dict: + contract_item: ContractIdAndSPE = { + "ID": contract_sg_dict["id"], + "seconds_per_epoch": contract_sg_dict["secondsPerEpoch"], + "name": contract_sg_dict["token"]["name"], } - for contract in contract_data - ] + contracts_list.append(contract_item) - return contracts + return contracts_list diff --git a/pdr_backend/util/subgraph_slot.py b/pdr_backend/util/subgraph_slot.py index d2b74fb6c..7f76ccf13 100644 --- a/pdr_backend/util/subgraph_slot.py +++ b/pdr_backend/util/subgraph_slot.py @@ -9,7 +9,7 @@ @dataclass class PredictSlot: - id: str + ID: str slot: str trueValues: List[Dict[str, Any]] roundSumStakesUp: float @@ -165,7 +165,7 @@ def fetch_slots_for_all_assets( slots_by_asset: Dict[str, List[PredictSlot]] = {} for slot in all_slots: - slot_id = slot.id + slot_id = slot.ID # split the id to get the asset id asset_id = slot_id.split("-")[0] if asset_id not in slots_by_asset: @@ -228,7 +228,7 @@ def process_single_slot( return None # split the id to get the slot timestamp - timestamp = int(slot.id.split("-")[1]) # Using dot notation for attribute access + timestamp = int(slot.ID.split("-")[1]) # Using dot notation for attribute access if ( end_of_previous_day_timestamp - SECONDS_IN_A_DAY @@ -244,7 +244,7 @@ def process_single_slot( ) if prediction_result is None: - print("Prediction result is None for slot: ", slot.id) + print("Prediction result is None for slot: ", slot.ID) return ( staked_yesterday, staked_today, @@ -308,7 +308,7 @@ def aggregate_statistics( @enforce_types def calculate_statistics_for_all_assets( asset_ids: List[str], - contracts: List[ContractIdAndSPE], + contracts_list: List[ContractIdAndSPE], start_ts_param: int, end_ts_param: int, network: str = "mainnet", @@ -346,15 +346,20 @@ def calculate_statistics_for_all_assets( ) # filter contracts to get the contract with the current asset id - contract = next( - (contract for contract in contracts if contract["id"] == asset_id), + contract_item = next( + ( + contract_item + for contract_item in contracts_list + if contract_item["ID"] == asset_id + ), None, ) overall_stats[asset_id] = { - "token_name": contract["name"] if contract else None, + "token_name": contract_item["name"] if contract_item else None, "average_accuracy": average_accuracy, "total_staked_yesterday": staked_yesterday, "total_staked_today": staked_today, } + return overall_stats diff --git a/pdr_backend/util/test_data.py b/pdr_backend/util/test_data.py new file mode 100644 index 000000000..bb1149e93 --- /dev/null +++ b/pdr_backend/util/test_data.py @@ -0,0 +1,176 @@ +sample_first_predictions = [ + ( + "ADA/USDT", + "5m", + True, + 0.0500, + False, + 1701503000, + "binance", + 0.0, + 1701503100, + "0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "5m", + True, + 0.0500, + True, + 1701589400, + "binance", + 0.0, + 1701589500, + "0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] + +sample_second_predictions = [ + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1701675800, + "binance", + 0.0500, + 1701675900, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "1h", + True, + 0.0500, + False, + 1701503100, + "binance", + 0.0, + 1701503000, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ADA/USDT", + "5m", + True, + 0.0500, + True, + 1701589400, + "binance", + 0.0500, + 1701589500, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BNB/USDT", + "1h", + True, + 0.0500, + True, + 1701675800, + "kraken", + 0.0500, + 1701675900, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "1h", + True, + None, + False, + 1701589400, + "binance", + 0.0, + 1701589500, + "0xcccc4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1701675800, + "binance", + 0.0500, + 1701675900, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] + +sample_daily_predictions = [ + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1698865200, + "binance", + 0.0500, + 1698865200, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "1h", + True, + 0.0500, + False, + 1698951600, + "binance", + 0.0, + 1698951600, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ADA/USDT", + "5m", + True, + 0.0500, + True, + 1699038000, + "binance", + 0.0500, + 1699038000, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BNB/USDT", + "1h", + True, + 0.0500, + True, + 1699124400, + "kraken", + 0.0500, + 1699124400, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "1h", + True, + None, + False, + 1699214400, + "binance", + 0.0, + 1701589500, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1699300800, + "binance", + 0.0500, + 1699300800, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] diff --git a/pdr_backend/util/test_ganache/test_networkutil.py b/pdr_backend/util/test_ganache/test_networkutil.py index b6fb4f1d4..cd8607ba5 100644 --- a/pdr_backend/util/test_ganache/test_networkutil.py +++ b/pdr_backend/util/test_ganache/test_networkutil.py @@ -11,6 +11,7 @@ ) from pdr_backend.util.networkutil import ( is_sapphire_network, + get_sapphire_postfix, send_encrypted_tx, tx_call_params, tx_gas_price, @@ -26,6 +27,25 @@ def test_is_sapphire_network(): assert is_sapphire_network(SAPPHIRE_MAINNET_CHAINID) +@enforce_types +def test_get_sapphire_postfix(): + assert get_sapphire_postfix("sapphire-testnet"), "testnet" + assert get_sapphire_postfix("sapphire-mainnet"), "mainnet" + + unwanteds = [ + "oasis_saphire_testnet", + "saphire_mainnet", + "barge-pytest", + "barge-predictoor-bot", + "development", + "foo", + "", + ] + for unwanted in unwanteds: + with pytest.raises(ValueError): + assert get_sapphire_postfix(unwanted) + + @enforce_types def test_send_encrypted_tx( mock_send_encrypted_sapphire_tx, # pylint: disable=redefined-outer-name diff --git a/pdr_backend/util/test_ganache/test_subgraph_predictions.py b/pdr_backend/util/test_ganache/test_subgraph_predictions.py index 77ece7498..68873d53a 100644 --- a/pdr_backend/util/test_ganache/test_subgraph_predictions.py +++ b/pdr_backend/util/test_ganache/test_subgraph_predictions.py @@ -11,7 +11,7 @@ SAMPLE_PREDICTION = Prediction( # pylint: disable=line-too-long - id="0x18f54cc21b7a2fdd011bea06bba7801b280e3151-1698527100-0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ID="0x18f54cc21b7a2fdd011bea06bba7801b280e3151-1698527100-0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", pair="ADA/USDT", timeframe="5m", prediction=True, @@ -144,15 +144,18 @@ def test_get_all_contract_ids_by_owner( def test_fetch_contract_id_and_spe( mock_query_subgraph, ): # pylint: disable=unused-argument - contract_details = fetch_contract_id_and_spe( + contracts_list = fetch_contract_id_and_spe( contract_addresses=["contract1", "contract2"], network="mainnet" ) - assert len(contract_details) == 2 - assert contract_details[0]["id"] == "contract1" - assert contract_details[0]["seconds_per_epoch"] == 300 - assert contract_details[0]["name"] == "token1" - assert contract_details[1]["id"] == "contract2" - assert contract_details[1]["seconds_per_epoch"] == 600 - assert contract_details[1]["name"] == "token2" + assert len(contracts_list) == 2 + + c0, c1 = contracts_list # pylint: disable=unbalanced-tuple-unpacking + assert c0["ID"] == "contract1" + assert c0["seconds_per_epoch"] == 300 + assert c0["name"] == "token1" + assert c1["ID"] == "contract2" + assert c1["seconds_per_epoch"] == 600 + assert c1["name"] == "token2" + mock_query_subgraph.assert_called_once() diff --git a/pdr_backend/util/test_ganache/test_subgraph_slot.py b/pdr_backend/util/test_ganache/test_subgraph_slot.py index 5f41ecd92..f0b842140 100644 --- a/pdr_backend/util/test_ganache/test_subgraph_slot.py +++ b/pdr_backend/util/test_ganache/test_subgraph_slot.py @@ -17,9 +17,9 @@ # Sample data for tests SAMPLE_PREDICT_SLOT = PredictSlot( - id="1-12345", + ID="1-12345", slot="12345", - trueValues=[{"id": "1", "trueValue": True}], + trueValues=[{"ID": "1", "trueValue": True}], roundSumStakesUp=150.0, roundSumStakes=100.0, ) @@ -39,9 +39,9 @@ def test_get_predict_slots_query(): # Sample data for tests SAMPLE_PREDICT_SLOT = PredictSlot( - id="0xAsset-12345", + ID="0xAsset-12345", slot="12345", - trueValues=[{"id": "1", "trueValue": True}], + trueValues=[{"ID": "1", "trueValue": True}], roundSumStakesUp=150.0, roundSumStakes=100.0, ) @@ -89,7 +89,7 @@ def test_get_slots(mock_query_subgraph): # Verify that the slots contain instances of PredictSlot assert isinstance(result_slots[0], PredictSlot) # Verify the first slot's data matches the sample - assert result_slots[0].id == "0xAsset-12345" + assert result_slots[0].ID == "0xAsset-12345" @enforce_types @@ -140,23 +140,23 @@ def test_aggregate_statistics(): @enforce_types @patch("pdr_backend.util.subgraph_slot.fetch_slots_for_all_assets") def test_calculate_statistics_for_all_assets(mock_fetch_slots): - # Set up the mock to return a predetermined value + # Mocks mock_fetch_slots.return_value = {"0xAsset": [SAMPLE_PREDICT_SLOT] * 1000} - # Contracts List - contracts: List[ContractIdAndSPE] = [ - {"id": "0xAsset", "seconds_per_epoch": 300, "name": "TEST/USDT"} + contracts_list: List[ContractIdAndSPE] = [ + {"ID": "0xAsset", "seconds_per_epoch": 300, "name": "TEST/USDT"} ] - # Test the calculate_statistics_for_all_assets function + + # Main work statistics = calculate_statistics_for_all_assets( asset_ids=["0xAsset"], - contracts=contracts, + contracts_list=contracts_list, start_ts_param=1000, end_ts_param=2000, network="mainnet", ) - # Verify that the statistics are calculated as expected + + # Verify assert statistics["0xAsset"]["average_accuracy"] == 100.0 - # Verify that the mock was called as expected mock_fetch_slots.assert_called_once_with(["0xAsset"], 1000, 2000, "mainnet") @@ -175,6 +175,6 @@ def test_fetch_slots_for_all_assets(mock_query_subgraph): assert "0xAsset" in result assert all(isinstance(slot, PredictSlot) for slot in result["0xAsset"]) assert len(result["0xAsset"]) == 1 - assert result["0xAsset"][0].id == "0xAsset-12345" + assert result["0xAsset"][0].ID == "0xAsset-12345" # Verify that the mock was called mock_query_subgraph.assert_called() diff --git a/pdr_backend/util/test_noganache/conftest.py b/pdr_backend/util/test_noganache/conftest.py index 1f3caf868..5bce9d84d 100644 --- a/pdr_backend/util/test_noganache/conftest.py +++ b/pdr_backend/util/test_noganache/conftest.py @@ -1,135 +1,46 @@ -from unittest.mock import Mock - +from typing import List from enforce_typing import enforce_types import pytest -from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str -from pdr_backend.util.subgraph_predictions import Prediction +from pdr_backend.models.prediction import mock_prediction, Prediction +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.util.test_data import ( + sample_first_predictions, + sample_second_predictions, + sample_daily_predictions, +) @enforce_types @pytest.fixture(scope="session") -def mock_ppss(tmpdir_factory): +def _mock_ppss(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("my_tmpdir") - s = fast_test_yaml_str(tmpdir) - ppss = PPSS(yaml_str=s, network="development") - ppss.web3_pp = Mock() + ppss = mock_ppss("5m", ["binance c BTC/USDT"], "sapphire-mainnet", str(tmpdir)) return ppss @enforce_types @pytest.fixture(scope="session") -def sample_first_predictions(): +def _sample_first_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) + for prediction_tuple in sample_first_predictions + ] + + +@enforce_types +@pytest.fixture(scope="session") +def _sample_second_predictions() -> List[Prediction]: return [ - Prediction( - id="1", - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=False, - timestamp=1701503000, - source="binance", - payout=0.0, - slot=1701503100, - user="0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="2", - pair="BTC/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1701589400, - source="binance", - payout=0.0, - slot=1701589500, - user="0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", - ), + mock_prediction(prediction_tuple) + for prediction_tuple in sample_second_predictions ] @enforce_types @pytest.fixture(scope="session") -def sample_second_predictions(): +def _sample_daily_predictions() -> List[Prediction]: return [ - Prediction( - id="3", - pair="ETH/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1701675800, - source="binance", - payout=0.0500, - slot=1701675900, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="4", - pair="BTC/USDT", - timeframe="1h", - prediction=True, - stake=0.0500, - trueval=False, - timestamp=1701503100, - source="binance", - payout=0.0, - slot=1701503000, - user="0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="5", - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1701589400, - source="binance", - payout=0.0500, - slot=1701589500, - user="0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="6", - pair="BNB/USDT", - timeframe="1h", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1701675800, - source="kraken", - payout=0.0500, - slot=1701675900, - user="0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="7", - pair="ETH/USDT", - timeframe="1h", - prediction=True, - stake=None, - trueval=False, - timestamp=1701589400, - source="binance", - payout=0.0, - slot=1701589500, - user="0xcccc4cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - id="8", - pair="ETH/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1701675800, - source="binance", - payout=0.0500, - slot=1701675900, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), + mock_prediction(prediction_tuple) + for prediction_tuple in sample_daily_predictions ] diff --git a/pdr_backend/util/test_noganache/test_checknetwork.py b/pdr_backend/util/test_noganache/test_checknetwork.py index e8d3bf99f..36b3f2817 100644 --- a/pdr_backend/util/test_noganache/test_checknetwork.py +++ b/pdr_backend/util/test_noganache/test_checknetwork.py @@ -1,9 +1,8 @@ import time -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str from pdr_backend.util.check_network import ( WEEK, check_dfbuyer, @@ -107,17 +106,6 @@ def test_get_expected_consume(): assert get_expected_consume(for_ts, tokens) == expected -@pytest.fixture(name="mock_ppss_") -def mock_ppss(tmpdir): - s = fast_test_yaml_str(tmpdir) - ppss = PPSS(yaml_str=s, network="development") - ppss.web3_pp = Mock() - ppss.web3_pp.subgraph_url = "http://example.com/subgraph" - ppss.web3_pp.web3_config = MagicMock() - ppss.web3_pp.web3_config.w3.eth.chain_id = 1 - return ppss - - @patch("pdr_backend.util.check_network.get_opf_addresses") @patch("pdr_backend.util.subgraph.query_subgraph") @patch("pdr_backend.util.check_network.Token") @@ -125,7 +113,7 @@ def test_check_network_main( mock_token, mock_query_subgraph, mock_get_opf_addresses, - mock_ppss_, + _mock_ppss, ): mock_get_opf_addresses.return_value = { "dfbuyer": "0xdfBuyerAddress", @@ -133,10 +121,14 @@ def test_check_network_main( } mock_query_subgraph.return_value = {"data": {"predictContracts": []}} mock_token.return_value.balanceOf.return_value = 1000 * 1e18 - mock_ppss_.web3_pp.web3_config.w3.eth.get_balance.return_value = 1000 * 1e18 - check_network_main(mock_ppss_, lookback_hours=24) + + mock_w3 = Mock() # pylint: disable=not-callable + mock_w3.eth.chain_id = 1 + mock_w3.eth.get_balance.return_value = 1000 * 1e18 + _mock_ppss.web3_pp.web3_config.w3 = mock_w3 + check_network_main(_mock_ppss, lookback_hours=24) mock_get_opf_addresses.assert_called_once_with(1) - assert mock_query_subgraph.call_count == 2 + assert mock_query_subgraph.call_count == 1 mock_token.assert_called() - mock_ppss_.web3_pp.web3_config.w3.eth.get_balance.assert_called() + _mock_ppss.web3_pp.web3_config.w3.eth.get_balance.assert_called() diff --git a/pdr_backend/util/test_noganache/test_get_predictions_info.py b/pdr_backend/util/test_noganache/test_get_predictions_info.py index d8c9b2698..5c2f501c5 100644 --- a/pdr_backend/util/test_noganache/test_get_predictions_info.py +++ b/pdr_backend/util/test_noganache/test_get_predictions_info.py @@ -1,37 +1,48 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch from enforce_typing import enforce_types +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.ppss.web3_pp import del_network_override from pdr_backend.util.get_predictions_info import get_predictions_info_main from pdr_backend.util.subgraph_predictions import FilterMode @enforce_types -@patch("pdr_backend.util.get_predictions_info.get_cli_statistics") -@patch("pdr_backend.util.get_predictions_info.get_all_contract_ids_by_owner") -@patch("pdr_backend.util.get_predictions_info.fetch_filtered_predictions") def test_get_predictions_info_main_mainnet( - mock_fetch_filtered_predictions, - mock_get_all_contract_ids_by_owner, - mock_get_cli_statistics, - mock_ppss, - sample_first_predictions, + _sample_first_predictions, + tmpdir, + monkeypatch, ): - mock_ppss.web3_pp.network = "main" - mock_get_all_contract_ids_by_owner.return_value = ["0x123", "0x234"] - mock_fetch_filtered_predictions.return_value = sample_first_predictions - - get_predictions_info_main( - mock_ppss, "0x123", "2023-01-01", "2023-01-02", "parquet_data/" - ) - - mock_fetch_filtered_predictions.assert_called_with( - 1672531200, - 1672617600, - ["0x123"], - "mainnet", - FilterMode.CONTRACT, - payout_only=True, - trueval_only=True, - ) - mock_get_cli_statistics.assert_called_with(sample_first_predictions) + del_network_override(monkeypatch) + ppss = mock_ppss("5m", ["binance c BTC/USDT"], "sapphire-mainnet", str(tmpdir)) + + mock_getids = Mock(return_value=["0x123", "0x234"]) + mock_fetch = Mock(return_value=_sample_first_predictions) + mock_save = Mock() + mock_getstats = Mock() + + PATH = "pdr_backend.util.get_predictions_info" + with patch(f"{PATH}.get_all_contract_ids_by_owner", mock_getids), patch( + f"{PATH}.fetch_filtered_predictions", mock_fetch + ), patch(f"{PATH}.save_analysis_csv", mock_save), patch( + f"{PATH}.get_cli_statistics", mock_getstats + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_predictions_info_main( + ppss, "0x123", st_timestr, fin_timestr, "parquet_data/" + ) + + mock_fetch.assert_called_with( + 1698883200, + 1699142400, + ["0x123"], + "mainnet", + FilterMode.CONTRACT, + payout_only=True, + trueval_only=True, + ) + mock_save.assert_called() + mock_getstats.assert_called_with(_sample_first_predictions) diff --git a/pdr_backend/util/test_noganache/test_get_predictoors_info.py b/pdr_backend/util/test_noganache/test_get_predictoors_info.py index 2a9f644de..1a660f5d2 100644 --- a/pdr_backend/util/test_noganache/test_get_predictoors_info.py +++ b/pdr_backend/util/test_noganache/test_get_predictoors_info.py @@ -1,46 +1,40 @@ from unittest.mock import Mock, patch -import pytest +from enforce_typing import enforce_types -from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.ppss.web3_pp import del_network_override from pdr_backend.util.get_predictoors_info import get_predictoors_info_main from pdr_backend.util.subgraph_predictions import FilterMode -@pytest.fixture(name="mock_ppss_") -def mock_ppss(tmpdir): - s = fast_test_yaml_str(tmpdir) - ppss = PPSS(yaml_str=s, network="development") - ppss.web3_pp = Mock() - return ppss - - -@patch("pdr_backend.util.get_predictoors_info.fetch_filtered_predictions") -@patch("pdr_backend.util.get_predictoors_info.save_prediction_csv") -@patch("pdr_backend.util.get_predictoors_info.get_cli_statistics") -def test_get_predictoors_info_main_mainnet( - mock_get_cli_statistics, - mock_save_prediction_csv, - mock_fetch_filtered_predictions, - mock_ppss_, -): - mock_ppss_.web3_pp.network = "main" - mock_fetch_filtered_predictions.return_value = [] - - get_predictoors_info_main( - mock_ppss_, - "0x123", - "2023-01-01", - "2023-01-02", - "parquet_data/", - ) - - mock_fetch_filtered_predictions.assert_called_with( - 1672531200, - 1672617600, - ["0x123"], - "mainnet", - FilterMode.PREDICTOOR, - ) - mock_save_prediction_csv.assert_called_with([], "parquet_data/") - mock_get_cli_statistics.assert_called_with([]) +@enforce_types +def test_get_predictoors_info_main_mainnet(tmpdir, monkeypatch): + del_network_override(monkeypatch) + ppss = mock_ppss("5m", ["binance c BTC/USDT"], "sapphire-mainnet", str(tmpdir)) + + mock_fetch = Mock(return_value=[]) + mock_save = Mock() + mock_getstats = Mock() + + PATH = "pdr_backend.util.get_predictoors_info" + with patch(f"{PATH}.fetch_filtered_predictions", mock_fetch), patch( + f"{PATH}.save_prediction_csv", mock_save + ), patch(f"{PATH}.get_cli_statistics", mock_getstats): + get_predictoors_info_main( + ppss, + "0x123", + "2023-01-01", + "2023-01-02", + "parquet_data/", + ) + + mock_fetch.assert_called_with( + 1672531200, + 1672617600, + ["0x123"], + "mainnet", + FilterMode.PREDICTOOR, + ) + mock_save.assert_called_with([], "parquet_data/") + mock_getstats.assert_called_with([]) diff --git a/pdr_backend/util/test_noganache/test_get_traction_info.py b/pdr_backend/util/test_noganache/test_get_traction_info.py index b3fdd1250..4a9b09d59 100644 --- a/pdr_backend/util/test_noganache/test_get_traction_info.py +++ b/pdr_backend/util/test_noganache/test_get_traction_info.py @@ -1,43 +1,75 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch from enforce_typing import enforce_types +import polars as pl +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.ppss.web3_pp import del_network_override from pdr_backend.util.get_traction_info import get_traction_info_main from pdr_backend.util.subgraph_predictions import FilterMode +from pdr_backend.util.timeutil import timestr_to_ut @enforce_types -@patch("pdr_backend.util.get_traction_info.get_traction_statistics") -@patch("pdr_backend.util.get_traction_info.get_all_contract_ids_by_owner") -@patch("pdr_backend.util.get_traction_info.fetch_filtered_predictions") -@patch("pdr_backend.util.get_traction_info.plot_traction_cum_sum_statistics") -@patch("pdr_backend.util.get_traction_info.plot_traction_daily_statistics") def test_get_traction_info_main_mainnet( - mock_plot_traction_daily_statistics, - mock_plot_traction_cum_sum_statistics, - mock_fetch_filtered_predictions, - mock_get_all_contract_ids_by_owner, - mock_get_traction_statistics, - mock_ppss, - sample_first_predictions, + _sample_daily_predictions, + tmpdir, + monkeypatch, ): - mock_ppss.web3_pp.network = "main" - mock_get_all_contract_ids_by_owner.return_value = ["0x123", "0x234"] - mock_fetch_filtered_predictions.return_value = sample_first_predictions - - get_traction_info_main( - mock_ppss, "0x123", "2023-01-01", "2023-01-02", "parquet_data/" - ) - - mock_fetch_filtered_predictions.assert_called_with( - 1672531200, - 1672617600, - ["0x123"], - "mainnet", - FilterMode.CONTRACT, - payout_only=False, - trueval_only=False, - ) - mock_get_traction_statistics.assert_called_with(sample_first_predictions) - mock_plot_traction_cum_sum_statistics.assert_called() - mock_plot_traction_daily_statistics.assert_called() + del_network_override(monkeypatch) + ppss = mock_ppss("5m", ["binance c BTC/USDT"], "sapphire-mainnet", str(tmpdir)) + + mock_traction_stat = Mock() + mock_plot_cumsum = Mock() + mock_plot_daily = Mock() + mock_getids = Mock(return_value=["0x123"]) + mock_fetch = Mock(return_value=_sample_daily_predictions) + + PATH = "pdr_backend.util.get_traction_info" + PATH2 = "pdr_backend.data_eng" + with patch(f"{PATH}.get_traction_statistics", mock_traction_stat), patch( + f"{PATH}.plot_traction_cum_sum_statistics", mock_plot_cumsum + ), patch(f"{PATH}.plot_traction_daily_statistics", mock_plot_daily), patch( + f"{PATH2}.gql_data_factory.get_all_contract_ids_by_owner", mock_getids + ), patch( + f"{PATH2}.table_pdr_predictions.fetch_filtered_predictions", mock_fetch + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_traction_info_main(ppss, st_timestr, fin_timestr, "parquet_data/") + + mock_fetch.assert_called_with( + 1698883200, + 1699142400, + ["0x123"], + "mainnet", + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + # calculate ms locally so we can filter raw Predictions + st_ut = timestr_to_ut(st_timestr) + fin_ut = timestr_to_ut(fin_timestr) + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # Get all predictions into a dataframe + preds = [ + x + for x in _sample_daily_predictions + if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + preds = [pred.__dict__ for pred in preds] + preds_df = pl.DataFrame(preds) + preds_df = preds_df.with_columns( + [ + pl.col("timestamp").mul(1000).alias("timestamp"), + ] + ) + + # Assert calls and values + pl.DataFrame.equals(mock_traction_stat.call_args, preds_df) + mock_plot_cumsum.assert_called() + mock_plot_daily.assert_called() diff --git a/pdr_backend/util/test_noganache/test_predictoor_stats.py b/pdr_backend/util/test_noganache/test_predictoor_stats.py index b54bb88cb..986be3133 100644 --- a/pdr_backend/util/test_noganache/test_predictoor_stats.py +++ b/pdr_backend/util/test_noganache/test_predictoor_stats.py @@ -18,9 +18,9 @@ @enforce_types -def test_aggregate_prediction_statistics(sample_first_predictions): +def test_aggregate_prediction_statistics(_sample_first_predictions): stats, correct_predictions = aggregate_prediction_statistics( - sample_first_predictions + _sample_first_predictions ) assert isinstance(stats, dict) assert "pair_timeframe" in stats @@ -29,9 +29,9 @@ def test_aggregate_prediction_statistics(sample_first_predictions): @enforce_types -def test_get_endpoint_statistics(sample_first_predictions): +def test_get_endpoint_statistics(_sample_first_predictions): accuracy, pair_timeframe_stats, predictoor_stats = get_endpoint_statistics( - sample_first_predictions + _sample_first_predictions ) assert isinstance(accuracy, float) assert isinstance(pair_timeframe_stats, List) # List[PairTimeframeStat] @@ -61,8 +61,8 @@ def test_get_endpoint_statistics(sample_first_predictions): @enforce_types -def test_get_cli_statistics(capsys, sample_first_predictions): - get_cli_statistics(sample_first_predictions) +def test_get_cli_statistics(capsys, _sample_first_predictions): + get_cli_statistics(_sample_first_predictions) captured = capsys.readouterr() output = captured.out assert "Overall Accuracy" in output @@ -73,10 +73,15 @@ def test_get_cli_statistics(capsys, sample_first_predictions): @enforce_types @patch("matplotlib.pyplot.savefig") def test_get_traction_statistics( - mock_savefig, sample_first_predictions, sample_second_predictions + mock_savefig, _sample_first_predictions, _sample_second_predictions ): - predictions = sample_first_predictions + sample_second_predictions - stats_df = get_traction_statistics(predictions) + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + stats_df = get_traction_statistics(preds_df) assert isinstance(stats_df, pl.DataFrame) assert stats_df.shape == (3, 3) assert "datetime" in stats_df.columns @@ -91,9 +96,15 @@ def test_get_traction_statistics( @enforce_types -def test_get_slot_statistics(sample_first_predictions, sample_second_predictions): - predictions = sample_first_predictions + sample_second_predictions - slots_df = get_slot_statistics(predictions) +def test_get_slot_statistics(_sample_first_predictions, _sample_second_predictions): + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + # calculate slot stats + slots_df = get_slot_statistics(preds_df) assert isinstance(slots_df, pl.DataFrame) assert slots_df.shape == (7, 9) @@ -116,10 +127,16 @@ def test_get_slot_statistics(sample_first_predictions, sample_second_predictions @enforce_types @patch("matplotlib.pyplot.savefig") def test_plot_slot_statistics( - mock_savefig, sample_first_predictions, sample_second_predictions + mock_savefig, _sample_first_predictions, _sample_second_predictions ): - predictions = sample_first_predictions + sample_second_predictions - slots_df = get_slot_statistics(predictions) + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + # calculate slot stats + slots_df = get_slot_statistics(preds_df) slot_daily_df = calculate_slot_daily_statistics(slots_df) for key in [