From e507ea9379e83cccd72fd74435219bbb13bfcdb3 Mon Sep 17 00:00:00 2001 From: Jannis Becktepe <61006252+becktepe@users.noreply.github.com> Date: Thu, 23 May 2024 19:24:58 +0200 Subject: [PATCH] feat: Formatting and more docstrings --- arlbench/arlbench.py | 12 +- arlbench/autorl/autorl_env.py | 4 +- arlbench/autorl/checkpointing.py | 11 +- arlbench/autorl/objectives.py | 10 +- arlbench/autorl/state_features.py | 2 +- arlbench/core/algorithms/algorithm.py | 12 ++ arlbench/core/algorithms/dqn/dqn.py | 11 +- arlbench/core/environments/autorl_env.py | 6 +- arlbench/core/environments/brax_env.py | 4 +- arlbench/core/environments/envpool_env.py | 6 +- arlbench/core/environments/gymnasium_env.py | 4 +- arlbench/core/environments/gymnax_env.py | 2 +- arlbench/core/environments/make_env.py | 3 +- arlbench/core/environments/xland_env.py | 4 +- arlbench/core/wrappers/__init__.py | 5 +- arlbench/core/wrappers/flatten_observation.py | 60 ++++++- arlbench/core/wrappers/image_extraction.py | 17 -- arlbench/core/wrappers/wrapper.py | 5 + arlbench/utils/__init__.py | 2 - arlbench/utils/common.py | 104 +++++++---- arlbench/utils/handle_termination.py | 52 ------ arlbench/utils/hydra_utils.py | 162 ------------------ examples/run_arlbench.py | 10 +- examples/run_heuristic_schedule.py | 11 +- examples/run_reactive_schedule.py | 8 +- run_arlbench.py | 10 +- 26 files changed, 218 insertions(+), 319 deletions(-) delete mode 100644 arlbench/core/wrappers/image_extraction.py delete mode 100644 arlbench/utils/handle_termination.py delete mode 100644 arlbench/utils/hydra_utils.py diff --git a/arlbench/arlbench.py b/arlbench/arlbench.py index 3839faf13..3f59c3baa 100644 --- a/arlbench/arlbench.py +++ b/arlbench/arlbench.py @@ -10,17 +10,20 @@ def run_arlbench(cfg: DictConfig, logger: Logger | None = None) -> float | tuple | list: """Run ARLBench using the given config and return objective(s).""" + # We check if we need to load a checkpoint for HyperPBT + # If so, we load the first episode and first step of ARLBench since we always run only + # one iteration if "load" in cfg and cfg.load: - print(f"### ATTEMPTING TO LOAD {cfg.load} ###") checkpoint_path = os.path.join( cfg.load, cfg.autorl.checkpoint_name, "default_checkpoint_c_episode_1_step_1", ) - print(f"### CHECKPOINT PATH = {checkpoint_path} ###") else: checkpoint_path = None + # We check if we need to save a checkpoint for HyperPBT + # If so, we need to adapt the autorl config accordingly if "save" in cfg and cfg.save: cfg.autorl.checkpoint_dir = str(cfg.save).replace(".pt", "") if cfg.algorithm == "PPO": @@ -28,6 +31,7 @@ def run_arlbench(cfg: DictConfig, logger: Logger | None = None) -> float | tuple else: cfg.autorl.checkpoint = ["opt_state", "params", "buffer"] + # Here, we define how the AutoRLEnv should behave env = AutoRLEnv(cfg.autorl) _ = env.reset() @@ -39,11 +43,9 @@ def run_arlbench(cfg: DictConfig, logger: Logger | None = None) -> float | tuple if logger: logger.info("Training finished.") + # Additionally, we store the evaluation rewards we had during training info["train_info_df"].to_csv("evaluation.csv", index=False) - if "reward_curves" in cfg and cfg.reward_curves: - return list(info["train_info_df"]["returns"]) - if len(objectives) == 1: return objectives[next(iter(objectives.keys()))] else: diff --git a/arlbench/autorl/autorl_env.py b/arlbench/autorl/autorl_env.py index 6255fc080..19b266284 100644 --- a/arlbench/autorl/autorl_env.py +++ b/arlbench/autorl/autorl_env.py @@ -245,7 +245,7 @@ def _make_algorithm(self) -> Algorithm: """Instantiated the RL algorithm given the current AutoRL config and hyperparameter configuration. Returns: - Algorithm: RL algorithm instance. + Algorithm: RL algorithm instance. """ return self._algorithm_cls( self._hpo_config, @@ -303,7 +303,7 @@ def step( self._algorithm = self._make_algorithm() - # First, we check if there is a checkpoint to load. If not, we have to check + # First, we check if there is a checkpoint to load. If not, we have to check # whether this is the first iteration, i.e., call of env.step(). In that case, # we have to initialiaze the algorithm state. # Otherwise, we are using the state from previous iteration(s) diff --git a/arlbench/autorl/checkpointing.py b/arlbench/autorl/checkpointing.py index fb4680f65..eb615d6fb 100644 --- a/arlbench/autorl/checkpointing.py +++ b/arlbench/autorl/checkpointing.py @@ -3,8 +3,9 @@ import json import os import warnings +from collections.abc import Callable from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import jax import jax.numpy as jnp @@ -87,7 +88,7 @@ def save( """Saves the current state of a AutoRL environment. Args: - algorithm (str): Name of the algorithm. + algorithm (str): Name of the algorithm. algorithm_state (AlgorithmState): Current algorithm state. autorl_config (dict): AutoRL configuration. hp_config (Configuration): Hyperparameter configuration of the algorithm. @@ -164,7 +165,7 @@ def save( # Get actual checkpoint by calling factory function ckpt[key] = algorithm_ckpt[ key - ]() + ]() else: # Only use selected checkpoint options for key in checkpoint: @@ -184,7 +185,7 @@ def save( # Get actual checkpoint by calling factory function ckpt[algorithm_key] = algorithm_ckpt[ algorithm_key - ]() + ]() if not found_key: warnings.warn( f"Invalid checkpoint for algorithm {algorithm}: {key}. Valid keys are {list(algorithm_ckpt.keys())!s}. Skipping key." @@ -439,7 +440,7 @@ def load_buffer( PrioritisedTrajectoryBufferState: The buffer state that was loaded from disk. """ # Using the vault we can easily load the data of the buffer - # As described in the part of saving the buffer, this does + # As described in the part of saving the buffer, this does # not contain the priorities v = Vault( vault_name="buffer_state_vault", diff --git a/arlbench/autorl/objectives.py b/arlbench/autorl/objectives.py index 4f2251f0e..92282035e 100644 --- a/arlbench/autorl/objectives.py +++ b/arlbench/autorl/objectives.py @@ -20,7 +20,7 @@ class Objective(ABC): """An abstract optimization objective for the AutoRL environment. - + It can be wrapped around the training function to calculate the objective. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class. @@ -73,7 +73,7 @@ def __lt__(self, other: Objective) -> bool: other (Objective): Other Objective to compare to. Returns: - bool: Whether this Objective is less than the other Objective. + bool: Whether this Objective is less than the other Objective. """ return self.RANK < other.RANK @@ -109,7 +109,7 @@ def get_spec() -> dict: class RewardMean(Objective): """Reward objective for the AutoRL environment. It measures the mean of the last evaluation rewards.""" - + KEY = "reward_mean" RANK = 2 @@ -144,7 +144,7 @@ def get_spec() -> dict: class RewardStd(Objective): """Reward objective for the AutoRL environment. It measures the standard deviation of the last evaluation rewards.""" - + KEY = "reward_std" RANK = 2 @@ -175,7 +175,7 @@ def get_spec() -> dict: class Emissions(Objective): """Emissions objective for the AutoRL environment. It measures the emissions during the training using code carbon.""" - + KEY = "emissions" RANK = 1 diff --git a/arlbench/autorl/state_features.py b/arlbench/autorl/state_features.py index b6458bf6c..4777e3e69 100644 --- a/arlbench/autorl/state_features.py +++ b/arlbench/autorl/state_features.py @@ -15,7 +15,7 @@ class StateFeature(ABC): """An abstract state features for the AutoRL environment. - + It can be wrapped around the training function to calculate the state features. We do this be overriding the __new__() function. It allows us to imitate the behaviour of a basic function while keeping the advantages of a static class. diff --git a/arlbench/core/algorithms/algorithm.py b/arlbench/core/algorithms/algorithm.py index bf30fbea4..5d9081e94 100644 --- a/arlbench/core/algorithms/algorithm.py +++ b/arlbench/core/algorithms/algorithm.py @@ -97,6 +97,18 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace: ConfigurationSpace: Hyperparameter configuration space of the algorithm. """ + @staticmethod + @abstractmethod + def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace: + """Returns the hyperparameter search space of the algorithm. However, this can be adapted to fit a given HPO method. + + Args: + seed (int | None, optional): Random generator seed that is used to sample configurations. Defaults to None. + + Returns: + ConfigurationSpace: Hyperparameter search space of the algorithm. + """ + @staticmethod @abstractmethod def get_default_hpo_config() -> Configuration: diff --git a/arlbench/core/algorithms/dqn/dqn.py b/arlbench/core/algorithms/dqn/dqn.py index 1e9f605b2..b425e18fd 100644 --- a/arlbench/core/algorithms/dqn/dqn.py +++ b/arlbench/core/algorithms/dqn/dqn.py @@ -151,6 +151,7 @@ def __init__( track_metrics=track_metrics, ) + # For the network, we need the properties of the action space action_size, discrete = self.action_type network_cls = CNNQ if cnn_policy else MLPQ self.network = network_cls( @@ -169,6 +170,10 @@ def __init__( priority_exponent=self.hpo_config["buffer_alpha"], device=jax.default_backend(), ) + + # This is how we can turn the prioritized sampling on/off for dynamic HPO + # We always use the prioritized replay buffer, but if "buffer_prio_sampling" + # is disabled, we replace the sampling function by the uniform sampling if self.hpo_config["buffer_prio_sampling"] is False: sample_fn = functools.partial( uniform_sample, @@ -231,7 +236,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace: @staticmethod def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace: - # defaults from https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html cs = ConfigurationSpace( name="DQNConfigSpace", seed=seed, @@ -358,15 +362,18 @@ def init( DQNState: DQN state. """ rng, reset_rng = jax.random.split(rng) - env_state, obs = self.env.reset(reset_rng) + # If any of these if not defined, we need a dummy environment transition + # to initialize them if buffer_state is None or network_params is None or target_params is None: dummy_rng = jax.random.PRNGKey(0) _action = self.env.sample_actions(dummy_rng) _, (_obs, _reward, _done, _) = self.env.step(env_state, _action, dummy_rng) if buffer_state is None: + # This is how transitions will look like during training so we need to pass one + # once to the buffer to estimate and allocate the required buffer size _timestep = TimeStep( last_obs=_obs[0], obs=_obs[0], diff --git a/arlbench/core/environments/autorl_env.py b/arlbench/core/environments/autorl_env.py index 706f6eeb6..67876b6e2 100644 --- a/arlbench/core/environments/autorl_env.py +++ b/arlbench/core/environments/autorl_env.py @@ -6,9 +6,9 @@ import jax import jax.numpy as jnp -import gymnax if TYPE_CHECKING: + import gymnax from chex import PRNGKey @@ -63,7 +63,7 @@ def step( rng (PRNGKey): Random number generator key. Returns: - tuple[Any, Any]: Returns a tuple containing the environment state as well as the actual return of the step() function. + tuple[Any, Any]: Returns a tuple containing the environment state as well as the actual return of the step() function. """ raise NotImplementedError @@ -87,7 +87,7 @@ def observation_space(self) -> gymnax.environments.spaces.Space: @functools.partial(jax.jit, static_argnums=0) def sample_actions(self, rng: PRNGKey) -> jnp.ndarray: - """Samples a random action for each environment. + """Samples a random action for each environment. Args: rng (PRNGKey): Random number generator key. diff --git a/arlbench/core/environments/brax_env.py b/arlbench/core/environments/brax_env.py index e9ea2b8f4..8f52ba794 100644 --- a/arlbench/core/environments/brax_env.py +++ b/arlbench/core/environments/brax_env.py @@ -20,11 +20,11 @@ class BraxEnv(Environment): """A brax-based RL environment.""" - + def __init__( self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None ): - """Creates a brax environment for JAX-based RL training. + """Creates a brax environment for JAX-based RL training. Args: env_name (str): Name/id of the brax environment. diff --git a/arlbench/core/environments/envpool_env.py b/arlbench/core/environments/envpool_env.py index 7a5224876..98788d5b7 100644 --- a/arlbench/core/environments/envpool_env.py +++ b/arlbench/core/environments/envpool_env.py @@ -129,7 +129,7 @@ def numpy_to_jax(x: np.ndarray) -> jnp.ndarray: class EnvpoolEnv(Environment): """An envpool-based RL environment.""" - + def __init__( self, env_name: str, @@ -137,7 +137,7 @@ def __init__( seed: int, env_kwargs: dict[str, Any] | None = None, ): - """Creates an envpool environment for JAX-based RL training. + """Creates an envpool environment for JAX-based RL training. Args: env_name (str): Name/id of the brax environment. @@ -225,7 +225,7 @@ def step(self, env_state: Any, action: Any, _): env_state, _ = env_state else: env_state, lives = env_state - + # Here, we perform the actual step in the envpool environment env_state, (obs, reward, term, trunc, info) = self._xla_step(env_state, action) done = jnp.logical_or(term, trunc) diff --git a/arlbench/core/environments/gymnasium_env.py b/arlbench/core/environments/gymnasium_env.py index 80fb600c6..0dbc4be8b 100644 --- a/arlbench/core/environments/gymnasium_env.py +++ b/arlbench/core/environments/gymnasium_env.py @@ -23,8 +23,8 @@ class GymnasiumEnv(Environment): def __init__( self, env_name: str, seed: int, env_kwargs: dict[str, Any] | None = None - ): - """Creates a gymnasium environment for JAX-based RL training. + ): + """Creates a gymnasium environment for JAX-based RL training. Args: env_name (str): Name/id of the brax environment. diff --git a/arlbench/core/environments/gymnax_env.py b/arlbench/core/environments/gymnax_env.py index 3bacbea60..e3102b2a9 100644 --- a/arlbench/core/environments/gymnax_env.py +++ b/arlbench/core/environments/gymnax_env.py @@ -18,7 +18,7 @@ class GymnaxEnv(Environment): def __init__( self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None ): - """Creates a gymnax environment for JAX-based RL training. + """Creates a gymnax environment for JAX-based RL training. Args: env_name (str): Name/id of the brax environment. diff --git a/arlbench/core/environments/make_env.py b/arlbench/core/environments/make_env.py index f2450fd71..dcdaaeaeb 100644 --- a/arlbench/core/environments/make_env.py +++ b/arlbench/core/environments/make_env.py @@ -3,10 +3,11 @@ import warnings from typing import TYPE_CHECKING -from arlbench.core.wrappers import Wrapper, FlattenObservationWrapper +from arlbench.core.wrappers import FlattenObservationWrapper, Wrapper if TYPE_CHECKING: from typing import Any + from .autorl_env import Environment diff --git a/arlbench/core/environments/xland_env.py b/arlbench/core/environments/xland_env.py index 089c9b731..acc800cc8 100644 --- a/arlbench/core/environments/xland_env.py +++ b/arlbench/core/environments/xland_env.py @@ -23,7 +23,7 @@ def __init__( env_kwargs: dict[str, Any] | None = None, cnn_policy: bool = False, ): - """Creates an xland environment for JAX-based RL training. + """Creates an xland environment for JAX-based RL training. Args: env_name (str): Name/id of the brax environment. @@ -64,7 +64,7 @@ def step(self, env_state: Any, action: Any, rng: PRNGKey): # (as referred to in the xland documentation) timestep = jax.vmap(self._env.step, in_axes=(None, 0, 0))( self.env_params, env_state, action - ) + ) return timestep, (timestep.observation, timestep.reward, timestep.last(), {}) diff --git a/arlbench/core/wrappers/__init__.py b/arlbench/core/wrappers/__init__.py index 80ea4937e..cb91e9156 100644 --- a/arlbench/core/wrappers/__init__.py +++ b/arlbench/core/wrappers/__init__.py @@ -1,5 +1,4 @@ -from .wrapper import Wrapper from .flatten_observation import FlattenObservationWrapper -from .image_extraction import ImageExtractionWrapper +from .wrapper import Wrapper -__all__ = ["Wrapper", "ImageExtractionWrapper", "FlattenObservationWrapper"] +__all__ = ["Wrapper", "FlattenObservationWrapper"] diff --git a/arlbench/core/wrappers/flatten_observation.py b/arlbench/core/wrappers/flatten_observation.py index a7bc5b3e0..9be4592b0 100644 --- a/arlbench/core/wrappers/flatten_observation.py +++ b/arlbench/core/wrappers/flatten_observation.py @@ -1,27 +1,40 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import jax +import jax.numpy as jnp import numpy as np from gymnax.environments import spaces from .wrapper import Wrapper if TYPE_CHECKING: + import chex + from arlbench.core.environments import Environment # TODO add test cases class FlattenObservationWrapper(Wrapper): - """Flatten the observations of the environment.""" + """Wraps the given environment to flatten its observations.""" def __init__(self, env: Environment): + """Wraps the given environment to flatten its observations. + + Args: + env (Environment): Environment to wrap. + """ super().__init__(env) @property def observation_space(self) -> spaces.Box: + """The flattened observation space of the environment. + + Returns: + spaces.Box: Flattened obseration space. + """ assert isinstance( self._env.observation_space, spaces.Box ), "Only Box spaces are supported for now." @@ -33,20 +46,55 @@ def observation_space(self) -> spaces.Box: ) @functools.partial(jax.jit, static_argnums=(0,)) - def __flatten(self, obs): - # since we have a stack of observations, + def __flatten(self, obs: jnp.ndarray) -> jnp.ndarray: + """Flattens a batch of observations. + + Args: + obs (jnp.ndarray): The observations to flatten. + + Returns: + jnp.ndarray: The flattened observations. + """ + # Since we have a stack of observations, # we want to keep the first dimension return obs.reshape(obs.shape[0], -1) @functools.partial(jax.jit, static_argnums=(0,)) - def reset(self, rng): # TODO improve typing + def reset(self, rng: chex.PRNGKey) -> tuple[Any, jnp.ndarray]: + """Calls the reset() function of the environment and flattens the returned observations. + + + Args: + rng (chex.PRNGKey): Random number generator key. + + Returns: + tuple[Any, jnp.ndarray]: Result of the step function but observations are flattened. + """ env_state, obs = self._env.reset(rng) obs = self.__flatten(obs) return env_state, obs @functools.partial(jax.jit, static_argnums=(0,)) - def step(self, env_state, action, rng): # TODO improve typing + def step( + self, + env_state: Any, + action: + jnp.ndarray, + rng: chex.PRNGKey + ) -> tuple[Any, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]]: + """Calls the step() function of the environment and flattens the returned + observations. + + Args: + env_state (Any): The internal environment state. + action (jnp.ndarray): The actions to take. + rng (chex.PRNGKey): Random number generator key. + + Returns: + tuple[Any, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]]: Result of + the step function but observations are flattened. + """ env_state, (obs, reward, done, info) = self._env.step(env_state, action, rng) obs = self.__flatten(obs) diff --git a/arlbench/core/wrappers/image_extraction.py b/arlbench/core/wrappers/image_extraction.py deleted file mode 100644 index 06abf052b..000000000 --- a/arlbench/core/wrappers/image_extraction.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import gymnasium - - -class ImageExtractionWrapper(gymnasium.Wrapper): - def __init__(self, env): - super().__init__(env) - self.observation_space = self.env.observation_space["image"] - - def reset(self, **kwargs): - obs, info = self.env.reset(**kwargs) - return obs["image"], info - - def step(self, action): - obs, reward, tr, te, info = self.env.step(action) - return obs["image"], reward, tr, te, info diff --git a/arlbench/core/wrappers/wrapper.py b/arlbench/core/wrappers/wrapper.py index dfe9d9377..6f2b7bc86 100644 --- a/arlbench/core/wrappers/wrapper.py +++ b/arlbench/core/wrappers/wrapper.py @@ -10,6 +10,11 @@ class Wrapper: """Base class for ARLBench wrappers.""" def __init__(self, env: Environment): + """Wraps an ARLBench Environment. + + Args: + env (Environment): Environment to wrap + """ self._env = env # provide proxy access to regular attributes of wrapped object diff --git a/arlbench/utils/__init__.py b/arlbench/utils/__init__.py index 5ef8bfdc2..f08ca65db 100644 --- a/arlbench/utils/__init__.py +++ b/arlbench/utils/__init__.py @@ -1,7 +1,6 @@ from .common import (config_space_to_gymnasium_space, config_space_to_yaml, gymnasium_space_to_gymnax_space, recursive_concat, save_defaults_to_yaml, tuple_concat) -from .handle_termination import HandleTermination __all__ = [ "config_space_to_gymnasium_space", @@ -10,5 +9,4 @@ "gymnasium_space_to_gymnax_space", "recursive_concat", "tuple_concat", - "HandleTermination", ] diff --git a/arlbench/utils/common.py b/arlbench/utils/common.py index 3f9d99089..6e75ede99 100644 --- a/arlbench/utils/common.py +++ b/arlbench/utils/common.py @@ -7,26 +7,33 @@ import jax.numpy as jnp import numpy as np import yaml +from ConfigSpace import ConfigurationSpace +CAT_HP_CHOICES = 2 -def to_gymnasium_space(space): - import gym as old_gym - if isinstance(space, old_gym.spaces.Box): - new_space = gymnasium.spaces.Box( - low=space.low, high=space.high, dtype=space.low.dtype - ) - elif isinstance(space, old_gym.spaces.Discrete): - new_space = gymnasium.spaces.Discrete(space.n) - else: - raise NotImplementedError - return new_space +def save_defaults_to_yaml( + hp_config_space: ConfigurationSpace, + nas_config_sapce: ConfigurationSpace, + algorithm: str + ) -> str: + """Extracts the default values of the hp_config_space and nas_config_sapce and + returns a yaml file. + Args: + hp_config_space (ConfigurationSpace): The hyperparameter configuration space + of the algorithm. + nas_config_sapce (ConfigurationSpace): The neural architecture configuration + space of the algorithm. + algorithm (str): The name of the algorithm. -def save_defaults_to_yaml(hp_config_space, nas_config_sapce, algorithm: str): + Returns: + str: yaml string. + """ yaml_dict = {"algorithm": algorithm, "hp_config": {}, "nas_config": {}} - def add_hps(config_space, config_key): + def add_hps(config_space: ConfigurationSpace, config_key: str) -> None: + """Adds hyperparameter defaults to a dictionary.""" for hp_name, hp in config_space.items(): if isinstance(hp, ConfigSpace.UniformIntegerHyperparameter): yaml_dict[config_key][hp_name] = int(hp.default_value) @@ -49,10 +56,22 @@ def add_hps(config_space, config_key): def config_space_to_yaml( - config_space: ConfigSpace.ConfigurationSpace, + config_space: ConfigurationSpace, config_key: str = "hp_config", seed: int = 0, -): +) -> str: + """Converts a ConfigSpace object to yaml. + + Args: + config_space (ConfigurationSpace): Configuration space object. + config_key (str, optional): Key for the hyperparameters. + Defaults to "hp_config". + seed (int, optional): Configuration space seed to write to yaml. Defaults to 0. + + + Returns: + _type_: _description_ + """ yaml_dict = {"seed": seed, "hyperparameters": {}, "conditions": []} for hp_name, hp in config_space.items(): if hp_name == "normalize_observations": @@ -77,7 +96,7 @@ def config_space_to_yaml( } elif isinstance(hp, ConfigSpace.CategoricalHyperparameter): try: - if len(hp.choices) == 2: # assume bool + if len(hp.choices) == CAT_HP_CHOICES: # assume bool param = { "type": "categorical", "choices": [bool(c) for c in hp.choices], @@ -89,7 +108,7 @@ def config_space_to_yaml( "choices": [int(c) for c in hp.choices], "default": int(hp.default_value), } - except: + except TypeError: param = { "type": "categorical", "choices": [str(c) for c in hp.choices], @@ -115,8 +134,17 @@ def config_space_to_yaml( def config_space_to_gymnasium_space( - config_space: ConfigSpace.ConfigurationSpace, seed=None + config_space: ConfigurationSpace, seed: int | None = None ) -> gymnasium.spaces.Dict: + """Converts a configuration space to a gymnasium space. + + Args: + config_space (ConfigurationSpace): Configuration space. + seed (int | None, optional): Seed for the gymnasium space. Defaults to None. + + Returns: + gymnasium.spaces.Dict: Gymnasium space. + """ spaces = {} for hp_name, hp in config_space._hyperparameters.items(): @@ -139,7 +167,14 @@ def config_space_to_gymnasium_space( def gymnasium_space_to_gymnax_space(space: gym_spaces.Space) -> gymnax_spaces.Space: - """Convert Gym space to equivalent Gymnax space.""" + """Converst a gymnasium space to a gymnax space. + + Args: + space (Space): Gymnasium space. + + Returns: + gymnax_spaces.Space: Gymnax space. + """ if isinstance(space, gym_spaces.Discrete): return gymnax_spaces.Discrete(int(space.n)) if isinstance(space, gym_spaces.Box): @@ -164,18 +199,17 @@ def gymnasium_space_to_gymnax_space(space: gym_spaces.Space) -> gymnax_spaces.Sp raise NotImplementedError(f"Conversion of {space.__class__.__name__} not supported") -def flatten_dict(d): - """Flatten a nested dictionary into a tuple containing all items.""" - values = [] - for value in d.values(): - if isinstance(value, dict): - values.extend(flatten_dict(value)) - else: - values.append(value) - return tuple(values) +def recursive_concat(dict1: dict, dict2: dict, axis: int = 0) -> dict: + """Recursively concatenates two dictionaries value-wise for same keys. + Args: + dict1 (dict): First dictionary. + dict2 (dict): Second dictionary. + axis (int, optional): Concat axis. Defaults to 0. -def recursive_concat(dict1: dict, dict2: dict, axis: int = 0): + Returns: + dict: Concatenated dictionary. + """ concat_dict = {} assert dict1.keys() == dict2.keys(), "Dictionaries have different sets of keys" @@ -189,7 +223,17 @@ def recursive_concat(dict1: dict, dict2: dict, axis: int = 0): return concat_dict -def tuple_concat(tuple1: tuple, tuple2: tuple, axis: int = 0): +def tuple_concat(tuple1: tuple, tuple2: tuple, axis: int = 0) -> tuple: + """Concatenates two tuples element-wise. + + Args: + tuple1 (tuple): First tuple. + tuple2 (tuple): Second tuple. + axis (int, optional): Concat axis. Defaults to 0. + + Returns: + tuple: Concatenated tuple. + """ assert len(tuple1) == len(tuple2), "Tuples must be of the same length" return tuple( diff --git a/arlbench/utils/handle_termination.py b/arlbench/utils/handle_termination.py deleted file mode 100644 index 920b9ac62..000000000 --- a/arlbench/utils/handle_termination.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Main module.""" - -from __future__ import annotations - -import logging -import signal -import sys -from typing import Any - -logger = logging.getLogger(__name__) - - -class HandleTermination: - """This is a context manager that handles different termination signals. - The comments show how to save a model and optimizer states - even if there's an error in the code. - """ - - def __init__(self, env: Any): - """Initialize the context manager with logdir.""" - self.env = env - - def __enter__(self): - self.old_sigterm_handler = signal.signal(signal.SIGTERM, self.handle_sigterm) - return self - - def __exit__(self, exc_type, exc_value, traceback): - """An exception was raised, so save the model and optimizer states - with an exception tag before re-raising. - """ - signal.signal(signal.SIGTERM, self.old_sigterm_handler) - if exc_type is not None: - logger.info("Oh no, there was an exception!") - path = self.env._save(tag="exc") - logger.info(f"Saving checkpoint to {path}") - - # torch.save({ - # 'model_state_dict': self.model.state_dict(), - # 'optimizer_state_dict': self.optimizer.state_dict(), - # }, self.directory / f'checkpoint_exc.pth') - - # Everything ran perfectly, so save the final model and optimizer states - if exc_type is None: - logger.info("All clear!") - - return False - - def handle_sigterm(self, signum, frame): # noqa: ARG002 - """Save the model and optimizer states before exiting.""" - path = self.env.save(tag="sigterm") - logger.info(f"Saving checkpoint to {path}") - sys.exit(0) diff --git a/arlbench/utils/hydra_utils.py b/arlbench/utils/hydra_utils.py deleted file mode 100644 index 62fafb937..000000000 --- a/arlbench/utils/hydra_utils.py +++ /dev/null @@ -1,162 +0,0 @@ -import logging -import multiprocessing -import os -import subprocess -from itertools import islice -from pathlib import Path -from typing import Any, Callable, List - -import pandas as pd -from hydra.core.utils import setup_globals -from rich.logging import RichHandler -from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn - -FORMAT = "%(message)s" -logging.basicConfig( - level=logging.INFO, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] -) - -setup_globals() - -try: - import tables - - HDF = True -except: - HDF = False - - -def batched(iterable, n): - # batched('ABCDEFG', 3) --> ABC DEF G - if n < 1: - raise ValueError("n must be at least one") - it = iter(iterable) - while batch := tuple(islice(it, n)): - yield batch - - -def read_log( - paths: str, - loading_functions: List[Callable], - processing_functions: List[Callable], - outpath: Path, - batch_size: int = 10, - n_processes: int = 4, -) -> pd.DataFrame: - log = logging.getLogger("ReadLogs") - filenames = [] - if type(paths) == str: - paths = [paths] - paths = list(set(paths)) # filter duplicates - for path in paths: - path = Path(path) - filenames.extend(list(path.glob(f"**/*"))) - - hdf_key = "run_data" - batch_names = [] - for i, batch in enumerate(batched(filenames, batch_size)): - log.info(f"Batch {i}: Start reading {len(batch)} logs from {paths}") - df = map_multiprocessing( - task_functions=loading_functions, - task_params=batch, - task_string="Reading logs...", - n_processes=n_processes, - ) - log.info("Concatenating logs...") - df = pd.concat(df).reset_index(drop=True) - log.info("Postprocess logs...") - for f in processing_functions: - processed = f(df) - if processed is not None: - df = processed - data_fn_tmp = str(outpath) + f"_{i}" - batch_names.append(data_fn_tmp) - log.info(f"Dumping logs to '{data_fn_tmp}'.") - if HDF: - df.to_hdf(data_fn_tmp, hdf_key) - else: - df.to_csv(f"{data_fn_tmp}.csv") - log.info(f"Done with batch {i} 🙂") - - log.info("Collect all batches and save to disk") - if HDF: - df = pd.concat([pd.read_hdf(fn, hdf_key) for fn in batch_names]) - df.to_hdf(outpath, hdf_key) - for fn in batch_names: - subprocess.Popen(f"rm {fn}") - else: - df = pd.concat([pd.read_csv(f"{fn}.csv") for fn in batch_names]) - df.to_csv(f"{outpath}.csv") - for fn in batch_names: - os.remove(f"{fn}.csv") - log.info(f"Done 🙂") - return df - - -def map_multiprocessing( - task_functions: List[Callable], - task_params: list[Any], - n_processes: int = 4, - task_string: str = "Working...", -) -> list: - results = [] - - def get_results(path): - res = [] - for f in task_functions: - results = map(f, path) - results = list(filter(lambda x: x is not None, results)) - if len(results) > 0: - res.append(pd.concat(results).reset_index(drop=True)) - return res - - with Progress( - *Progress.get_default_columns(), - MofNCompleteColumn(), - TimeElapsedColumn(), - refresh_per_second=2, - ) as progress: - task_id = progress.add_task(f"[cyan]{task_string}", total=len(task_params)) - try: - with multiprocessing.Pool(processes=n_processes) as pool: - for result in pool.imap(get_results, task_params): - results.append(result) - progress.advance(task_id) - except AttributeError: - print("Pickling failed, falling back on sequential loading.") - for param in task_params: - results.extend(get_results([param])) - progress.advance(task_id) - return results - - -def read_logs( - data_path: str, - loading_functions: List[Callable], - processing_functions: List[Callable], - save_to: str, -) -> pd.DataFrame: - path = Path(data_path) - filenames = list(path.glob(f"**/*")) - outpath = Path(save_to) - return read_log(filenames, loading_functions, processing_functions, outpath) - - -def get_missing_jobs( - data_path: str, is_done: Callable, n_processes: int = 4 -) -> pd.DataFrame: - path = Path(data_path) - filenames = list(path.glob(f"**/*")) - results = [] - - def is_missing(filename): - if is_done(filename) is False and is_done(filename) is not None: - return filename - else: - return None - - results = list(filter(lambda x: x is not None, map(is_missing, filenames))) - log = logging.getLogger("CheckJobs") - log.info(f"Found {len(results)} missing jobs.") - log.info(f"That means {len(filenames) - len(results)} jobs are done.") - return results diff --git a/examples/run_arlbench.py b/examples/run_arlbench.py index 294200f29..6ceee66a8 100644 --- a/examples/run_arlbench.py +++ b/examples/run_arlbench.py @@ -3,16 +3,20 @@ from __future__ import annotations import warnings + warnings.filterwarnings("ignore") import csv import logging import sys import traceback +from typing import TYPE_CHECKING import hydra import jax from arlbench.arlbench import run_arlbench -from omegaconf import DictConfig + +if TYPE_CHECKING: + from omegaconf import DictConfig @hydra.main(version_base=None, config_path="configs", config_name="base") @@ -36,7 +40,6 @@ def execute(cfg: DictConfig): def run(cfg: DictConfig, logger: logging.Logger): """Console script for arlbench.""" - # check if file done exists and if so, return try: with open("./done.txt") as f: @@ -45,8 +48,7 @@ def run(cfg: DictConfig, logger: logging.Logger): with open("./performance.csv") as pf: csvreader = csv.reader(pf) performance = next(csvreader) - performance = float(performance[0]) - return performance + return float(performance[0]) except FileNotFoundError: pass diff --git a/examples/run_heuristic_schedule.py b/examples/run_heuristic_schedule.py index d85ea6d86..3de2808ee 100644 --- a/examples/run_heuristic_schedule.py +++ b/examples/run_heuristic_schedule.py @@ -3,19 +3,24 @@ from __future__ import annotations import warnings + warnings.filterwarnings("ignore") import logging import sys import traceback +from typing import TYPE_CHECKING import hydra import jax from arlbench.autorl import AutoRLEnv -from omegaconf import DictConfig + +if TYPE_CHECKING: + from omegaconf import DictConfig + def run(cfg: DictConfig, logger: logging.Logger): """Heuristic-based exploration schedule. Decrease epsilon in DQN when evaluation performance reaches a certain threshold.""" - logger.info(f"Starting run with epsilon value {str(cfg.hp_config.initial_epsilon)}") + logger.info(f"Starting run with epsilon value {cfg.hp_config.initial_epsilon!s}") # Initialize environment with general config env = AutoRLEnv(cfg.autorl) @@ -30,7 +35,7 @@ def run(cfg: DictConfig, logger: logging.Logger): cfg.hp_config.target_epsilon = 0.7 cfg.hp_config.initial_epsilon = 0.7 logger.info("Agent reached performance threshold, decreasing epsilon to 0.7") - + logger.info(f"Training finished with a total reward of {objectives['reward_mean']}") @hydra.main(version_base=None, config_path="configs", config_name="epsilon_heuristic") diff --git a/examples/run_reactive_schedule.py b/examples/run_reactive_schedule.py index c8f2c86f4..ac72d2d75 100644 --- a/examples/run_reactive_schedule.py +++ b/examples/run_reactive_schedule.py @@ -3,19 +3,23 @@ from __future__ import annotations import warnings + warnings.filterwarnings("ignore") import logging import sys import traceback +from typing import TYPE_CHECKING import hydra import jax from arlbench.autorl import AutoRLEnv -from omegaconf import DictConfig + +if TYPE_CHECKING: + from omegaconf import DictConfig + def run(cfg: DictConfig, logger: logging.Logger): """Gradient-based learning rate schedule. Spike the learning rate for one step if gradients stagnate.""" - # Initialize environment with general config env = AutoRLEnv(cfg.autorl) diff --git a/run_arlbench.py b/run_arlbench.py index edf1637c7..b0801b1a4 100644 --- a/run_arlbench.py +++ b/run_arlbench.py @@ -3,16 +3,20 @@ from __future__ import annotations import warnings + warnings.filterwarnings("ignore") import csv import logging import sys import traceback +from typing import TYPE_CHECKING import hydra import jax from arlbench.arlbench import run_arlbench -from omegaconf import DictConfig + +if TYPE_CHECKING: + from omegaconf import DictConfig @hydra.main(version_base=None, config_path="examples/configs", config_name="base") @@ -36,7 +40,6 @@ def execute(cfg: DictConfig): def run(cfg: DictConfig, logger: logging.Logger): """Console script for arlbench.""" - # check if file done exists and if so, return try: with open("./done.txt") as f: @@ -45,8 +48,7 @@ def run(cfg: DictConfig, logger: logging.Logger): with open("./performance.csv") as pf: csvreader = csv.reader(pf) performance = next(csvreader) - performance = float(performance[0]) - return performance + return float(performance[0]) except FileNotFoundError: pass