Skip to content

Commit

Permalink
Fix PGA-MAP-Elites
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Dec 29, 2023
1 parent 5fb2b65 commit dd729cf
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions qdax/core/emitters/qpg_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import optax
from jax import numpy as jnp

from qdax.core.containers.repertoire import MapElitesRepertoire
from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer
from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn
Expand Down Expand Up @@ -121,7 +121,7 @@ def use_all_data(self) -> bool:
def init(
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
repertoire: Repertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
Expand Down Expand Up @@ -168,6 +168,13 @@ def init(
buffer_size=self._config.replay_buffer_size, transition=dummy_transition
)

# get the transitions out of the dictionary
assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
transitions = extra_scores["transitions"]

# add transitions in the replay buffer
replay_buffer = replay_buffer.insert(transitions)

# Initial training state
random_key, subkey = jax.random.split(random_key)
emitter_state = QualityPGEmitterState(
Expand All @@ -177,9 +184,9 @@ def init(
actor_opt_state=actor_optimizer_state,
target_critic_params=target_critic_params,
target_actor_params=target_actor_params,
replay_buffer=replay_buffer,
random_key=subkey,
steps=jnp.array(0),
replay_buffer=replay_buffer,
)

return emitter_state, random_key
Expand All @@ -190,7 +197,7 @@ def init(
)
def emit(
self,
repertoire: MapElitesRepertoire,
repertoire: Repertoire,
emitter_state: QualityPGEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
Expand Down Expand Up @@ -279,7 +286,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
def state_update(
self,
emitter_state: QualityPGEmitterState,
repertoire: Optional[MapElitesRepertoire],
repertoire: Optional[Repertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
Expand Down

0 comments on commit dd729cf

Please sign in to comment.