Skip to content

Commit

Permalink
Fix #1346: [PPSS] Add setters & input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
trentmc committed Jul 5, 2024
1 parent 02bb11b commit 946e30d
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 19 deletions.
42 changes: 35 additions & 7 deletions pdr_backend/ppss/aimodel_data_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,26 @@ def __init__(self, d: dict):
self.d = d

# test inputs
if not 0 < self.max_n_train:
raise ValueError(self.max_n_train)
if not 0 < self.autoregressive_n < np.inf:
raise ValueError(self.autoregressive_n)
if self.transform not in TRANSFORM_OPTIONS:
raise ValueError(self.transform)
self.validate_max_n_train(self.max_n_train)
self.validate_autoregressive_n(self.autoregressive_n)
self.validate_transform(self.transform)

# --------------------------------
# validators
@staticmethod
def validate_max_n_train(max_n_train: int):
if not 0 < max_n_train:
raise ValueError(max_n_train)

@staticmethod
def validate_autoregressive_n(autoregressive_n: int):
if not 0 < autoregressive_n < np.inf:
raise ValueError(autoregressive_n)

@staticmethod
def validate_transform(transform: str):
if transform not in TRANSFORM_OPTIONS:
raise ValueError(transform)

# --------------------------------
# yaml properties
Expand All @@ -50,10 +64,24 @@ def autoregressive_n(self) -> int:
return self.d["autoregressive_n"]

@property
def transform(self) -> int:
def transform(self) -> str:
"""eg 'RelDiff'"""
return self.d["transform"]

# --------------------------------
# setters
def set_max_n_train(self, max_n_train: int):
self.validate_max_n_train(max_n_train)
self.d["max_n_train"] = max_n_train

def set_autoregressive_n(self, autoregressive_n: int):
self.validate_autoregressive_n(autoregressive_n)
self.d["autoregressive_n"] = autoregressive_n

def set_transform(self, transform: str):
self.validate_transform(transform)
self.d["transform"] = transform


# =========================================================================
# utilities for testing
Expand Down
18 changes: 17 additions & 1 deletion pdr_backend/ppss/aimodel_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def __init__(self, d: dict):
raise ValueError(self.calibrate_probs)
if self.calibrate_regr not in CALIBRATE_REGR_OPTIONS:
raise ValueError(self.calibrate_regr)
self.validate_train_every_n_epochs(self.train_every_n_epochs)

# --------------------------------
# validators -- add as needed, when setters are added
def validate_train_every_n_epochs(self, n: int):
if n <= 0:
raise ValueError(n)

# --------------------------------
# yaml properties
Expand Down Expand Up @@ -135,6 +142,12 @@ def weight_recent_n(self) -> Tuple[int, int]:
return 10000, 0
raise ValueError(self.weight_recent)

# --------------------------------
# setters (only add as needed)
def set_train_every_n_epochs(self, n: int):
self.validate_train_every_n_epochs(n)
self.d["train_every_n_epochs"] = n


# =========================================================================
# utilities for testing
Expand All @@ -147,14 +160,17 @@ def aimodel_ss_test_dict(
balance_classes: Optional[str] = None,
calibrate_probs: Optional[str] = None,
calibrate_regr: Optional[str] = None,
train_every_n_epochs: Optional[int] = None,
) -> dict:
"""Use this function's return dict 'd' to construct AimodelSS(d)"""
d = {
"approach": approach or "ClassifLinearRidge",
"weight_recent": weight_recent or "10x_5x",
"balance_classes": balance_classes or "SMOTE",
"train_every_n_epochs": 1,
"calibrate_probs": calibrate_probs or "CalibratedClassifierCV_Sigmoid",
"calibrate_regr": calibrate_regr or "None",
"train_every_n_epochs": (
1 if train_every_n_epochs is None else train_every_n_epochs
),
}
return d
17 changes: 11 additions & 6 deletions pdr_backend/ppss/lake_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,19 @@ def __str__(self) -> str:


@enforce_types
def lake_ss_test_dict(lake_dir: str, feeds: Optional[list] = None):
def lake_ss_test_dict(
lake_dir: str,
feeds: Optional[list] = None,
st_timestr: Optional[str] = None,
fin_timestr: Optional[str] = None,
timeframe: Optional[str] = None,
):
"""Use this function's return dict 'd' to construct LakeSS(d)"""
feeds = feeds or ["binance BTC/USDT c 5m"]
d = {
"feeds": feeds,
"feeds": feeds or ["binance BTC/USDT c 5m"],
"lake_dir": lake_dir,
"st_timestr": "2023-06-18",
"fin_timestr": "2023-06-30",
"timeframe": "5m",
"st_timestr": st_timestr or "2023-06-18",
"fin_timestr": fin_timestr or "2023-06-30",
"timeframe": timeframe or "5m",
}
return d
2 changes: 1 addition & 1 deletion pdr_backend/ppss/predictoor_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def predictoor_ss_test_dict(
"approach": 1,
"stake_amount": 1,
"sim_only": {
"others_stake": 2313,
"others_stake": 2313.0,
"others_accuracy": 0.50001,
"revenue": 0.93007,
},
Expand Down
24 changes: 20 additions & 4 deletions pdr_backend/ppss/sim_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,21 @@ def __init__(self, d: dict):
s = f"Couldn't find log dir, so created one at: {self.log_dir}"
logger.warning(s)

# check test_n
test_n = d["test_n"]
# validate data
self.validate_test_n(self.test_n)
self.validate_tradetype(self.tradetype)

# --------------------------------
# validators
@staticmethod
def validate_test_n(test_n: int):
if not isinstance(test_n, int):
raise TypeError(test_n)
if not 0 < test_n < np.inf:
raise ValueError(test_n)

# check tradetype
tradetype = d["tradetype"]
@staticmethod
def validate_tradetype(tradetype: str):
if not isinstance(tradetype, str):
raise TypeError(tradetype)
if tradetype not in TRADETYPE_OPTIONS:
Expand Down Expand Up @@ -70,6 +76,16 @@ def is_final_iter(self, iter_i: int) -> bool:
raise ValueError(iter_i)
return (iter_i + 1) == self.test_n

# --------------------------------
# setters
def set_test_n(self, test_n: int):
self.validate_test_n(test_n)
self.d["test_n"] = test_n

def set_tradetype(self, tradetype: str):
self.validate_tradetype(tradetype)
self.d["tradetype"] = tradetype


# =========================================================================
# utilities for testing
Expand Down
28 changes: 28 additions & 0 deletions pdr_backend/ppss/test/test_aimodel_data_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,31 @@ def test_aimodel_data_ss__bad_inputs():

with pytest.raises(TypeError):
AimodelDataSS(aimodel_data_ss_test_dict(transform=3.1))


@enforce_types
def test_aimodel_data_ss__setters():
d = aimodel_data_ss_test_dict()
ss = AimodelDataSS(d)

# max_n_train
ss.set_max_n_train(32)
assert ss.max_n_train == 32
with pytest.raises(ValueError):
ss.set_max_n_train(0)
with pytest.raises(ValueError):
ss.set_max_n_train(-5)

# autoregressive_n
ss.set_autoregressive_n(12)
assert ss.autoregressive_n == 12
with pytest.raises(ValueError):
ss.set_autoregressive_n(0)
with pytest.raises(ValueError):
ss.set_autoregressive_n(-5)

# transform
ss.set_transform("RelDiff")
assert ss.transform == "RelDiff"
with pytest.raises(ValueError):
ss.set_transform("foo")
24 changes: 24 additions & 0 deletions pdr_backend/ppss/test/test_aimodel_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_aimodel_ss__default_values():
ss.calibrate_probs == d["calibrate_probs"] == "CalibratedClassifierCV_Sigmoid"
)
assert ss.calibrate_regr == d["calibrate_regr"] == "None"
assert ss.train_every_n_epochs == d["train_every_n_epochs"] == 1

# str
assert "AimodelSS" in str(ss)
Expand Down Expand Up @@ -69,6 +70,9 @@ def test_aimodel_ss__nondefault_values():
ss = AimodelSS(aimodel_ss_test_dict(calibrate_regr=calibrate_regr))
assert ss.calibrate_regr == calibrate_regr and calibrate_regr in str(ss)

ss = AimodelSS(aimodel_ss_test_dict(train_every_n_epochs=44))
assert ss.train_every_n_epochs == 44


@enforce_types
def test_aimodel_ss__bad_inputs():
Expand All @@ -88,6 +92,12 @@ def test_aimodel_ss__bad_inputs():
with pytest.raises(ValueError):
AimodelSS(aimodel_ss_test_dict(calibrate_regr="foo"))

with pytest.raises(ValueError):
AimodelSS(aimodel_ss_test_dict(train_every_n_epochs=0))

with pytest.raises(ValueError):
AimodelSS(aimodel_ss_test_dict(train_every_n_epochs=-5))


@enforce_types
def test_aimodel_ss__calibrate_probs_skmethod():
Expand All @@ -100,3 +110,17 @@ def test_aimodel_ss__calibrate_probs_skmethod():
ss = AimodelSS(d)
assert ss.calibrate_probs_skmethod(100) == "sigmoid" # because N is small
assert ss.calibrate_probs_skmethod(1000) == "isotonic"


@enforce_types
def test_aimodel_ss__setters():
d = aimodel_ss_test_dict()
ss = AimodelSS(d)

# train_every_n_epochs
ss.set_train_every_n_epochs(77)
assert ss.train_every_n_epochs == 77
with pytest.raises(ValueError):
ss.set_train_every_n_epochs(0)
with pytest.raises(ValueError):
ss.set_train_every_n_epochs(-5)
14 changes: 14 additions & 0 deletions pdr_backend/ppss/test/test_lake_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,17 @@ def test_lake_ss_test_dict_2_specify_feeds(tmpdir):
d = lake_ss_test_dict(lake_dir, feeds)
assert d["lake_dir"] == lake_dir
assert d["feeds"] == feeds


@enforce_types
def test_lake_ss_test_dict_3_nondefault_time_settings(tmpdir):
lake_dir = os.path.join(tmpdir, "lake_data")
d = lake_ss_test_dict(
lake_dir,
st_timestr="2023-01-20",
fin_timestr="2023-01-21",
timeframe="1h",
)
assert d["st_timestr"] == "2023-01-20"
assert d["fin_timestr"] == "2023-01-21"
assert d["timeframe"] == "1h"
20 changes: 20 additions & 0 deletions pdr_backend/ppss/test/test_sim_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ def test_sim_ss_is_final_iter(tmpdir):
_ = ss.is_final_iter(11)


@enforce_types
def test_sim_ss_setters(tmpdir):
d = sim_ss_test_dict(_logdir(tmpdir))
ss = SimSS(d)

# test_n
ss.set_test_n(32)
assert ss.test_n == 32
with pytest.raises(ValueError):
ss.set_test_n(0)
with pytest.raises(ValueError):
ss.set_test_n(-5)

# tradetype
ss.set_tradetype("livereal")
assert ss.tradetype == "livereal"
with pytest.raises(ValueError):
ss.set_tradetype("foo")


# ====================================================================
# helper funcs
@enforce_types
Expand Down

0 comments on commit 946e30d

Please sign in to comment.