From 32ad8cdb51ce6509a0f51f48953f755219edc028 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Fri, 5 Feb 2021 01:04:42 +0900 Subject: [PATCH 1/7] Add MarketCloseAwareDynamicNeutral --- .../labelizer/ternary/dynamic_neutral.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/backlight/labelizer/ternary/dynamic_neutral.py b/src/backlight/labelizer/ternary/dynamic_neutral.py index 3ae0f1d..4f1f21a 100644 --- a/src/backlight/labelizer/ternary/dynamic_neutral.py +++ b/src/backlight/labelizer/ternary/dynamic_neutral.py @@ -46,3 +46,33 @@ def neutral_window(self) -> str: @property def neutral_hard_limit(self) -> str: return self._params["neutral_hard_limit"] + + +class MarketCloseAwareDynamicNeutralLabelizer(DynamicNeutralLabelizer): + def _calculate_dynamic_neutral_range(self, diff_abs: pd.Series) -> pd.Series: + + df = pd.DataFrame(diff_abs, columns=["res"]) + df.loc[:, "est"] = df.index.tz_convert("America/New_York") + freq = int( + pd.Timedelta(self._params["neutral_window"]) + / pd.Timedelta(diff_abs.index.freq) + ) + + mask = ( + ~((df.est.dt.hour <= 17) & (df.est.dt.dayofweek == 6)) + & ((df.est.dt.hour < 16) | (df.est.dt.hour > 17)) + & ~((df.est.dt.hour >= 16) & (df.est.dt.dayofweek == 4)) + & (df.est.dt.dayofweek != 5) + ) + + dnr = ( + df.loc[mask, "res"] + .rolling(freq) + .quantile(self.neutral_ratio) + .reindex(diff_abs.index) + .fillna(100) + ) + + dnr[dnr < self.neutral_hard_limit] = self.neutral_hard_limit + + return dnr From e927132a65741a0e6cd081b3d3f179254d5b9bc7 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Fri, 5 Feb 2021 19:46:24 +0900 Subject: [PATCH 2/7] add new 2 labelizers --- .../labelizer/ternary/hybrid_neutral.py | 48 +++++++++++ .../labelizer/ternary/static_neutral.py | 84 +++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 src/backlight/labelizer/ternary/hybrid_neutral.py create mode 100644 src/backlight/labelizer/ternary/static_neutral.py diff --git a/src/backlight/labelizer/ternary/hybrid_neutral.py b/src/backlight/labelizer/ternary/hybrid_neutral.py new file mode 100644 index 0000000..6603116 --- /dev/null +++ b/src/backlight/labelizer/ternary/hybrid_neutral.py @@ -0,0 +1,48 @@ +import pandas as pd + +from backlight.datasource.marketdata import MarketData +from backlight.labelizer.common import LabelType, TernaryDirection +from backlight.labelizer.labelizer import Label +from backlight.labelizer.ternary.static_neutral import StaticNeutralLabelizer +from backlight.labelizer.ternary.dynamic_neutral import ( + MarketCloseAwareDynamicNeutralLabelizer, +) + + +class HybridNeutralLabelizer( + StaticNeutralLabelizer, MarketCloseAwareDynamicNeutralLabelizer +): + def __init__(self, **kwargs: str) -> None: + super().__init__(**kwargs) + self.validate_params() + + def validate_params(self) -> None: + super(HybridNeutralLabelizer, self).validate_params() + super(MarketCloseAwareDynamicNeutralLabelizer, self).validate_params() + assert "alpha" in self._params + assert 0 <= self._params["alpha"] <= 1 + + def _calculate_hybrid_neutral_range(self, diff_abs: pd.Series) -> pd.Series: + snr = self._calculate_static_neutral_range(diff_abs) + dnr = self._calculate_dynamic_neutral_range(diff_abs) + return self.alpha * snr + (1 - self.alpha) * dnr + + def create(self, mkt: MarketData) -> pd.DataFrame: + mid = mkt.mid.copy() + future_price = mid.shift(freq="-{}".format(self._params["lookahead"])) + diff = (future_price - mid).reindex(mid.index) + diff_abs = diff.abs() + neutral_range = self._calculate_hybrid_neutral_range(diff_abs) + df = mid.to_frame("mid") + df.loc[:, "label_diff"] = diff + df.loc[:, "neutral_range"] = neutral_range + df.loc[df.label_diff > 0, "label"] = TernaryDirection.UP.value + df.loc[df.label_diff < 0, "label"] = TernaryDirection.DOWN.value + df.loc[diff_abs < neutral_range, "label"] = TernaryDirection.NEUTRAL.value + df = Label(df[["label_diff", "label", "neutral_range"]]) + df.label_type = LabelType.TERNARY + return df + + @property + def alpha(self) -> float: + return self._params["alpha"] diff --git a/src/backlight/labelizer/ternary/static_neutral.py b/src/backlight/labelizer/ternary/static_neutral.py new file mode 100644 index 0000000..b8b4a06 --- /dev/null +++ b/src/backlight/labelizer/ternary/static_neutral.py @@ -0,0 +1,84 @@ +import pandas as pd +import numpy as np + +from backlight.datasource.marketdata import MarketData +from backlight.labelizer.common import LabelType, TernaryDirection +from backlight.labelizer.labelizer import Labelizer, Label + + +class StaticNeutralLabelizer(Labelizer): + """Generates session-aware static labels + + Args: + lookahead (str): Lookahead period + session_splits (list[datetime.time]): EST local time to split sessions + neutral_ratio (float): 0 < x < 1, Percentage of NEUTRAL labels + window_start (str): Start date for lookback window + window_end (str): End date for lookback window + neutral_hard_limit (float): The minimum diff to label UP/DOWN + """ + + def validate_params(self) -> None: + assert "lookahead" in self._params + assert "session_splits" in self._params + assert len(self._params["session_splits"]) + assert "neutral_ratio" in self._params + assert "window_start" in self._params + assert "window_end" in self._params + assert "neutral_hard_limit" in self._params + + def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: + df = pd.DataFrame(diff_abs, columns=["diff"]) + df.loc[:, "est"] = df.index.tz_convert("America/New_York") + df.loc[:, "res"] = np.nan + + mask = ( + (df.index >= self._params["window_start"]) + & (df.index < self._params["window_end"]) + & ~((df.est.dt.hour <= 17) & (df.est.dt.dayofweek == 6)) + & ((df.est.dt.hour < 16) | (df.est.dt.hour > 17)) + & ~((df.est.dt.hour >= 16) & (df.est.dt.dayofweek == 4)) + & (df.est.dt.dayofweek != 5) + ) + + splits = sorted(self._params["session_splits"]) + shifted_splits = splits[1:] + splits[:1] + + for s, t in list(zip(splits, shifted_splits)): + if s >= t: + scope = (df.est.dt.time >= s) | (df.est.dt.time < t) + else: + scope = (df.est.dt.time >= s) & (df.est.dt.time < t) + df.loc[scope, "res"] = df.loc[scope & mask, "diff"].quantile( + self.neutral_ratio + ) + + return df.res + + def create(self, mkt: MarketData) -> pd.DataFrame: + mid = mkt.mid.copy() + future_price = mid.shift(freq="-{}".format(self._params["lookahead"])) + diff = (future_price - mid).reindex(mid.index) + diff_abs = diff.abs() + neutral_range = self._calculate_static_neutral_range(diff_abs) + df = mid.to_frame("mid") + df.loc[:, "label_diff"] = diff + df.loc[:, "neutral_range"] = neutral_range + df.loc[df.label_diff > 0, "label"] = TernaryDirection.UP.value + df.loc[df.label_diff < 0, "label"] = TernaryDirection.DOWN.value + df.loc[diff_abs < neutral_range, "label"] = TernaryDirection.NEUTRAL.value + df = Label(df[["label_diff", "label", "neutral_range"]]) + df.label_type = LabelType.TERNARY + return df + + @property + def neutral_ratio(self) -> str: + return self._params["neutral_ratio"] + + @property + def session_splits(self) -> str: + return self._params["session_splits"] + + @property + def neutral_hard_limit(self) -> str: + return self._params["neutral_hard_limit"] From 1e477c38c9af3c601bf1b512ab9a2479ec3252c0 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Fri, 5 Feb 2021 20:15:13 +0900 Subject: [PATCH 3/7] modify fillna() and type(float) --- src/backlight/labelizer/ternary/dynamic_neutral.py | 2 +- src/backlight/labelizer/ternary/hybrid_neutral.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/backlight/labelizer/ternary/dynamic_neutral.py b/src/backlight/labelizer/ternary/dynamic_neutral.py index 4f1f21a..0659573 100644 --- a/src/backlight/labelizer/ternary/dynamic_neutral.py +++ b/src/backlight/labelizer/ternary/dynamic_neutral.py @@ -70,7 +70,7 @@ def _calculate_dynamic_neutral_range(self, diff_abs: pd.Series) -> pd.Series: .rolling(freq) .quantile(self.neutral_ratio) .reindex(diff_abs.index) - .fillna(100) + .ffill() ) dnr[dnr < self.neutral_hard_limit] = self.neutral_hard_limit diff --git a/src/backlight/labelizer/ternary/hybrid_neutral.py b/src/backlight/labelizer/ternary/hybrid_neutral.py index 6603116..f49a555 100644 --- a/src/backlight/labelizer/ternary/hybrid_neutral.py +++ b/src/backlight/labelizer/ternary/hybrid_neutral.py @@ -20,7 +20,7 @@ def validate_params(self) -> None: super(HybridNeutralLabelizer, self).validate_params() super(MarketCloseAwareDynamicNeutralLabelizer, self).validate_params() assert "alpha" in self._params - assert 0 <= self._params["alpha"] <= 1 + assert 0 <= float(self._params["alpha"]) <= 1 def _calculate_hybrid_neutral_range(self, diff_abs: pd.Series) -> pd.Series: snr = self._calculate_static_neutral_range(diff_abs) @@ -45,4 +45,4 @@ def create(self, mkt: MarketData) -> pd.DataFrame: @property def alpha(self) -> float: - return self._params["alpha"] + return float(self._params["alpha"]) From 936eca83d9072e92912dfa637db91091d3005370 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Fri, 5 Feb 2021 20:44:44 +0900 Subject: [PATCH 4/7] apply neutral_hard_limit to StaticNeutral --- src/backlight/labelizer/ternary/static_neutral.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/backlight/labelizer/ternary/static_neutral.py b/src/backlight/labelizer/ternary/static_neutral.py index b8b4a06..50345e7 100644 --- a/src/backlight/labelizer/ternary/static_neutral.py +++ b/src/backlight/labelizer/ternary/static_neutral.py @@ -53,7 +53,10 @@ def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: self.neutral_ratio ) - return df.res + snr = df.res + snr[snr < self.neutral_hard_limit] = self.neutral_hard_limit + + return snr def create(self, mkt: MarketData) -> pd.DataFrame: mid = mkt.mid.copy() From 5b5e6c425a64b8329f1cb5b9f1a1922b916cf269 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Sat, 6 Feb 2021 01:07:01 +0900 Subject: [PATCH 5/7] Add unittest with small bug fix --- .../labelizer/ternary/dynamic_neutral.py | 2 +- .../labelizer/ternary/static_neutral.py | 9 ++- .../test_labelizer_ternary_dynamic_neutral.py | 59 +++++++++++++++++ .../test_labelizer_ternary_hybrid_neutral.py | 63 +++++++++++++++++++ .../test_labelizer_ternary_static_neutral.py | 61 ++++++++++++++++++ 5 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 tests/labelizer/test_labelizer_ternary_dynamic_neutral.py create mode 100644 tests/labelizer/test_labelizer_ternary_hybrid_neutral.py create mode 100644 tests/labelizer/test_labelizer_ternary_static_neutral.py diff --git a/src/backlight/labelizer/ternary/dynamic_neutral.py b/src/backlight/labelizer/ternary/dynamic_neutral.py index 0659573..e9aa80a 100644 --- a/src/backlight/labelizer/ternary/dynamic_neutral.py +++ b/src/backlight/labelizer/ternary/dynamic_neutral.py @@ -51,7 +51,7 @@ def neutral_hard_limit(self) -> str: class MarketCloseAwareDynamicNeutralLabelizer(DynamicNeutralLabelizer): def _calculate_dynamic_neutral_range(self, diff_abs: pd.Series) -> pd.Series: - df = pd.DataFrame(diff_abs, columns=["res"]) + df = pd.DataFrame(diff_abs.values, index=diff_abs.index, columns=["res"]) df.loc[:, "est"] = df.index.tz_convert("America/New_York") freq = int( pd.Timedelta(self._params["neutral_window"]) diff --git a/src/backlight/labelizer/ternary/static_neutral.py b/src/backlight/labelizer/ternary/static_neutral.py index 50345e7..8ebc95d 100644 --- a/src/backlight/labelizer/ternary/static_neutral.py +++ b/src/backlight/labelizer/ternary/static_neutral.py @@ -28,7 +28,7 @@ def validate_params(self) -> None: assert "neutral_hard_limit" in self._params def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: - df = pd.DataFrame(diff_abs, columns=["diff"]) + df = pd.DataFrame(diff_abs.values, index=diff_abs.index, columns=["diff"]) df.loc[:, "est"] = df.index.tz_convert("America/New_York") df.loc[:, "res"] = np.nan @@ -49,14 +49,13 @@ def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: scope = (df.est.dt.time >= s) | (df.est.dt.time < t) else: scope = (df.est.dt.time >= s) & (df.est.dt.time < t) - df.loc[scope, "res"] = df.loc[scope & mask, "diff"].quantile( + df.loc[scope, "res"] = df.loc[(scope & mask), "diff"].quantile( self.neutral_ratio ) - snr = df.res - snr[snr < self.neutral_hard_limit] = self.neutral_hard_limit + df.loc[(df.res < self.neutral_hard_limit), "res"] = self.neutral_hard_limit - return snr + return df.res def create(self, mkt: MarketData) -> pd.DataFrame: mid = mkt.mid.copy() diff --git a/tests/labelizer/test_labelizer_ternary_dynamic_neutral.py b/tests/labelizer/test_labelizer_ternary_dynamic_neutral.py new file mode 100644 index 0000000..dfd5f46 --- /dev/null +++ b/tests/labelizer/test_labelizer_ternary_dynamic_neutral.py @@ -0,0 +1,59 @@ +from backlight.labelizer.ternary.dynamic_neutral import ( + MarketCloseAwareDynamicNeutralLabelizer as module, +) + +import pytest +import pandas as pd +import numpy as np + + +@pytest.fixture +def sample_df(): + index = pd.date_range( + "2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" + ) + return pd.DataFrame( + index=index, + data=np.array( + [ + [109.68, 109.69, 109.685], + [109.585, 109.595, 109.59], + [109.525, 109.535, 109.53], + [109.6, 109.61, 109.605], + [109.695, 109.7, 109.6975], + [109.565, 109.705, 109.635], + [109.63, 109.685, 109.6575], + [109.555, 109.675, 109.615], + [109.7, 109.75, 109.725], + [109.67, 109.72, 109.695], + [109.66, 109.675, 109.6675], + [109.8, 109.815, 109.8075], + [109.565, 109.575, 109.57], + [109.535, 109.545, 109.54], + [109.32, 109.33, 109.325], + [109.27, 109.275, 109.2725], + [109.345, 109.355, 109.35], + [109.305, 109.315, 109.31], + [109.3, 109.31, 109.305], + [109.445, 109.46, 109.4525], + [109.42, 109.425, 109.4225], + [109.385, 109.395, 109.39], + [109.305, 109.315, 109.31], + [109.365, 109.375, 109.37], + [109.365, 109.375, 109.37], + ] + ), + columns=["bid", "ask", "mid"], + ) + + +def test_create(sample_df): + lbl_args = { + "lookahead": "1H", + "neutral_ratio": 0.5, + "neutral_window": "3H", + "neutral_hard_limit": 0.00, + } + lbl = module(**lbl_args).create(sample_df) + assert lbl.label.sum() == -3 + assert lbl.neutral_range.isna().sum() == 2 diff --git a/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py b/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py new file mode 100644 index 0000000..62318d8 --- /dev/null +++ b/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py @@ -0,0 +1,63 @@ +from backlight.labelizer.ternary.hybrid_neutral import HybridNeutralLabelizer as module + +import pytest +import pandas as pd +import numpy as np +import datetime + + +@pytest.fixture +def sample_df(): + index = pd.date_range( + "2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" + ) + return pd.DataFrame( + index=index, + data=np.array( + [ + [109.68, 109.69, 109.685], + [109.585, 109.595, 109.59], + [109.525, 109.535, 109.53], + [109.6, 109.61, 109.605], + [109.695, 109.7, 109.6975], + [109.565, 109.705, 109.635], + [109.63, 109.685, 109.6575], + [109.555, 109.675, 109.615], + [109.7, 109.75, 109.725], + [109.67, 109.72, 109.695], + [109.66, 109.675, 109.6675], + [109.8, 109.815, 109.8075], + [109.565, 109.575, 109.57], + [109.535, 109.545, 109.54], + [109.32, 109.33, 109.325], + [109.27, 109.275, 109.2725], + [109.345, 109.355, 109.35], + [109.305, 109.315, 109.31], + [109.3, 109.31, 109.305], + [109.445, 109.46, 109.4525], + [109.42, 109.425, 109.4225], + [109.385, 109.395, 109.39], + [109.305, 109.315, 109.31], + [109.365, 109.375, 109.37], + [109.365, 109.375, 109.37], + ] + ), + columns=["bid", "ask", "mid"], + ) + + +def test_create(sample_df): + lbl_args = { + "lookahead": "1H", + "neutral_ratio": 0.5, + "session_splits": [datetime.time(9), datetime.time(18)], + "neutral_window": "3H", # noqa Simple approach: lookahead = lookback + "neutral_hard_limit": 0.00, + "window_start": "20170904 12:00:00+0000", + "window_end": "20170905 06:00:00+0000", + "alpha": 0.5, + } + + lbl = module(**lbl_args).create(sample_df) + assert lbl.label.sum() == 1 + assert lbl.neutral_range.isna().sum() == 2 diff --git a/tests/labelizer/test_labelizer_ternary_static_neutral.py b/tests/labelizer/test_labelizer_ternary_static_neutral.py new file mode 100644 index 0000000..1a9dba6 --- /dev/null +++ b/tests/labelizer/test_labelizer_ternary_static_neutral.py @@ -0,0 +1,61 @@ +from backlight.labelizer.ternary.static_neutral import StaticNeutralLabelizer as module + +import pytest +import pandas as pd +import numpy as np +import datetime + + +@pytest.fixture +def sample_df(): + index = pd.date_range( + "2017-09-04 13:00:00+00:00", "2017-09-05 13:00:00+00:00", freq="1H" + ) + return pd.DataFrame( + index=index, + data=np.array( + [ + [109.68, 109.69, 109.685], + [109.585, 109.595, 109.59], + [109.525, 109.535, 109.53], + [109.6, 109.61, 109.605], + [109.695, 109.7, 109.6975], + [109.565, 109.705, 109.635], + [109.63, 109.685, 109.6575], + [109.555, 109.675, 109.615], + [109.7, 109.75, 109.725], + [109.67, 109.72, 109.695], + [109.66, 109.675, 109.6675], + [109.8, 109.815, 109.8075], + [109.565, 109.575, 109.57], + [109.535, 109.545, 109.54], + [109.32, 109.33, 109.325], + [109.27, 109.275, 109.2725], + [109.345, 109.355, 109.35], + [109.305, 109.315, 109.31], + [109.3, 109.31, 109.305], + [109.445, 109.46, 109.4525], + [109.42, 109.425, 109.4225], + [109.385, 109.395, 109.39], + [109.305, 109.315, 109.31], + [109.365, 109.375, 109.37], + [109.365, 109.375, 109.37], + ] + ), + columns=["bid", "ask", "mid"], + ) + + +def test_create(sample_df): + lbl_args = { + "lookahead": "1H", + "neutral_ratio": 0.5, + "session_splits": [datetime.time(9), datetime.time(18)], + "neutral_hard_limit": 0.00, + "window_start": "20170901 12:00:00+0000", + "window_end": "20170904 12:00:00+0000", + } + + lbl = module(**lbl_args).create(sample_df) + assert lbl.label.sum() == 1 + assert lbl.neutral_range.isna().sum() == 0 From 83cbc9b2a8a015731451f2129a617380e09587c4 Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Sat, 6 Feb 2021 01:12:12 +0900 Subject: [PATCH 6/7] fix unittest --- tests/labelizer/test_labelizer_ternary_static_neutral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/labelizer/test_labelizer_ternary_static_neutral.py b/tests/labelizer/test_labelizer_ternary_static_neutral.py index 1a9dba6..aab9ef1 100644 --- a/tests/labelizer/test_labelizer_ternary_static_neutral.py +++ b/tests/labelizer/test_labelizer_ternary_static_neutral.py @@ -52,8 +52,8 @@ def test_create(sample_df): "neutral_ratio": 0.5, "session_splits": [datetime.time(9), datetime.time(18)], "neutral_hard_limit": 0.00, - "window_start": "20170901 12:00:00+0000", - "window_end": "20170904 12:00:00+0000", + "window_start": "20170904 12:00:00+0000", + "window_end": "20170905 06:00:00+0000", } lbl = module(**lbl_args).create(sample_df) From a9ef72fae7748f62986fa9307dc5218f95aeedee Mon Sep 17 00:00:00 2001 From: Pivot19 Date: Tue, 9 Feb 2021 00:39:39 +0900 Subject: [PATCH 7/7] add NYK time variable name --- src/backlight/labelizer/ternary/dynamic_neutral.py | 10 +++++----- src/backlight/labelizer/ternary/static_neutral.py | 14 +++++++------- .../test_labelizer_ternary_hybrid_neutral.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/backlight/labelizer/ternary/dynamic_neutral.py b/src/backlight/labelizer/ternary/dynamic_neutral.py index e9aa80a..929874a 100644 --- a/src/backlight/labelizer/ternary/dynamic_neutral.py +++ b/src/backlight/labelizer/ternary/dynamic_neutral.py @@ -52,17 +52,17 @@ class MarketCloseAwareDynamicNeutralLabelizer(DynamicNeutralLabelizer): def _calculate_dynamic_neutral_range(self, diff_abs: pd.Series) -> pd.Series: df = pd.DataFrame(diff_abs.values, index=diff_abs.index, columns=["res"]) - df.loc[:, "est"] = df.index.tz_convert("America/New_York") + df.loc[:, "nyk_time"] = df.index.tz_convert("America/New_York") freq = int( pd.Timedelta(self._params["neutral_window"]) / pd.Timedelta(diff_abs.index.freq) ) mask = ( - ~((df.est.dt.hour <= 17) & (df.est.dt.dayofweek == 6)) - & ((df.est.dt.hour < 16) | (df.est.dt.hour > 17)) - & ~((df.est.dt.hour >= 16) & (df.est.dt.dayofweek == 4)) - & (df.est.dt.dayofweek != 5) + ~((df.nyk_time.dt.hour <= 17) & (df.nyk_time.dt.dayofweek == 6)) + & ((df.nyk_time.dt.hour < 16) | (df.nyk_time.dt.hour > 17)) + & ~((df.nyk_time.dt.hour >= 16) & (df.nyk_time.dt.dayofweek == 4)) + & (df.nyk_time.dt.dayofweek != 5) ) dnr = ( diff --git a/src/backlight/labelizer/ternary/static_neutral.py b/src/backlight/labelizer/ternary/static_neutral.py index 8ebc95d..60ac391 100644 --- a/src/backlight/labelizer/ternary/static_neutral.py +++ b/src/backlight/labelizer/ternary/static_neutral.py @@ -29,16 +29,16 @@ def validate_params(self) -> None: def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: df = pd.DataFrame(diff_abs.values, index=diff_abs.index, columns=["diff"]) - df.loc[:, "est"] = df.index.tz_convert("America/New_York") + df.loc[:, "nyk_time"] = df.index.tz_convert("America/New_York") df.loc[:, "res"] = np.nan mask = ( (df.index >= self._params["window_start"]) & (df.index < self._params["window_end"]) - & ~((df.est.dt.hour <= 17) & (df.est.dt.dayofweek == 6)) - & ((df.est.dt.hour < 16) | (df.est.dt.hour > 17)) - & ~((df.est.dt.hour >= 16) & (df.est.dt.dayofweek == 4)) - & (df.est.dt.dayofweek != 5) + & ~((df.nyk_time.dt.hour <= 17) & (df.nyk_time.dt.dayofweek == 6)) + & ((df.nyk_time.dt.hour < 16) | (df.nyk_time.dt.hour > 17)) + & ~((df.nyk_time.dt.hour >= 16) & (df.nyk_time.dt.dayofweek == 4)) + & (df.nyk_time.dt.dayofweek != 5) ) splits = sorted(self._params["session_splits"]) @@ -46,9 +46,9 @@ def _calculate_static_neutral_range(self, diff_abs: pd.Series) -> pd.Series: for s, t in list(zip(splits, shifted_splits)): if s >= t: - scope = (df.est.dt.time >= s) | (df.est.dt.time < t) + scope = (df.nyk_time.dt.time >= s) | (df.nyk_time.dt.time < t) else: - scope = (df.est.dt.time >= s) & (df.est.dt.time < t) + scope = (df.nyk_time.dt.time >= s) & (df.nyk_time.dt.time < t) df.loc[scope, "res"] = df.loc[(scope & mask), "diff"].quantile( self.neutral_ratio ) diff --git a/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py b/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py index 62318d8..3100df9 100644 --- a/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py +++ b/tests/labelizer/test_labelizer_ternary_hybrid_neutral.py @@ -51,7 +51,7 @@ def test_create(sample_df): "lookahead": "1H", "neutral_ratio": 0.5, "session_splits": [datetime.time(9), datetime.time(18)], - "neutral_window": "3H", # noqa Simple approach: lookahead = lookback + "neutral_window": "3H", "neutral_hard_limit": 0.00, "window_start": "20170904 12:00:00+0000", "window_end": "20170905 06:00:00+0000",