Skip to content

Commit

Permalink
feat: Formatting and more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
becktepe committed May 23, 2024
1 parent 1446035 commit e507ea9
Show file tree
Hide file tree
Showing 26 changed files with 218 additions and 319 deletions.
12 changes: 7 additions & 5 deletions arlbench/arlbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,28 @@

def run_arlbench(cfg: DictConfig, logger: Logger | None = None) -> float | tuple | list:
"""Run ARLBench using the given config and return objective(s)."""
# We check if we need to load a checkpoint for HyperPBT
# If so, we load the first episode and first step of ARLBench since we always run only
# one iteration
if "load" in cfg and cfg.load:
print(f"### ATTEMPTING TO LOAD {cfg.load} ###")
checkpoint_path = os.path.join(
cfg.load,
cfg.autorl.checkpoint_name,
"default_checkpoint_c_episode_1_step_1",
)
print(f"### CHECKPOINT PATH = {checkpoint_path} ###")
else:
checkpoint_path = None

# We check if we need to save a checkpoint for HyperPBT
# If so, we need to adapt the autorl config accordingly
if "save" in cfg and cfg.save:
cfg.autorl.checkpoint_dir = str(cfg.save).replace(".pt", "")
if cfg.algorithm == "PPO":
cfg.autorl.checkpoint = ["opt_state", "params"]
else:
cfg.autorl.checkpoint = ["opt_state", "params", "buffer"]

# Here, we define how the AutoRLEnv should behave
env = AutoRLEnv(cfg.autorl)
_ = env.reset()

Expand All @@ -39,11 +43,9 @@ def run_arlbench(cfg: DictConfig, logger: Logger | None = None) -> float | tuple
if logger:
logger.info("Training finished.")

# Additionally, we store the evaluation rewards we had during training
info["train_info_df"].to_csv("evaluation.csv", index=False)

if "reward_curves" in cfg and cfg.reward_curves:
return list(info["train_info_df"]["returns"])

if len(objectives) == 1:
return objectives[next(iter(objectives.keys()))]
else:
Expand Down
4 changes: 2 additions & 2 deletions arlbench/autorl/autorl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _make_algorithm(self) -> Algorithm:
"""Instantiated the RL algorithm given the current AutoRL config and hyperparameter configuration.
Returns:
Algorithm: RL algorithm instance.
Algorithm: RL algorithm instance.
"""
return self._algorithm_cls(
self._hpo_config,
Expand Down Expand Up @@ -303,7 +303,7 @@ def step(

self._algorithm = self._make_algorithm()

# First, we check if there is a checkpoint to load. If not, we have to check
# First, we check if there is a checkpoint to load. If not, we have to check
# whether this is the first iteration, i.e., call of env.step(). In that case,
# we have to initialiaze the algorithm state.
# Otherwise, we are using the state from previous iteration(s)
Expand Down
11 changes: 6 additions & 5 deletions arlbench/autorl/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import json
import os
import warnings
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -87,7 +88,7 @@ def save(
"""Saves the current state of a AutoRL environment.
Args:
algorithm (str): Name of the algorithm.
algorithm (str): Name of the algorithm.
algorithm_state (AlgorithmState): Current algorithm state.
autorl_config (dict): AutoRL configuration.
hp_config (Configuration): Hyperparameter configuration of the algorithm.
Expand Down Expand Up @@ -164,7 +165,7 @@ def save(
# Get actual checkpoint by calling factory function
ckpt[key] = algorithm_ckpt[
key
]()
]()
else:
# Only use selected checkpoint options
for key in checkpoint:
Expand All @@ -184,7 +185,7 @@ def save(
# Get actual checkpoint by calling factory function
ckpt[algorithm_key] = algorithm_ckpt[
algorithm_key
]()
]()
if not found_key:
warnings.warn(
f"Invalid checkpoint for algorithm {algorithm}: {key}. Valid keys are {list(algorithm_ckpt.keys())!s}. Skipping key."
Expand Down Expand Up @@ -439,7 +440,7 @@ def load_buffer(
PrioritisedTrajectoryBufferState: The buffer state that was loaded from disk.
"""
# Using the vault we can easily load the data of the buffer
# As described in the part of saving the buffer, this does
# As described in the part of saving the buffer, this does
# not contain the priorities
v = Vault(
vault_name="buffer_state_vault",
Expand Down
10 changes: 5 additions & 5 deletions arlbench/autorl/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class Objective(ABC):
"""An abstract optimization objective for the AutoRL environment.
It can be wrapped around the training function to calculate the objective.
We do this be overriding the __new__() function. It allows us to imitate
the behaviour of a basic function while keeping the advantages of a static class.
Expand Down Expand Up @@ -73,7 +73,7 @@ def __lt__(self, other: Objective) -> bool:
other (Objective): Other Objective to compare to.
Returns:
bool: Whether this Objective is less than the other Objective.
bool: Whether this Objective is less than the other Objective.
"""
return self.RANK < other.RANK

Expand Down Expand Up @@ -109,7 +109,7 @@ def get_spec() -> dict:

class RewardMean(Objective):
"""Reward objective for the AutoRL environment. It measures the mean of the last evaluation rewards."""

KEY = "reward_mean"
RANK = 2

Expand Down Expand Up @@ -144,7 +144,7 @@ def get_spec() -> dict:

class RewardStd(Objective):
"""Reward objective for the AutoRL environment. It measures the standard deviation of the last evaluation rewards."""

KEY = "reward_std"
RANK = 2

Expand Down Expand Up @@ -175,7 +175,7 @@ def get_spec() -> dict:

class Emissions(Objective):
"""Emissions objective for the AutoRL environment. It measures the emissions during the training using code carbon."""

KEY = "emissions"
RANK = 1

Expand Down
2 changes: 1 addition & 1 deletion arlbench/autorl/state_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class StateFeature(ABC):
"""An abstract state features for the AutoRL environment.
It can be wrapped around the training function to calculate the state features.
We do this be overriding the __new__() function. It allows us to imitate
the behaviour of a basic function while keeping the advantages of a static class.
Expand Down
12 changes: 12 additions & 0 deletions arlbench/core/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:
ConfigurationSpace: Hyperparameter configuration space of the algorithm.
"""

@staticmethod
@abstractmethod
def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
"""Returns the hyperparameter search space of the algorithm. However, this can be adapted to fit a given HPO method.
Args:
seed (int | None, optional): Random generator seed that is used to sample configurations. Defaults to None.
Returns:
ConfigurationSpace: Hyperparameter search space of the algorithm.
"""

@staticmethod
@abstractmethod
def get_default_hpo_config() -> Configuration:
Expand Down
11 changes: 9 additions & 2 deletions arlbench/core/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
track_metrics=track_metrics,
)

# For the network, we need the properties of the action space
action_size, discrete = self.action_type
network_cls = CNNQ if cnn_policy else MLPQ
self.network = network_cls(
Expand All @@ -169,6 +170,10 @@ def __init__(
priority_exponent=self.hpo_config["buffer_alpha"],
device=jax.default_backend(),
)

# This is how we can turn the prioritized sampling on/off for dynamic HPO
# We always use the prioritized replay buffer, but if "buffer_prio_sampling"
# is disabled, we replace the sampling function by the uniform sampling
if self.hpo_config["buffer_prio_sampling"] is False:
sample_fn = functools.partial(
uniform_sample,
Expand Down Expand Up @@ -231,7 +236,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:

@staticmethod
def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
# defaults from https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html
cs = ConfigurationSpace(
name="DQNConfigSpace",
seed=seed,
Expand Down Expand Up @@ -358,15 +362,18 @@ def init(
DQNState: DQN state.
"""
rng, reset_rng = jax.random.split(rng)

env_state, obs = self.env.reset(reset_rng)

# If any of these if not defined, we need a dummy environment transition
# to initialize them
if buffer_state is None or network_params is None or target_params is None:
dummy_rng = jax.random.PRNGKey(0)
_action = self.env.sample_actions(dummy_rng)
_, (_obs, _reward, _done, _) = self.env.step(env_state, _action, dummy_rng)

if buffer_state is None:
# This is how transitions will look like during training so we need to pass one
# once to the buffer to estimate and allocate the required buffer size
_timestep = TimeStep(
last_obs=_obs[0],
obs=_obs[0],
Expand Down
6 changes: 3 additions & 3 deletions arlbench/core/environments/autorl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import jax
import jax.numpy as jnp
import gymnax

if TYPE_CHECKING:
import gymnax
from chex import PRNGKey


Expand Down Expand Up @@ -63,7 +63,7 @@ def step(
rng (PRNGKey): Random number generator key.
Returns:
tuple[Any, Any]: Returns a tuple containing the environment state as well as the actual return of the step() function.
tuple[Any, Any]: Returns a tuple containing the environment state as well as the actual return of the step() function.
"""
raise NotImplementedError

Expand All @@ -87,7 +87,7 @@ def observation_space(self) -> gymnax.environments.spaces.Space:

@functools.partial(jax.jit, static_argnums=0)
def sample_actions(self, rng: PRNGKey) -> jnp.ndarray:
"""Samples a random action for each environment.
"""Samples a random action for each environment.
Args:
rng (PRNGKey): Random number generator key.
Expand Down
4 changes: 2 additions & 2 deletions arlbench/core/environments/brax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

class BraxEnv(Environment):
"""A brax-based RL environment."""

def __init__(
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None
):
"""Creates a brax environment for JAX-based RL training.
"""Creates a brax environment for JAX-based RL training.
Args:
env_name (str): Name/id of the brax environment.
Expand Down
6 changes: 3 additions & 3 deletions arlbench/core/environments/envpool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ def numpy_to_jax(x: np.ndarray) -> jnp.ndarray:

class EnvpoolEnv(Environment):
"""An envpool-based RL environment."""

def __init__(
self,
env_name: str,
n_envs: int,
seed: int,
env_kwargs: dict[str, Any] | None = None,
):
"""Creates an envpool environment for JAX-based RL training.
"""Creates an envpool environment for JAX-based RL training.
Args:
env_name (str): Name/id of the brax environment.
Expand Down Expand Up @@ -225,7 +225,7 @@ def step(self, env_state: Any, action: Any, _):
env_state, _ = env_state
else:
env_state, lives = env_state

# Here, we perform the actual step in the envpool environment
env_state, (obs, reward, term, trunc, info) = self._xla_step(env_state, action)
done = jnp.logical_or(term, trunc)
Expand Down
4 changes: 2 additions & 2 deletions arlbench/core/environments/gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class GymnasiumEnv(Environment):

def __init__(
self, env_name: str, seed: int, env_kwargs: dict[str, Any] | None = None
):
"""Creates a gymnasium environment for JAX-based RL training.
):
"""Creates a gymnasium environment for JAX-based RL training.
Args:
env_name (str): Name/id of the brax environment.
Expand Down
2 changes: 1 addition & 1 deletion arlbench/core/environments/gymnax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GymnaxEnv(Environment):
def __init__(
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None
):
"""Creates a gymnax environment for JAX-based RL training.
"""Creates a gymnax environment for JAX-based RL training.
Args:
env_name (str): Name/id of the brax environment.
Expand Down
3 changes: 2 additions & 1 deletion arlbench/core/environments/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import warnings
from typing import TYPE_CHECKING

from arlbench.core.wrappers import Wrapper, FlattenObservationWrapper
from arlbench.core.wrappers import FlattenObservationWrapper, Wrapper

if TYPE_CHECKING:
from typing import Any

from .autorl_env import Environment


Expand Down
4 changes: 2 additions & 2 deletions arlbench/core/environments/xland_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
env_kwargs: dict[str, Any] | None = None,
cnn_policy: bool = False,
):
"""Creates an xland environment for JAX-based RL training.
"""Creates an xland environment for JAX-based RL training.
Args:
env_name (str): Name/id of the brax environment.
Expand Down Expand Up @@ -64,7 +64,7 @@ def step(self, env_state: Any, action: Any, rng: PRNGKey):
# (as referred to in the xland documentation)
timestep = jax.vmap(self._env.step, in_axes=(None, 0, 0))(
self.env_params, env_state, action
)
)

return timestep, (timestep.observation, timestep.reward, timestep.last(), {})

Expand Down
5 changes: 2 additions & 3 deletions arlbench/core/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .wrapper import Wrapper
from .flatten_observation import FlattenObservationWrapper
from .image_extraction import ImageExtractionWrapper
from .wrapper import Wrapper

__all__ = ["Wrapper", "ImageExtractionWrapper", "FlattenObservationWrapper"]
__all__ = ["Wrapper", "FlattenObservationWrapper"]
Loading

0 comments on commit e507ea9

Please sign in to comment.