From 79a283389c0b54453946ef5d93ceb9a3be65eb58 Mon Sep 17 00:00:00 2001 From: trizin <25263018+trizin@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:42:53 +0300 Subject: [PATCH] Towards #1579 - technical indicators (#1578) * Add ta_features parameter to PredictTrainFeedset constructor * Add ta_features parameter to PredictTrainFeedset constructor * Add TechnicalIndicator class for calculating technical indicators * formatting * Add MACD technical indicator class * Add RSI technical indicator class * Add get_ta_indicator function for retrieving technical indicator class * Add ta_features parameter to SimEngine constructor * Formatting * Add ta_features parameter to ppss.yaml * Format * Add technical indicator features to AimodelDataFactory The code changes in `aimodel_data_factory.py` introduce the `ta_features` parameter to the `AimodelDataFactory` class. This parameter allows for the calculation of technical indicator features based on the provided feeds. The technical indicators are retrieved using the `get_ta_indicator` function, which has been added in a recent commit. * add ta * assert correct * Typo fix * Refactor TechnicalIndicator constructor parameter names for clarity * linter * linter * Formatting * Add mypy configuration for ta package * Better handling * remove unused import * Readability * formatting * Add MockTechnicalIndicator for testing purposes * Add conftest.py for technical indicators tests * test get_ta_indicator * Add unit test for MACD indicator * test RSI calculation against ta library * test TechnicalIndicator * Linter fixes * Update tests * linter fixes --- mypy.ini | 2 + pdr_backend/aimodel/aimodel_data_factory.py | 36 +++++++- pdr_backend/cli/predict_train_feedset.py | 36 ++++++-- pdr_backend/cli/predict_train_feedsets.py | 12 ++- .../cli/test/test_predict_train_feedset.py | 17 ++-- .../cli/test/test_predict_train_feedsets.py | 9 +- pdr_backend/sim/sim_engine.py | 2 + .../technical_indicators/get_indicator.py | 16 ++++ .../technical_indicators/indicators/macd.py | 24 ++++++ .../technical_indicators/indicators/rsi.py | 20 +++++ .../technical_indicator.py | 83 +++++++++++++++++++ .../technical_indicators/tests/conftest.py | 14 ++++ .../tests/test_get_indicator.py | 17 ++++ .../technical_indicators/tests/test_macd.py | 22 +++++ .../technical_indicators/tests/test_rsi.py | 22 +++++ .../tests/test_technical_indicator.py | 40 +++++++++ ppss.yaml | 1 + setup.py | 1 + 18 files changed, 353 insertions(+), 21 deletions(-) create mode 100644 pdr_backend/technical_indicators/get_indicator.py create mode 100644 pdr_backend/technical_indicators/indicators/macd.py create mode 100644 pdr_backend/technical_indicators/indicators/rsi.py create mode 100644 pdr_backend/technical_indicators/technical_indicator.py create mode 100644 pdr_backend/technical_indicators/tests/conftest.py create mode 100644 pdr_backend/technical_indicators/tests/test_get_indicator.py create mode 100644 pdr_backend/technical_indicators/tests/test_macd.py create mode 100644 pdr_backend/technical_indicators/tests/test_rsi.py create mode 100644 pdr_backend/technical_indicators/tests/test_technical_indicator.py diff --git a/mypy.ini b/mypy.ini index 37f689433..3ea042d16 100644 --- a/mypy.ini +++ b/mypy.ini @@ -95,3 +95,5 @@ ignore_missing_imports = True [mypy-yaml.*] ignore_missing_imports = True +[mypy-ta.*] +ignore_missing_imports = True diff --git a/pdr_backend/aimodel/aimodel_data_factory.py b/pdr_backend/aimodel/aimodel_data_factory.py index 561393700..a02bec45d 100644 --- a/pdr_backend/aimodel/aimodel_data_factory.py +++ b/pdr_backend/aimodel/aimodel_data_factory.py @@ -14,6 +14,7 @@ from pdr_backend.cli.arg_feed import ArgFeed from pdr_backend.cli.arg_feeds import ArgFeeds from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.technical_indicators import get_indicator from pdr_backend.util.mathutil import fill_nans, has_nan logger = logging.getLogger("aimodel_data_factory") @@ -64,6 +65,7 @@ def create_xy( predict_feed: ArgFeed, train_feeds: Optional[ArgFeeds] = None, do_fill_nans: bool = True, + ta_features: Optional[List[str]] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame, np.ndarray]: """ @description @@ -116,6 +118,27 @@ def create_xy( x_dim_len = len(train_feeds_list) * ss.autoregressive_n diff = 0 if ss.transform == "None" else 1 + features: List[pd.Series] = [] + if ta_features: + for feed in train_feeds_list: + # Generate feed keys + feed_keys = { + key: f"{feed.exchange}:{feed.pair}:{key}" + for key in ["close", "open", "high", "low", "volume"] + } + + for feature in ta_features: + ta_class = get_indicator.get_ta_indicator(feature) + if ta_class is None: + raise ValueError(f"Unknown TA feature: {feature}") + + ta = ta_class(mergedohlcv_df.to_pandas(), **feed_keys) + features.append(ta.calculate()) + + # Verify the results + num_features = len(ta_features) * len(train_feeds_list) + assert len(features) == num_features + assert len(features[0]) == len(mergedohlcv_df) # main work xcol_list = [] # [col_i] : name_str x_list = [] # [col_i] : Series. Build this up. Not df here (slow) @@ -152,6 +175,16 @@ def create_xy( x_col = hist_col + f":(z(t-{ds1})-z(t-{ds11}))/z(t-{ds11})" xcol_list += [x_col] + for i, feature in enumerate(features): + assert type(feature) == pd.Series # type check for mypy + feature_np = list(feature.values) + features_shifted = pd.Series( + _slice(feature_np, -shift - N_train - 1, -shift) + ) + x_list += [features_shifted] + xrecent_list += [pd.Series(_slice(feature_np, -shift, -shift + 1))] + xcol_list.append(f"{feature.name}_t-{ds1}-{i}") + # convert x lists to dfs, all at once. Faster than building up df. assert len(x_list) == len(xrecent_list) == len(xcol_list) x_df = pd.concat(x_list, keys=xcol_list, axis=1) @@ -181,7 +214,8 @@ def create_xy( # postconditions assert X.shape[0] == yraw.shape[0] == ytran.shape[0] assert X.shape[0] <= (N_train + 1) - assert X.shape[1] == x_dim_len + feature_dims = len(features) * len(train_feeds_list) * ss.autoregressive_n + assert X.shape[1] == x_dim_len + feature_dims assert isinstance(x_df, pd.DataFrame) assert "timestamp" not in x_df.columns diff --git a/pdr_backend/cli/predict_train_feedset.py b/pdr_backend/cli/predict_train_feedset.py index 5e885ea1a..11116d4ea 100644 --- a/pdr_backend/cli/predict_train_feedset.py +++ b/pdr_backend/cli/predict_train_feedset.py @@ -2,7 +2,7 @@ # Copyright 2024 Ocean Protocol Foundation # SPDX-License-Identifier: Apache-2.0 # -from typing import List +from typing import List, Dict from enforce_typing import enforce_types from typeguard import check_type @@ -21,9 +21,15 @@ class PredictTrainFeedset: """ @enforce_types - def __init__(self, predict: ArgFeed, train_on: ArgFeeds): + def __init__( + self, + predict: ArgFeed, + train_on: ArgFeeds, + ta_features: List[str] = [], + ): self.predict: ArgFeed = predict self.train_on: ArgFeeds = train_on + self.ta_features: List[str] = ta_features if ta_features else [] @enforce_types def __str__(self) -> str: @@ -31,11 +37,19 @@ def __str__(self) -> str: @enforce_types def __eq__(self, other): - return self.predict == other.predict and self.train_on == other.train_on + return ( + self.predict == other.predict + and self.train_on == other.train_on + and self.ta_features == other.ta_features + ) @enforce_types - def to_dict(self): - return {"predict": str(self.predict), "train_on": str(self.train_on)} + def to_dict(self) -> Dict: + return { + "predict": str(self.predict), + "train_on": str(self.train_on), + "ta_features": self.ta_features, + } @classmethod def from_dict(cls, feedset_dict: dict) -> "PredictTrainFeedset": @@ -43,7 +57,8 @@ def from_dict(cls, feedset_dict: dict) -> "PredictTrainFeedset": @arguments feedset_dict -- has the following format: {"predict":predict_feed_str (1 feed), - "train_on":train_on_feeds_str (>=1 feeds)} + "train_on":train_on_feeds_str (>=1 feeds), + "ta_features":list of extra features} Note just ONE predict feed is allowed, not >=1. Here are three examples. from_dict() gives the same output for each. @@ -52,11 +67,14 @@ def from_dict(cls, feedset_dict: dict) -> "PredictTrainFeedset": 2. { "predict" : "binance BTC/USDT o 1h", "train_on" : "binance BTC/USDT o 1h, binance ETH/USDT o 1h"} 3. { "predict" : "binance BTC/USDT o 1h", - "train_on" : ["binance BTC/USDT o 1h", "binance ETH/USDT o 1h"]} + "train_on" : ["binance BTC/USDT o 1h", "binance ETH/USDT o 1h"], + "ta_features": ["rsi", "macd"]} """ predict = ArgFeed.from_str(feedset_dict["predict"]) - train_on = ArgFeeds.from_strs(_as_list(feedset_dict["train_on"])) - return cls(predict, train_on) + train_on = ArgFeeds.from_strs(_as_list(feedset_dict.get("train_on"))) + ta_features = feedset_dict.get("ta_features", []) + check_type(ta_features, List[str]) + return cls(predict, train_on, ta_features) @property def timeframe_ms(self) -> int: diff --git a/pdr_backend/cli/predict_train_feedsets.py b/pdr_backend/cli/predict_train_feedsets.py index f2336cef2..1e80ef204 100644 --- a/pdr_backend/cli/predict_train_feedsets.py +++ b/pdr_backend/cli/predict_train_feedsets.py @@ -34,8 +34,11 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets": @arguments feedset_list -- list of feedset_dict, where feedset_dict has the following format: - {"predict":predict_feeds_str, - "train_on":train_on_feeds_str} + { + "predict": predict_feeds_str, + "train_on": train_on_feeds_str, + "ta_features": ["feature1", "feature2"] + } Note that >=1 predict feeds are allowed for a given feedset_dict. Example feedset_list = [ @@ -45,10 +48,12 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets": "binance BTC/USDT ETH/USDT DOT/USDT c 5m", "kraken BTC/USDT c 5m", ], + "ta_features": ["macd", "rsi"] }, { "predict": "binance ETH/USDT ADA/USDT c 5m", "train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m", + "ta_features": ["ema", "rvi"] }, """ final_list = [] @@ -57,9 +62,10 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets": raise ValueError(feedset_dict) predict_feeds: ArgFeeds = parse_feed_obj(feedset_dict["predict"]) + ta_features = feedset_dict.get("ta_features", []) for predict in predict_feeds: train_on = parse_feed_obj(feedset_dict["train_on"]) - feedset = PredictTrainFeedset(predict, train_on) + feedset = PredictTrainFeedset(predict, train_on, ta_features) final_list.append(feedset) return cls(final_list) diff --git a/pdr_backend/cli/test/test_predict_train_feedset.py b/pdr_backend/cli/test/test_predict_train_feedset.py index b6a2896c5..26454252b 100644 --- a/pdr_backend/cli/test/test_predict_train_feedset.py +++ b/pdr_backend/cli/test/test_predict_train_feedset.py @@ -28,11 +28,16 @@ def test_feedset_main(): assert feedset.train_on == ARG_FEEDS assert feedset.timeframe_ms == ARG_FEED.timeframe.ms - assert feedset.to_dict() == {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR} + assert feedset.to_dict() == { + "predict": ARG_FEED_STR, + "train_on": ARG_FEEDS_STR, + "ta_features": [], + } assert ( str(feedset) - == "{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h'}" + # pylint: disable=line-too-long + == "{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h', 'ta_features': []}" ) @@ -62,24 +67,24 @@ def test_feedset_eq_diff(): @enforce_types def test_feedset_from_dict(): # "train_on" as str - d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR} + d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR, "ta_features": []} feedset = PredictTrainFeedset.from_dict(d) assert feedset.predict == ARG_FEED assert feedset.train_on == ARG_FEEDS assert feedset.to_dict() == d # "train_on" as list - d = {"predict": ARG_FEED_STR, "train_on": [ARG_FEEDS_STR]} + d = {"predict": ARG_FEED_STR, "train_on": [ARG_FEEDS_STR], "ta_features": []} feedset = PredictTrainFeedset.from_dict(d) assert feedset.predict == ARG_FEED assert feedset.train_on == ARG_FEEDS # "predict" value must be a str - d = {"predict": ARG_FEED, "train_on": ARG_FEEDS_STR} + d = {"predict": ARG_FEED, "train_on": ARG_FEEDS_STR, "ta_features": []} with pytest.raises(TypeError): feedset = PredictTrainFeedset.from_dict(d) # "train_on" value must be a str - d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS} + d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS, "ta_features": []} with pytest.raises(TypeCheckError): feedset = PredictTrainFeedset.from_dict(d) diff --git a/pdr_backend/cli/test/test_predict_train_feedsets.py b/pdr_backend/cli/test/test_predict_train_feedsets.py index 97fe15543..21bd4fa13 100644 --- a/pdr_backend/cli/test/test_predict_train_feedsets.py +++ b/pdr_backend/cli/test/test_predict_train_feedsets.py @@ -21,7 +21,7 @@ ARG_FEEDS: ArgFeeds = ArgFeeds.from_str(ARG_FEEDS_STR) # ("predict", "train_on") set -FEEDSET_DICT = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR} +FEEDSET_DICT = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR, "ta_features": []} FEEDSET = PredictTrainFeedset(predict=ARG_FEED, train_on=ARG_FEEDS) @@ -38,7 +38,8 @@ def test_feedsets_1_feedset(): assert feedsets == PredictTrainFeedsets([FEEDSET]) assert ( str(feedsets) - == "[{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h'}]" + # pylint: disable=line-too-long + == "[{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h', 'ta_features': []}]" ) feedsets2 = PredictTrainFeedsets.from_list_of_dict([FEEDSET_DICT]) @@ -102,18 +103,22 @@ def test_feedsets_from_list_of_dict__thorough(): { "predict": "binance BTC/USDT c 5m", "train_on": "binance BTC/USDT DOT/USDT ETH/USDT c 5m, kraken BTC/USDT c 5m", + "ta_features": [], }, { "predict": "kraken BTC/USDT c 5m", "train_on": "binance BTC/USDT DOT/USDT ETH/USDT c 5m, kraken BTC/USDT c 5m", + "ta_features": [], }, { "predict": "binance ETH/USDT c 5m", "train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m", + "ta_features": [], }, { "predict": "binance ADA/USDT c 5m", "train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m", + "ta_features": [], }, ] assert feedset_list2 == target_feedset_list2 diff --git a/pdr_backend/sim/sim_engine.py b/pdr_backend/sim/sim_engine.py index 0313da556..e98034330 100644 --- a/pdr_backend/sim/sim_engine.py +++ b/pdr_backend/sim/sim_engine.py @@ -129,6 +129,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame): data_f = AimodelDataFactory(pdr_ss) # type: ignore[arg-type] predict_feed = self.predict_train_feedset.predict train_feeds = self.predict_train_feedset.train_on + features = self.predict_train_feedset.ta_features # X, ycont, and x_df are all expressed in % change wrt prev candle X, ytran, yraw, x_df, _ = data_f.create_xy( @@ -136,6 +137,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame): testshift, predict_feed, train_feeds, + ta_features=features, ) colnames = list(x_df.columns) diff --git a/pdr_backend/technical_indicators/get_indicator.py b/pdr_backend/technical_indicators/get_indicator.py new file mode 100644 index 000000000..20ec92709 --- /dev/null +++ b/pdr_backend/technical_indicators/get_indicator.py @@ -0,0 +1,16 @@ +from typing import Optional, Type +from pdr_backend.technical_indicators.indicators.macd import MACD +from pdr_backend.technical_indicators.indicators.rsi import RSI +from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator + +indicators = { + "rsi": RSI, + "macd": MACD, +} + + +def get_ta_indicator(indicator: str) -> Optional[Type[TechnicalIndicator]]: + """ + Returns the technical indicator class based on the input indicator name. + """ + return indicators.get(indicator) diff --git a/pdr_backend/technical_indicators/indicators/macd.py b/pdr_backend/technical_indicators/indicators/macd.py new file mode 100644 index 000000000..46eb63675 --- /dev/null +++ b/pdr_backend/technical_indicators/indicators/macd.py @@ -0,0 +1,24 @@ +import pandas as pd +import ta +from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator + + +class MACD(TechnicalIndicator): + """ + Moving Average Convergence Divergence (MACD) technical indicator. + """ + + def calculate(self, *args, **kwargs) -> pd.Series: + """ + Calculates the MACD value based on the input data. + + @param: + window_fast - The window size for the fast EMA calculation (default=12). + window_slow - The window size for the slow EMA calculation (default=26). + """ + window_fast = kwargs.get("window_fast", 12) + window_slow = kwargs.get("window_slow", 26) + macd = ta.trend.MACD( + close=self._close(), window_fast=window_fast, window_slow=window_slow + ) + return macd.macd() diff --git a/pdr_backend/technical_indicators/indicators/rsi.py b/pdr_backend/technical_indicators/indicators/rsi.py new file mode 100644 index 000000000..864321d03 --- /dev/null +++ b/pdr_backend/technical_indicators/indicators/rsi.py @@ -0,0 +1,20 @@ +import pandas as pd +import ta +from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator + + +class RSI(TechnicalIndicator): + """ + Relative Strength Index (RSI) technical indicator. + """ + + def calculate(self, *args, **kwargs) -> pd.Series: + """ + Calculates the RSI value based on the input data. + + @param: + window - The window size for the RSI calculation (default=14). + """ + window = kwargs.get("window", 14) + rsi = ta.momentum.RSIIndicator(close=self._close(), window=window).rsi() + return rsi diff --git a/pdr_backend/technical_indicators/technical_indicator.py b/pdr_backend/technical_indicators/technical_indicator.py new file mode 100644 index 000000000..cd69a5748 --- /dev/null +++ b/pdr_backend/technical_indicators/technical_indicator.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from enforce_typing import enforce_types +import pandas as pd + + +@enforce_types +class TechnicalIndicator(ABC): + """ + Abstract base class for technical indicators. + + Attributes: + df - pd.DataFrame + The input dataframe containing the time series data. + open - str + The name of the column containing opening price data. + high - str + The name of the column containing high price data. + low - str + The name of the column containing low price data. + close - str + The name of the column containing closing price data. + volume - str + The name of the column containing volume data. + + Methods: + calculate(*args, **kwargs) -> pd.Series + Calculates the indicator value based on the input data. + """ + + def __init__( + self, + df: pd.DataFrame, + open_col: str, + high_col: str, + low_col: str, + close_col: str, + volume_col: str, + ): + """ + Initializes a TechnicalIndicator object. + @param: + open - name of column containing opening price data, + high - name of column containing high price data, + low - name of column containing low price data, + close - name of column containing closing price data, + volume - name of column containing volume data, + """ + self.df = df + self.open_col = open_col + self.high_col = high_col + self.low_col = low_col + self.close_col = close_col + self.volume_col = volume_col + + def _open(self): + return self.df[self.open_col] + + def _high(self): + return self.df[self.high_col] + + def _low(self): + return self.df[self.low_col] + + def _close(self): + return self.df[self.close_col] + + def _volume(self): + return self.df[self.volume_col] + + @abstractmethod + def calculate(self, *args, **kwargs) -> pd.Series: + """ + Calculates the indicator value based on the input data. + + @return + pd.Series - the indicator. + """ + + +class MockTechnicalIndicator(TechnicalIndicator): + def calculate(self, *args, **kwargs) -> pd.Series: + # Example implementation for testing purposes + return self._close() * 0.5 diff --git a/pdr_backend/technical_indicators/tests/conftest.py b/pdr_backend/technical_indicators/tests/conftest.py new file mode 100644 index 000000000..63d46643d --- /dev/null +++ b/pdr_backend/technical_indicators/tests/conftest.py @@ -0,0 +1,14 @@ +import pandas as pd +import pytest + + +@pytest.fixture +def sample_df(): + data = { + "open": [1.0, 2.0, 3.0, 4.0, 5.0], + "high": [1.5, 2.5, 3.5, 4.5, 5.5], + "low": [0.5, 1.5, 2.5, 3.5, 4.5], + "close": [1.2, 2.2, 3.2, 4.2, 5.2], + "volume": [1000, 1500, 2000, 2500, 3000], + } + return pd.DataFrame(data) diff --git a/pdr_backend/technical_indicators/tests/test_get_indicator.py b/pdr_backend/technical_indicators/tests/test_get_indicator.py new file mode 100644 index 000000000..c1244bac6 --- /dev/null +++ b/pdr_backend/technical_indicators/tests/test_get_indicator.py @@ -0,0 +1,17 @@ +from pdr_backend.technical_indicators.get_indicator import get_ta_indicator +from pdr_backend.technical_indicators.indicators.macd import MACD +from pdr_backend.technical_indicators.indicators.rsi import RSI + + +def test_get_ta_indicator_valid(): + assert get_ta_indicator("rsi") == RSI + assert get_ta_indicator("macd") == MACD + + +def test_get_ta_indicator_invalid(): + assert get_ta_indicator("invalid_indicator") is None + + +def test_get_ta_indicator_case_sensitivity(): + assert get_ta_indicator("RSI") is None + assert get_ta_indicator("MACD") is None diff --git a/pdr_backend/technical_indicators/tests/test_macd.py b/pdr_backend/technical_indicators/tests/test_macd.py new file mode 100644 index 000000000..313255432 --- /dev/null +++ b/pdr_backend/technical_indicators/tests/test_macd.py @@ -0,0 +1,22 @@ +import pandas as pd +import ta +from pdr_backend.technical_indicators.indicators.macd import MACD + + +def test_macd(sample_df): + macd_indicator = MACD( + df=sample_df, + open_col="open", + high_col="high", + low_col="low", + close_col="close", + volume_col="volume", + ) + + macd_result = macd_indicator.calculate(window_fast=12, window_slow=26) + + expected_macd = ta.trend.MACD( + close=sample_df["close"], window_fast=12, window_slow=26 + ).macd() + + pd.testing.assert_series_equal(macd_result, expected_macd, check_dtype=False) diff --git a/pdr_backend/technical_indicators/tests/test_rsi.py b/pdr_backend/technical_indicators/tests/test_rsi.py new file mode 100644 index 000000000..47e35c8f5 --- /dev/null +++ b/pdr_backend/technical_indicators/tests/test_rsi.py @@ -0,0 +1,22 @@ +import pandas as pd +import ta +from pdr_backend.technical_indicators.indicators.rsi import RSI + + +def test_rsi(sample_df): + rsi_indicator = RSI( + df=sample_df, + open_col="open", + high_col="high", + low_col="low", + close_col="close", + volume_col="volume", + ) + + # Calculate RSI + rsi_result = rsi_indicator.calculate(window=14) + + # Expected RSI calculation using `ta` library + expected_rsi = ta.momentum.RSIIndicator(close=sample_df["close"], window=14).rsi() + + pd.testing.assert_series_equal(rsi_result, expected_rsi, check_dtype=False) diff --git a/pdr_backend/technical_indicators/tests/test_technical_indicator.py b/pdr_backend/technical_indicators/tests/test_technical_indicator.py new file mode 100644 index 000000000..40c758b03 --- /dev/null +++ b/pdr_backend/technical_indicators/tests/test_technical_indicator.py @@ -0,0 +1,40 @@ +import pytest +import pandas as pd + +from pdr_backend.technical_indicators.technical_indicator import ( + MockTechnicalIndicator, + TechnicalIndicator, +) + + +def test_mock_technical_indicator(sample_df): + indicator = MockTechnicalIndicator( + df=sample_df, + open_col="open", + high_col="high", + low_col="low", + close_col="close", + volume_col="volume", + ) + + assert indicator._open().equals(sample_df["open"]) + assert indicator._high().equals(sample_df["high"]) + assert indicator._low().equals(sample_df["low"]) + assert indicator._close().equals(sample_df["close"]) + assert indicator._volume().equals(sample_df["volume"]) + + expected_result = sample_df["close"] * 0.5 + pd.testing.assert_series_equal(indicator.calculate(), expected_result) + + +def test_abstract_method_implementation(): + with pytest.raises(TypeError): + # pylint: disable=abstract-class-instantiated + TechnicalIndicator( + df=pd.DataFrame(), + open_col="open", + high_col="high", + low_col="low", + close_col="close", + volume_col="volume", + ) diff --git a/ppss.yaml b/ppss.yaml index f3ce357a0..b9ad0ea38 100644 --- a/ppss.yaml +++ b/ppss.yaml @@ -19,6 +19,7 @@ predictoor_ss: - predict: binance BTC/USDT c 5m train_on: - binance BTC/USDT c 5m + ta_features: [] # list of TA features to use. approach: 2 # 1->50/50; 2->two-sided model-based; 3-> one-sided model-based stake_amount: 100 # How much your bot stakes. In OCEAN per epoch, per feed sim_only: diff --git a/setup.py b/setup.py index 2dd56e018..81d8b1ded 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "web3==6.20.2", "sapphire.py==0.2.3", "stopit==1.1.2", + "ta==0.11.0", "ocean-contracts==2.1.0", # install this last ]