Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Jan 8, 2024
1 parent dd72df7 commit ec90733
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions qdax/core/emitters/cma_pool_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
random_key = carry
emitter_state, random_key = self._emitter.init(genotypes, random_key)
emitter_state, random_key = self._emitter.init(
random_key,
repertoire,
genotypes,
fitnesses,
descriptors,
extra_scores,
)
return random_key, emitter_state

# init all the emitter states
Expand Down Expand Up @@ -117,11 +124,11 @@ def emit(
)

# use it to emit offsprings
offsprings, random_key = self._emitter.emit(
offsprings, extra_info, random_key = self._emitter.emit(
repertoire, used_emitter_state, random_key
)

return offsprings, {}, random_key
return offsprings, extra_info, random_key

@partial(
jax.jit,
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/me_pbt_sac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def scoring_function(genotypes, random_key): # type: ignore
lambda x: jnp.repeat(x, population_size, axis=0), first_states
)
population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)
return population_returns, population_bds, None, random_key
return population_returns, population_bds, {}, random_key

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]
Expand Down

0 comments on commit ec90733

Please sign in to comment.