diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 3d90144e..fed716e3 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -199,7 +199,7 @@ def update( repertoire: unstructured repertoire emitter_state: state of the emitter random_key: a jax PRNG random key - aurora_extra_info: extra info for the encoding # TODO + aurora_extra_info: extra info for computing encodings Results: the updated MAP-Elites repertoire diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index a6a495a0..af1d51ba 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -42,10 +42,27 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri class AuroraExtraInfo(flax.struct.PyTreeNode): + """ + Information specific to the AURORA algorithm. + + Args: + model_params: the parameters of the dimensionality reduction model + """ + model_params: Params class AuroraExtraInfoNormalization(AuroraExtraInfo): + """ + Information specific to the AURORA algorithm. In particular, it contains + the normalization parameters for the observations. + + Args: + model_params: the parameters of the dimensionality reduction model + mean_observations: the mean of observations + std_observations: the std of observations + """ + mean_observations: jnp.ndarray std_observations: jnp.ndarray