Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards #1579 - technical indicators #1578

Merged
merged 38 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5c1e813
Add ta_features parameter to PredictTrainFeedset constructor
trizin Sep 3, 2024
c2beeeb
Add ta_features parameter to PredictTrainFeedset constructor
trizin Sep 3, 2024
9e85f30
Add TechnicalIndicator class for calculating technical indicators
trizin Sep 3, 2024
5019bef
formatting
trizin Sep 3, 2024
7bf15c2
Add MACD technical indicator class
trizin Sep 3, 2024
bec5104
Add RSI technical indicator class
trizin Sep 3, 2024
56421d3
Add get_ta_indicator function for retrieving technical indicator class
trizin Sep 3, 2024
f5e60c3
Add ta_features parameter to SimEngine constructor
trizin Sep 3, 2024
6a6065b
Formatting
trizin Sep 3, 2024
e3af740
Add ta_features parameter to ppss.yaml
trizin Sep 3, 2024
5ae2d92
Format
trizin Sep 3, 2024
50baee7
Add technical indicator features to AimodelDataFactory
trizin Sep 3, 2024
97e5582
add ta
trizin Sep 4, 2024
56c2d8d
assert correct
trizin Sep 4, 2024
dfcbb78
Typo fix
trizin Sep 4, 2024
db249bd
Refactor TechnicalIndicator constructor parameter names for clarity
trizin Sep 4, 2024
c22e153
linter
trizin Sep 4, 2024
00258b6
linter
trizin Sep 4, 2024
a9474db
Formatting
trizin Sep 4, 2024
5d9514e
Add mypy configuration for ta package
trizin Sep 4, 2024
6f03cd5
Better handling
trizin Sep 4, 2024
ebd9ede
remove unused import
trizin Sep 4, 2024
5fe7048
Merge branch 'main' into issue1406-technical-indicators
trizin Sep 4, 2024
a3d38ce
Readability
trizin Sep 4, 2024
8690611
Merge branch 'issue1406-technical-indicators' of https://github.com/o…
trizin Sep 4, 2024
aa9a219
formatting
trizin Sep 4, 2024
279bf79
Add MockTechnicalIndicator for testing purposes
trizin Sep 4, 2024
ebf9071
Add conftest.py for technical indicators tests
trizin Sep 4, 2024
74443a9
test get_ta_indicator
trizin Sep 4, 2024
7273e75
Add unit test for MACD indicator
trizin Sep 4, 2024
1ad5c3e
test RSI calculation against ta library
trizin Sep 4, 2024
4d92a68
test TechnicalIndicator
trizin Sep 4, 2024
22693e1
Linter fixes
trizin Sep 4, 2024
cab2634
Merge branch 'main' into issue1406-technical-indicators
trizin Sep 5, 2024
d976a6c
Update tests
trizin Sep 5, 2024
f1e1ebf
Merge branch 'issue1406-technical-indicators' of https://github.com/o…
trizin Sep 5, 2024
dab110c
linter fixes
trizin Sep 5, 2024
cb9de91
Merge branch 'main' into issue1406-technical-indicators
trizin Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,5 @@ ignore_missing_imports = True
[mypy-yaml.*]
ignore_missing_imports = True

[mypy-ta.*]
ignore_missing_imports = True
36 changes: 35 additions & 1 deletion pdr_backend/aimodel/aimodel_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pdr_backend.cli.arg_feed import ArgFeed
from pdr_backend.cli.arg_feeds import ArgFeeds
from pdr_backend.ppss.predictoor_ss import PredictoorSS
from pdr_backend.technical_indicators import get_indicator
from pdr_backend.util.mathutil import fill_nans, has_nan

logger = logging.getLogger("aimodel_data_factory")
Expand Down Expand Up @@ -64,6 +65,7 @@ def create_xy(
predict_feed: ArgFeed,
train_feeds: Optional[ArgFeeds] = None,
do_fill_nans: bool = True,
ta_features: Optional[List[str]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame, np.ndarray]:
"""
@description
Expand Down Expand Up @@ -116,6 +118,27 @@ def create_xy(
x_dim_len = len(train_feeds_list) * ss.autoregressive_n
diff = 0 if ss.transform == "None" else 1

features: List[pd.Series] = []
if ta_features:
for feed in train_feeds_list:
# Generate feed keys
feed_keys = {
key: f"{feed.exchange}:{feed.pair}:{key}"
for key in ["close", "open", "high", "low", "volume"]
}

for feature in ta_features:
ta_class = get_indicator.get_ta_indicator(feature)
if ta_class is None:
raise ValueError(f"Unknown TA feature: {feature}")

ta = ta_class(mergedohlcv_df.to_pandas(), **feed_keys)
features.append(ta.calculate())

# Verify the results
num_features = len(ta_features) * len(train_feeds_list)
assert len(features) == num_features
assert len(features[0]) == len(mergedohlcv_df)
# main work
xcol_list = [] # [col_i] : name_str
x_list = [] # [col_i] : Series. Build this up. Not df here (slow)
Expand Down Expand Up @@ -152,6 +175,16 @@ def create_xy(
x_col = hist_col + f":(z(t-{ds1})-z(t-{ds11}))/z(t-{ds11})"
xcol_list += [x_col]

for i, feature in enumerate(features):
assert type(feature) == pd.Series # type check for mypy
feature_np = list(feature.values)
features_shifted = pd.Series(
_slice(feature_np, -shift - N_train - 1, -shift)
)
x_list += [features_shifted]
xrecent_list += [pd.Series(_slice(feature_np, -shift, -shift + 1))]
xcol_list.append(f"{feature.name}_t-{ds1}-{i}")

# convert x lists to dfs, all at once. Faster than building up df.
assert len(x_list) == len(xrecent_list) == len(xcol_list)
x_df = pd.concat(x_list, keys=xcol_list, axis=1)
Expand Down Expand Up @@ -181,7 +214,8 @@ def create_xy(
# postconditions
assert X.shape[0] == yraw.shape[0] == ytran.shape[0]
assert X.shape[0] <= (N_train + 1)
assert X.shape[1] == x_dim_len
feature_dims = len(features) * len(train_feeds_list) * ss.autoregressive_n
assert X.shape[1] == x_dim_len + feature_dims
assert isinstance(x_df, pd.DataFrame)

assert "timestamp" not in x_df.columns
Expand Down
36 changes: 27 additions & 9 deletions pdr_backend/cli/predict_train_feedset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright 2024 Ocean Protocol Foundation
# SPDX-License-Identifier: Apache-2.0
#
from typing import List
from typing import List, Dict

from enforce_typing import enforce_types
from typeguard import check_type
Expand All @@ -21,29 +21,44 @@ class PredictTrainFeedset:
"""

@enforce_types
def __init__(self, predict: ArgFeed, train_on: ArgFeeds):
def __init__(
self,
predict: ArgFeed,
train_on: ArgFeeds,
ta_features: List[str] = [],
):
self.predict: ArgFeed = predict
self.train_on: ArgFeeds = train_on
self.ta_features: List[str] = ta_features if ta_features else []

@enforce_types
def __str__(self) -> str:
return str(self.to_dict())

@enforce_types
def __eq__(self, other):
return self.predict == other.predict and self.train_on == other.train_on
return (
self.predict == other.predict
and self.train_on == other.train_on
and self.ta_features == other.ta_features
)

@enforce_types
def to_dict(self):
return {"predict": str(self.predict), "train_on": str(self.train_on)}
def to_dict(self) -> Dict:
return {
"predict": str(self.predict),
"train_on": str(self.train_on),
"ta_features": self.ta_features,
}

@classmethod
def from_dict(cls, feedset_dict: dict) -> "PredictTrainFeedset":
"""
@arguments
feedset_dict -- has the following format:
{"predict":predict_feed_str (1 feed),
"train_on":train_on_feeds_str (>=1 feeds)}
"train_on":train_on_feeds_str (>=1 feeds),
"ta_features":list of extra features}
Note just ONE predict feed is allowed, not >=1.

Here are three examples. from_dict() gives the same output for each.
Expand All @@ -52,11 +67,14 @@ def from_dict(cls, feedset_dict: dict) -> "PredictTrainFeedset":
2. { "predict" : "binance BTC/USDT o 1h",
"train_on" : "binance BTC/USDT o 1h, binance ETH/USDT o 1h"}
3. { "predict" : "binance BTC/USDT o 1h",
"train_on" : ["binance BTC/USDT o 1h", "binance ETH/USDT o 1h"]}
"train_on" : ["binance BTC/USDT o 1h", "binance ETH/USDT o 1h"],
"ta_features": ["rsi", "macd"]}
"""
predict = ArgFeed.from_str(feedset_dict["predict"])
train_on = ArgFeeds.from_strs(_as_list(feedset_dict["train_on"]))
return cls(predict, train_on)
train_on = ArgFeeds.from_strs(_as_list(feedset_dict.get("train_on")))
ta_features = feedset_dict.get("ta_features", [])
check_type(ta_features, List[str])
return cls(predict, train_on, ta_features)

@property
def timeframe_ms(self) -> int:
Expand Down
12 changes: 9 additions & 3 deletions pdr_backend/cli/predict_train_feedsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets":
@arguments
feedset_list -- list of feedset_dict,
where feedset_dict has the following format:
{"predict":predict_feeds_str,
"train_on":train_on_feeds_str}
{
"predict": predict_feeds_str,
"train_on": train_on_feeds_str,
"ta_features": ["feature1", "feature2"]
}
Note that >=1 predict feeds are allowed for a given feedset_dict.

Example feedset_list = [
Expand All @@ -45,10 +48,12 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets":
"binance BTC/USDT ETH/USDT DOT/USDT c 5m",
"kraken BTC/USDT c 5m",
],
"ta_features": ["macd", "rsi"]
},
{
"predict": "binance ETH/USDT ADA/USDT c 5m",
"train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m",
"ta_features": ["ema", "rvi"]
},
"""
final_list = []
Expand All @@ -57,9 +62,10 @@ def from_list_of_dict(cls, feedset_list: List[dict]) -> "PredictTrainFeedsets":
raise ValueError(feedset_dict)

predict_feeds: ArgFeeds = parse_feed_obj(feedset_dict["predict"])
ta_features = feedset_dict.get("ta_features", [])
for predict in predict_feeds:
train_on = parse_feed_obj(feedset_dict["train_on"])
feedset = PredictTrainFeedset(predict, train_on)
feedset = PredictTrainFeedset(predict, train_on, ta_features)
final_list.append(feedset)

return cls(final_list)
Expand Down
17 changes: 11 additions & 6 deletions pdr_backend/cli/test/test_predict_train_feedset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ def test_feedset_main():
assert feedset.train_on == ARG_FEEDS
assert feedset.timeframe_ms == ARG_FEED.timeframe.ms

assert feedset.to_dict() == {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR}
assert feedset.to_dict() == {
"predict": ARG_FEED_STR,
"train_on": ARG_FEEDS_STR,
"ta_features": [],
}

assert (
str(feedset)
== "{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h'}"
# pylint: disable=line-too-long
== "{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h', 'ta_features': []}"
)


Expand Down Expand Up @@ -62,24 +67,24 @@ def test_feedset_eq_diff():
@enforce_types
def test_feedset_from_dict():
# "train_on" as str
d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR}
d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR, "ta_features": []}
feedset = PredictTrainFeedset.from_dict(d)
assert feedset.predict == ARG_FEED
assert feedset.train_on == ARG_FEEDS
assert feedset.to_dict() == d

# "train_on" as list
d = {"predict": ARG_FEED_STR, "train_on": [ARG_FEEDS_STR]}
d = {"predict": ARG_FEED_STR, "train_on": [ARG_FEEDS_STR], "ta_features": []}
feedset = PredictTrainFeedset.from_dict(d)
assert feedset.predict == ARG_FEED
assert feedset.train_on == ARG_FEEDS

# "predict" value must be a str
d = {"predict": ARG_FEED, "train_on": ARG_FEEDS_STR}
d = {"predict": ARG_FEED, "train_on": ARG_FEEDS_STR, "ta_features": []}
with pytest.raises(TypeError):
feedset = PredictTrainFeedset.from_dict(d)

# "train_on" value must be a str
d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS}
d = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS, "ta_features": []}
with pytest.raises(TypeCheckError):
feedset = PredictTrainFeedset.from_dict(d)
9 changes: 7 additions & 2 deletions pdr_backend/cli/test/test_predict_train_feedsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ARG_FEEDS: ArgFeeds = ArgFeeds.from_str(ARG_FEEDS_STR)

# ("predict", "train_on") set
FEEDSET_DICT = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR}
FEEDSET_DICT = {"predict": ARG_FEED_STR, "train_on": ARG_FEEDS_STR, "ta_features": []}
FEEDSET = PredictTrainFeedset(predict=ARG_FEED, train_on=ARG_FEEDS)


Expand All @@ -38,7 +38,8 @@ def test_feedsets_1_feedset():
assert feedsets == PredictTrainFeedsets([FEEDSET])
assert (
str(feedsets)
== "[{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h'}]"
# pylint: disable=line-too-long
== "[{'predict': 'binance BTC/USDT o 1h', 'train_on': 'binance BTC/USDT ETH/USDT o 1h', 'ta_features': []}]"
)

feedsets2 = PredictTrainFeedsets.from_list_of_dict([FEEDSET_DICT])
Expand Down Expand Up @@ -102,18 +103,22 @@ def test_feedsets_from_list_of_dict__thorough():
{
"predict": "binance BTC/USDT c 5m",
"train_on": "binance BTC/USDT DOT/USDT ETH/USDT c 5m, kraken BTC/USDT c 5m",
"ta_features": [],
},
{
"predict": "kraken BTC/USDT c 5m",
"train_on": "binance BTC/USDT DOT/USDT ETH/USDT c 5m, kraken BTC/USDT c 5m",
"ta_features": [],
},
{
"predict": "binance ETH/USDT c 5m",
"train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m",
"ta_features": [],
},
{
"predict": "binance ADA/USDT c 5m",
"train_on": "binance BTC/USDT DOT/USDT c 5m, kraken BTC/USDT c 5m",
"ta_features": [],
},
]
assert feedset_list2 == target_feedset_list2
Expand Down
2 changes: 2 additions & 0 deletions pdr_backend/sim/sim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,15 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):
data_f = AimodelDataFactory(pdr_ss) # type: ignore[arg-type]
predict_feed = self.predict_train_feedset.predict
train_feeds = self.predict_train_feedset.train_on
features = self.predict_train_feedset.ta_features

# X, ycont, and x_df are all expressed in % change wrt prev candle
X, ytran, yraw, x_df, _ = data_f.create_xy(
mergedohlcv_df,
testshift,
predict_feed,
train_feeds,
ta_features=features,
)
colnames = list(x_df.columns)

Expand Down
16 changes: 16 additions & 0 deletions pdr_backend/technical_indicators/get_indicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional, Type
from pdr_backend.technical_indicators.indicators.macd import MACD
from pdr_backend.technical_indicators.indicators.rsi import RSI
from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator

indicators = {
"rsi": RSI,
"macd": MACD,
}


def get_ta_indicator(indicator: str) -> Optional[Type[TechnicalIndicator]]:
"""
Returns the technical indicator class based on the input indicator name.
"""
return indicators.get(indicator)
24 changes: 24 additions & 0 deletions pdr_backend/technical_indicators/indicators/macd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd
import ta
from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator


class MACD(TechnicalIndicator):
"""
Moving Average Convergence Divergence (MACD) technical indicator.
"""

def calculate(self, *args, **kwargs) -> pd.Series:
"""
Calculates the MACD value based on the input data.

@param:
window_fast - The window size for the fast EMA calculation (default=12).
window_slow - The window size for the slow EMA calculation (default=26).
"""
window_fast = kwargs.get("window_fast", 12)
window_slow = kwargs.get("window_slow", 26)
macd = ta.trend.MACD(
close=self._close(), window_fast=window_fast, window_slow=window_slow
)
return macd.macd()
20 changes: 20 additions & 0 deletions pdr_backend/technical_indicators/indicators/rsi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
import ta
from pdr_backend.technical_indicators.technical_indicator import TechnicalIndicator


class RSI(TechnicalIndicator):
"""
Relative Strength Index (RSI) technical indicator.
"""

def calculate(self, *args, **kwargs) -> pd.Series:
"""
Calculates the RSI value based on the input data.

@param:
window - The window size for the RSI calculation (default=14).
"""
window = kwargs.get("window", 14)
rsi = ta.momentum.RSIIndicator(close=self._close(), window=window).rsi()
return rsi
Loading
Loading