From b81744b6e34de5ac80f1d8aa52fc74d940a1702f Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 12 Sep 2024 18:20:30 +0900 Subject: [PATCH] gops_params_override option --- experiments/cf_simple.py | 2 ++ src/emevo/exp_utils.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index 9e2d7e6..0d56259 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -402,6 +402,7 @@ def evolve( env_override: str = "", birth_override: str = "", hazard_override: str = "", + gops_params_override: str = "", logdir: Path = Path("./log"), log_mode: LogMode = LogMode.REWARD_LOG_STATE, log_interval: int = 1000, @@ -424,6 +425,7 @@ def evolve( cfconfig.apply_override(env_override) bdconfig.apply_birth_override(birth_override) bdconfig.apply_hazard_override(hazard_override) + gopsconfig.apply_params_override(gops_params_override) # Load models birth_fn, hazard_fn = bdconfig.load_models() diff --git a/src/emevo/exp_utils.py b/src/emevo/exp_utils.py index 9d61be9..d90ee38 100644 --- a/src/emevo/exp_utils.py +++ b/src/emevo/exp_utils.py @@ -147,6 +147,12 @@ def load_model(self) -> gops.Mutation | gops.Crossover: params[k] = v return _load_cls(self.path)(**params) + def apply_params_override(self, override: str) -> None: + if 0 < len(override): + override_dict = json.loads(override) + for key, value in override_dict.items(): + self.params[key] = value + @chex.dataclass class Log: