Skip to content

Commit

Permalink
use appropriate reward offset and clip wrappers in test and nb
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 6, 2024
1 parent 38801ba commit fbbaa5e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 107 deletions.
58 changes: 7 additions & 51 deletions examples/dcrlme.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n",
"from qdax.custom_types import EnvState, Params, RNGKey\n",
"from qdax.environments import behavior_descriptor_extractor\n",
"from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n",
"from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n",
"from qdax.utils.plotting import plot_map_elites_results\n",
"\n",
Expand All @@ -102,33 +103,6 @@
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RewardOffsetEnvWrapper(Wrapper):\n",
" \"\"\"Wraps ant_omni environment to add and scale position.\"\"\"\n",
"\n",
" def __init__(self, env: Env, env_name: str) -> None:\n",
" super().__init__(env)\n",
" self._env_name = env_name\n",
"\n",
" @property\n",
" def name(self) -> str:\n",
" return self._env_name\n",
"\n",
" def reset(self, rng: jnp.ndarray) -> State:\n",
" state = self.env.reset(rng)\n",
" return state\n",
"\n",
" def step(self, state: State, action: jnp.ndarray) -> State:\n",
" state = self.env.step(state, action)\n",
" new_reward = state.reward + environments.reward_offset[self._env_name]\n",
" return state.replace(reward=new_reward)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -199,9 +173,12 @@
"\n",
"# Init environment\n",
"env = environments.create(env_name, episode_length=episode_length)\n",
"env = RewardOffsetEnvWrapper(\n",
" env, env_name\n",
") # apply reward offset as DCG needs positive rewards\n",
"env = OffsetRewardWrapper(\n",
" env, offset=environments.reward_offset[env_name]\n",
") # apply reward offset as DCRL needs positive rewards\n",
"env = ClipRewardWrapper(\n",
" env, clip_min=0.,\n",
") # apply reward clip as DCRL needs positive rewards\n",
"\n",
"reset_fn = jax.jit(env.reset)\n",
"\n",
Expand Down Expand Up @@ -472,27 +449,6 @@
"# create the plots and the grid\n",
"fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
31 changes: 0 additions & 31 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,37 +102,6 @@ def step(self, state: State, action: jp.ndarray) -> State:
)


class AffineRewardWrapper(Wrapper):
"""Wraps gym environments to clip the reward.
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply clip the reward to be greater than 0.
"""

def __init__(
self,
env: Env,
clip_min: Optional[float] = None,
clip_max: Optional[float] = None,
) -> None:
super().__init__(env)
self._clip_min = clip_min
self._clip_max = clip_max

def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
return state.replace(
reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)
)

def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
return state.replace(
reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)
)


class OffsetRewardWrapper(Wrapper):
"""Wraps gym environments to offset the reward to be greater than 0.
Expand Down
33 changes: 8 additions & 25 deletions tests/baselines_test/dcrlme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax
import jax.numpy as jnp
import pytest
from brax.envs import Env, State, Wrapper

from qdax import environments
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
Expand All @@ -15,31 +14,11 @@
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC
from qdax.custom_types import EnvState, Params, RNGKey
from qdax.environments import behavior_descriptor_extractor
from qdax.environments.wrappers import ClipRewardWrapper, OffsetRewardWrapper
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs
from qdax.utils.metrics import default_qd_metrics


class RewardOffsetEnvWrapper(Wrapper):
"""Wraps ant_omni environment to add and scale position."""

def __init__(self, env: Env, env_name: str) -> None:
super().__init__(env)
self._env_name = env_name

@property
def name(self) -> str:
return self._env_name

def reset(self, rng: jnp.ndarray) -> State:
state = self.env.reset(rng)
return state

def step(self, state: State, action: jnp.ndarray) -> State:
state = self.env.step(state, action)
new_reward = state.reward + environments.reward_offset[self._env_name]
return state.replace(reward=new_reward)


def test_dcrlme() -> None:
seed = 42

Expand Down Expand Up @@ -86,9 +65,13 @@ def test_dcrlme() -> None:

# Init environment
env = environments.create(env_name, episode_length=episode_length)
env = RewardOffsetEnvWrapper(
env, env_name
) # apply reward offset as DCG needs positive rewards
env = OffsetRewardWrapper(
env, offset=environments.reward_offset[env_name]
) # apply reward offset as DCRL needs positive rewards
env = ClipRewardWrapper(
env,
clip_min=0.0,
) # apply reward clip as DCRL needs positive rewards

reset_fn = jax.jit(env.reset)

Expand Down

0 comments on commit fbbaa5e

Please sign in to comment.