From d72c99e93363286d09eb0c908b150ab458cd263e Mon Sep 17 00:00:00 2001 From: Theresa Eimer Date: Wed, 5 Jun 2024 13:10:09 +0200 Subject: [PATCH] fix: merge issues --- arlbench/core/algorithms/dqn/dqn.py | 47 ----------------------------- 1 file changed, 47 deletions(-) diff --git a/arlbench/core/algorithms/dqn/dqn.py b/arlbench/core/algorithms/dqn/dqn.py index fbd0f9f31..f5f4655cd 100644 --- a/arlbench/core/algorithms/dqn/dqn.py +++ b/arlbench/core/algorithms/dqn/dqn.py @@ -252,53 +252,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace: return cs - @staticmethod - def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace: - """Returns the hyperparameter search space for DQN.""" - cs = ConfigurationSpace( - name="DQNConfigSpace", - seed=seed, - space={ - "buffer_size": Integer( - "buffer_size", (1024, int(1e7)), default=1000000 - ), - "buffer_batch_size": Categorical( - "buffer_batch_size", [4, 8, 16, 32, 64], default=16 - ), - "buffer_prio_sampling": Categorical( - "buffer_prio_sampling", [True, False], default=False - ), - "buffer_alpha": Float("buffer_alpha", (0.01, 1.0), default=0.9), - "buffer_beta": Float("buffer_beta", (0.01, 1.0), default=0.9), - "buffer_epsilon": Float("buffer_epsilon", (1e-7, 1e-3), default=1e-6), - "learning_rate": Float( - "learning_rate", (1e-6, 0.1), default=3e-4, log=True - ), - "tau": Float("tau", (0.01, 1.0), default=1.0), - "initial_epsilon": Float("initial_epsilon", (0.5, 1.0), default=1.0), - "target_epsilon": Float("target_epsilon", (0.001, 0.2), default=0.05), - "use_target_network": Categorical( - "use_target_network", [True, False], default=True - ), - "train_freq": Integer("train_freq", (1, 256), default=4), - "gradient steps": Integer("gradient_steps", (1, 256), default=1), - "learning_starts": Integer("learning_starts", (0, 32768), default=1024), - "target_update_interval": Integer( - "target_update_interval", (1, 2000), default=1000 - ), - }, - ) - cs.add_conditions( - [ - EqualsCondition( - cs["target_update_interval"], cs["use_target_network"], True - ), - EqualsCondition(cs["tau"], cs["use_target_network"], True), - ] - ) - - return cs - @staticmethod def get_default_hpo_config() -> Configuration: """Returns the default hyperparameter configuration for DQN."""