From ec907338b24ba4f21c27aec22ae8cf7ec83cc536 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Mon, 8 Jan 2024 17:33:12 +0000 Subject: [PATCH] Fix tests --- qdax/core/emitters/cma_pool_emitter.py | 13 ++++++++++--- tests/baselines_test/me_pbt_sac_test.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5af2181..24556f8b 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -73,7 +73,14 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(genotypes, random_key) + emitter_state, random_key = self._emitter.init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) return random_key, emitter_state # init all the emitter states @@ -117,11 +124,11 @@ def emit( ) # use it to emit offsprings - offsprings, random_key = self._emitter.emit( + offsprings, extra_info, random_key = self._emitter.emit( repertoire, used_emitter_state, random_key ) - return offsprings, {}, random_key + return offsprings, extra_info, random_key @partial( jax.jit, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index c4ab259e..98a5b960 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -126,7 +126,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name]