diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 43be3835..58a089a6 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -178,9 +178,6 @@ def _update_masked_pareto_front( pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore - first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0] - genotypes_dim = first_leaf.shape[1] - descriptors_dim = new_batch_of_descriptors.shape[1] # gather all data @@ -235,14 +232,11 @@ def _update_masked_pareto_front( front_size = len(pareto_front_fitnesses) # type: ignore new_front_fitness = new_front_fitness[:front_size, :] - genotypes_mask = jnp.repeat( - jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1 - ) new_front_genotypes = jax.tree_util.tree_map( - lambda x: x * genotypes_mask, new_front_genotypes + lambda x: x * new_mask_indices[0], new_front_genotypes ) new_front_genotypes = jax.tree_util.tree_map( - lambda x: x[:front_size, :], new_front_genotypes + lambda x: x[:front_size], new_front_genotypes ) descriptors_mask = jnp.repeat( @@ -297,25 +291,31 @@ def _add_one( index = index.astype(jnp.int32) - # get cell data - cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes) - cell_fitness = carry.fitnesses[index] - cell_descriptor = carry.descriptors[index] + # get current repertoire cell data + cell_genotype = jax.tree_util.tree_map( + lambda x: x[index][0], carry.genotypes + ) + cell_fitness = carry.fitnesses[index][0] + cell_descriptor = carry.descriptors[index][0] cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1) + new_genotypes = jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, axis=0), genotype + ) + # update pareto front ( cell_fitness, - cell_genotype, + cell_genotype, # new pf for cell cell_descriptor, cell_mask, ) = self._update_masked_pareto_front( - pareto_front_fitnesses=cell_fitness.squeeze(axis=0), - pareto_front_genotypes=cell_genotype.squeeze(axis=0), - pareto_front_descriptors=cell_descriptor.squeeze(axis=0), - mask=cell_mask.squeeze(axis=0), + pareto_front_fitnesses=cell_fitness, + pareto_front_genotypes=cell_genotype, + pareto_front_descriptors=cell_descriptor, + mask=cell_mask, new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0), - new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0), + new_batch_of_genotypes=new_genotypes, new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0), new_mask=jnp.zeros(shape=(1,), dtype=bool), )