From dd729cfb081fe26ceae2ba813bc5ba43ed7fc6c8 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Fri, 29 Dec 2023 17:50:56 +0000 Subject: [PATCH] Fix PGA-MAP-Elites --- qdax/core/emitters/qpg_emitter.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index d43827df..a0b5c62d 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -12,7 +12,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import MapElitesRepertoire +from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn @@ -121,7 +121,7 @@ def use_all_data(self) -> bool: def init( self, random_key: RNGKey, - repertoire: MapElitesRepertoire, + repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -168,6 +168,13 @@ def init( buffer_size=self._config.replay_buffer_size, transition=dummy_transition ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + # add transitions in the replay buffer + replay_buffer = replay_buffer.insert(transitions) + # Initial training state random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( @@ -177,9 +184,9 @@ def init( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, + replay_buffer=replay_buffer, random_key=subkey, steps=jnp.array(0), - replay_buffer=replay_buffer, ) return emitter_state, random_key @@ -190,7 +197,7 @@ def init( ) def emit( self, - repertoire: MapElitesRepertoire, + repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, RNGKey]: @@ -279,7 +286,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: def state_update( self, emitter_state: QualityPGEmitterState, - repertoire: Optional[MapElitesRepertoire], + repertoire: Optional[Repertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor],