From b3f0f13efdd93808c3585c837eefc4bef98ae4c4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 20:26:40 +0000 Subject: [PATCH] add reward offset to DCRLME test --- tests/baselines_test/dcrlme_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 9b602698..313d3623 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import pytest +from brax.envs import Env, State, Wrapper from qdax import environments from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids @@ -18,6 +19,27 @@ from qdax.utils.metrics import default_qd_metrics +class RewardOffsetEnvWrapper(Wrapper): + """Wraps ant_omni environment to add and scale position.""" + + def __init__(self, env: Env, env_name: str) -> None: + super().__init__(env) + self._env_name = env_name + + @property + def name(self) -> str: + return self._env_name + + def reset(self, rng: jnp.ndarray) -> State: + state = self.env.reset(rng) + return state + + def step(self, state: State, action: jnp.ndarray) -> State: + state = self.env.step(state, action) + new_reward = state.reward + environments.reward_offset[self._env_name] + return state.replace(reward=new_reward) + + def test_dcrlme() -> None: seed = 42 @@ -64,6 +86,10 @@ def test_dcrlme() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) + env = RewardOffsetEnvWrapper( + env, env_name + ) # apply reward offset as DCG needs positive rewards + reset_fn = jax.jit(env.reset) # Compute the centroids