From fbbaa5e4631e7f12897f19195bcbb61c61749de4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 6 Sep 2024 10:40:07 +0100 Subject: [PATCH] use appropriate reward offset and clip wrappers in test and nb --- examples/dcrlme.ipynb | 58 ++++------------------------- qdax/environments/wrappers.py | 31 --------------- tests/baselines_test/dcrlme_test.py | 33 ++++------------ 3 files changed, 15 insertions(+), 107 deletions(-) diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index 12a147b6..5c367e37 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -89,6 +89,7 @@ "from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n", "from qdax.custom_types import EnvState, Params, RNGKey\n", "from qdax.environments import behavior_descriptor_extractor\n", + "from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n", "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", "from qdax.utils.plotting import plot_map_elites_results\n", "\n", @@ -102,33 +103,6 @@ "clear_output()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class RewardOffsetEnvWrapper(Wrapper):\n", - " \"\"\"Wraps ant_omni environment to add and scale position.\"\"\"\n", - "\n", - " def __init__(self, env: Env, env_name: str) -> None:\n", - " super().__init__(env)\n", - " self._env_name = env_name\n", - "\n", - " @property\n", - " def name(self) -> str:\n", - " return self._env_name\n", - "\n", - " def reset(self, rng: jnp.ndarray) -> State:\n", - " state = self.env.reset(rng)\n", - " return state\n", - "\n", - " def step(self, state: State, action: jnp.ndarray) -> State:\n", - " state = self.env.step(state, action)\n", - " new_reward = state.reward + environments.reward_offset[self._env_name]\n", - " return state.replace(reward=new_reward)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -199,9 +173,12 @@ "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", - "env = RewardOffsetEnvWrapper(\n", - " env, env_name\n", - ") # apply reward offset as DCG needs positive rewards\n", + "env = OffsetRewardWrapper(\n", + " env, offset=environments.reward_offset[env_name]\n", + ") # apply reward offset as DCRL needs positive rewards\n", + "env = ClipRewardWrapper(\n", + " env, clip_min=0.,\n", + ") # apply reward clip as DCRL needs positive rewards\n", "\n", "reset_fn = jax.jit(env.reset)\n", "\n", @@ -472,27 +449,6 @@ "# create the plots and the grid\n", "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index e5e40e4b..babedaed 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -102,37 +102,6 @@ def step(self, state: State, action: jp.ndarray) -> State: ) -class AffineRewardWrapper(Wrapper): - """Wraps gym environments to clip the reward. - - Utilisation is simple: create an environment with Brax, pass - it to the wrapper with the name of the environment, and it will - work like before and will simply clip the reward to be greater than 0. - """ - - def __init__( - self, - env: Env, - clip_min: Optional[float] = None, - clip_max: Optional[float] = None, - ) -> None: - super().__init__(env) - self._clip_min = clip_min - self._clip_max = clip_max - - def reset(self, rng: jp.ndarray) -> State: - state = self.env.reset(rng) - return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) - ) - - def step(self, state: State, action: jp.ndarray) -> State: - state = self.env.step(state, action) - return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) - ) - - class OffsetRewardWrapper(Wrapper): """Wraps gym environments to offset the reward to be greater than 0. diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index eead8104..05304944 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp import pytest -from brax.envs import Env, State, Wrapper from qdax import environments from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids @@ -15,31 +14,11 @@ from qdax.core.neuroevolution.networks.networks import MLP, MLPDC from qdax.custom_types import EnvState, Params, RNGKey from qdax.environments import behavior_descriptor_extractor +from qdax.environments.wrappers import ClipRewardWrapper, OffsetRewardWrapper from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs from qdax.utils.metrics import default_qd_metrics -class RewardOffsetEnvWrapper(Wrapper): - """Wraps ant_omni environment to add and scale position.""" - - def __init__(self, env: Env, env_name: str) -> None: - super().__init__(env) - self._env_name = env_name - - @property - def name(self) -> str: - return self._env_name - - def reset(self, rng: jnp.ndarray) -> State: - state = self.env.reset(rng) - return state - - def step(self, state: State, action: jnp.ndarray) -> State: - state = self.env.step(state, action) - new_reward = state.reward + environments.reward_offset[self._env_name] - return state.replace(reward=new_reward) - - def test_dcrlme() -> None: seed = 42 @@ -86,9 +65,13 @@ def test_dcrlme() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) - env = RewardOffsetEnvWrapper( - env, env_name - ) # apply reward offset as DCG needs positive rewards + env = OffsetRewardWrapper( + env, offset=environments.reward_offset[env_name] + ) # apply reward offset as DCRL needs positive rewards + env = ClipRewardWrapper( + env, + clip_min=0.0, + ) # apply reward clip as DCRL needs positive rewards reset_fn = jax.jit(env.reset)