From 1aba76e19b3bbf26ec23aed5acdd7eb105fb9c22 Mon Sep 17 00:00:00 2001 From: Felix Date: Mon, 27 Feb 2023 20:21:57 +0200 Subject: [PATCH 01/26] WIP: add all aurora components - passes tests --- qdax/core/aurora.py | 160 ++++ .../containers/unstructured_repertoire.py | 754 ++++++++++++++++++ qdax/environments/bd_extractors.py | 69 +- qdax/environments/exploration_wrappers.py | 4 +- qdax/tasks/brax_envs.py | 62 +- qdax/utils/seq2seq_model.py | 212 +++++ qdax/utils/train_seq2seq.py | 244 ++++++ requirements.txt | 2 +- setup.py | 2 +- tests/core_test/aurora_test.py | 309 +++++++ 10 files changed, 1813 insertions(+), 5 deletions(-) create mode 100644 qdax/core/aurora.py create mode 100644 qdax/core/containers/unstructured_repertoire.py create mode 100644 qdax/utils/seq2seq_model.py create mode 100644 qdax/utils/train_seq2seq.py create mode 100644 tests/core_test/aurora_test.py diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py new file mode 100644 index 00000000..51766635 --- /dev/null +++ b/qdax/core/aurora.py @@ -0,0 +1,160 @@ +"""Core class of the AURORA algorithm.""" + +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import jax +import jax.numpy as jnp +from chex import ArrayTree + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.types import Centroid, Descriptor, Fitness, Genotype, Metrics, Params, RNGKey + + +class AURORA: + """ + Core elements of the AURORA algorithm. + + Args: + scoring_function: a function that takes a batch of genotypes and compute + their fitnesses and descriptors + emitter: an emitter is used to suggest offsprings given a MAPELites + repertoire. It has two compulsory functions. A function that takes + emits a new population, and a function that update the internal state + of the emitter. + metrics_function: a function that takes a MAP-Elites repertoire and compute + any useful metric to track its evolution + """ + + def __init__( + self, + scoring_function: Callable[ + [Genotype, RNGKey, Params, jnp.ndarray, jnp.ndarray], + Tuple[Fitness, Descriptor, ArrayTree, RNGKey], + ], + emitter: Emitter, + metrics_function: Callable[[MapElitesRepertoire], Metrics], + ) -> None: + self._scoring_function = scoring_function + self._emitter = emitter + self._metrics_function = metrics_function + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, + init_genotypes: Genotype, + centroids: Centroid, + random_key: RNGKey, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + l_value: jnp.ndarray, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: + """ + Initialize a Map-Elites grid with an initial population of genotypes. Requires + the definition of centroids that can be computed with any method such as + CVT or Euclidean mapping. + + Args: + init_genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + centroids: tesselation centroids of shape (batch_size, num_descriptors) + random_key: a random key used for stochastic operations. + + Returns: + an initialized MAP-Elite repertoire with the initial state of the emitter. + """ + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + init_genotypes, + random_key, + model_params, + mean_observations, + std_observations, + ) + + repertoire = UnstructuredRepertoire.init( + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + observations=extra_scores["last_valid_observations"], # type: ignore + l_value=l_value, + ) + # get initial state of the emitter + emitter_state, random_key = self._emitter.init( + init_genotypes=init_genotypes, random_key=random_key + ) + + # update emitter state + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + + return repertoire, emitter_state, random_key + + @partial(jax.jit, static_argnames=("self",)) + def update( + self, + repertoire: MapElitesRepertoire, + emitter_state: Optional[EmitterState], + random_key: RNGKey, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: + """ + Performs one iteration of the MAP-Elites algorithm. + 1. A batch of genotypes is sampled in the archive and the genotypes are copied. + 2. The copies are mutated and crossed-over + 3. The obtained offsprings are scored and then added to the archive. + + Args: + repertoire: the MAP-Elites repertoire + emitter_state: state of the emitter + random_key: a jax PRNG random key + + Results: + the updated MAP-Elites repertoire + the updated (if needed) emitter state + metrics about the updated repertoire + a new jax PRNG key + """ + # generate offsprings with the emitter + genotypes, random_key = self._emitter.emit( + repertoire, emitter_state, random_key + ) + # scores the offsprings + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + genotypes, + random_key, + model_params, + mean_observations, + std_observations, + ) + + # add genotypes and observations in the repertoire + repertoire = repertoire.add( + genotypes, descriptors, fitnesses, extra_scores["last_valid_observations"] + ) + + # update emitter state after scoring is made + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + + # update the metrics + metrics = self._metrics_function(repertoire) + + return repertoire, emitter_state, metrics, random_key diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py new file mode 100644 index 00000000..9c0fe15a --- /dev/null +++ b/qdax/core/containers/unstructured_repertoire.py @@ -0,0 +1,754 @@ +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import flax +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from qdax.types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Observation, + RNGKey, +) + + +@partial(jax.jit, static_argnames=("k_nn",)) +def get_cells_indices( + batch_of_descriptors: jnp.ndarray, centroids: jnp.ndarray, k_nn: int +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns the array of cells indices for a batch of descriptors + given the centroids of the grid. + + Args: + batch_of_descriptors: a batch of descriptors + of shape (batch_size, num_descriptors) + centroids: centroids array of shape (num_centroids, num_descriptors) + + Returns: + the indices of the centroids corresponding to each vector of descriptors + in the batch with shape (batch_size,) + """ + + def _get_cells_indices( + descriptors: jnp.ndarray, + centroids: jnp.ndarray, + k_nn: int, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + set_of_descriptors of shape (1, num_descriptors) + centroids of shape (num_centroids, num_descriptors) + """ + + # distances = jnp.sum(jnp.square(jnp.subtract(descriptors, centroids)), axis=-1) + distances = jax.vmap(jnp.linalg.norm)(descriptors - centroids) + ## Negating distances because we want the smallest ones + min_dist, min_args = jax.lax.top_k(-1 * distances, k_nn) + # return jnp.argmin(distances),jnp.min(distances) + return min_args, -1 * min_dist + + # func = jax.vmap(lambda x: _get_cells_indices(x, centroids,k_nn),in_axes=(0,None,None,)) + func = jax.vmap( + _get_cells_indices, + in_axes=( + 0, + None, + None, + ), + ) + + # return func(batch_of_descriptors) + return func(batch_of_descriptors, centroids, k_nn) + + +@jax.jit +def intra_batch_comp( + normed, + current_index, + normed_all, + eval_scores, + l_value, +): + + ## Check for individuals that are Nans, we remove them at the end + not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) + ## Fill in Nans to do computations + normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) + eval_scores = jnp.where( + jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores + ) + ## If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) + additional_score = jnp.where( + jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 + ) + additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) + ## Add scores to empty individuals + eval_scores = jnp.where( + jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores + ) + ##Virtual eval_scores + eval_scores = eval_scores + additional_scores + ## For each point we check what other points are the closest ones. + knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( + -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] + ) + ## We negated the scores to use top_k so we reverse it. + knn_relevant_scores = knn_relevant_scores * -1 + + ##Check if the individual is close enough to compare (under l-value) + fitness = jnp.where(jnp.squeeze(knn_relevant_scores < l_value), True, False) + ## We want to eliminate the same individual (distance 0) + # fitness = jnp.where(knn_relevant_scores==0.0,False,fitness) + fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) + current_fitness = jnp.squeeze( + eval_scores.at[knn_relevant_indices.at[0].get()].get() + ) + + ## Is the fitness of the other individual higher? + ## If both are True then we discard the current individual since this individual would be replaced by the better one. + discard_indiv = jnp.logical_and( + jnp.where( + eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False + ), + fitness, + ).any() + ## Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + discard_indiv = jnp.logical_or(discard_indiv, not_existent) + + ## Negate to know if we keep the individual + return jnp.logical_not(discard_indiv) + + +@jax.jit +def intra_batch_comp_relevant( + normed, + current_index, + normed_all, + eval_scores, + relevant_l_values, +): + + ## Check for individuals that are Nans, we remove them at the end + not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) + ## Fill in Nans to do computations + normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) + eval_scores = jnp.where( + jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores + ) + ## If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) + additional_score = jnp.where( + jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 + ) + additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) + ## Add scores to empty individuals + eval_scores = jnp.where( + jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores + ) + ##Virtual eval_scores + eval_scores = eval_scores + additional_scores + ## For each point we check what other points are the closest ones. + knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( + -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] + ) + ## We negated the scores to use top_k so we reverse it. + knn_relevant_scores = knn_relevant_scores * -1 + + ##Check if the individual is close enough to compare (under l-value) + fitness = jnp.where( + jnp.squeeze(knn_relevant_scores < relevant_l_values), True, False + ) + ## We want to eliminate the same individual (distance 0) + # fitness = jnp.where(knn_relevant_scores==0.0,False,fitness) + fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) + current_fitness = jnp.squeeze( + eval_scores.at[knn_relevant_indices.at[0].get()].get() + ) + + ## Is the fitness of the other individual higher? + ## If both are True then we discard the current individual since this individual would be replaced by the better one. + discard_indiv = jnp.logical_and( + jnp.where( + eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False + ), + fitness, + ).any() + ## Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + discard_indiv = jnp.logical_or(discard_indiv, not_existent) + + ## Negate to know if we keep the individual + return jnp.logical_not(discard_indiv) + + +class UnstructuredRepertoire(flax.struct.PyTreeNode): + """ + Class for the unstructured repertoire in Map Elites. + + Args: + genotypes: a PyTree containing all the genotypes in the repertoire ordered + by the centroids. Each leaf has a shape (num_centroids, num_features). The + PyTree can be a simple Jax array or a more complex nested structure such + as to represent parameters of neural network in Flax. + fitnesses: an array that contains the fitness of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + descriptors: an array that contains the descriptors of solutions in each cell + of the repertoire, ordered by centroids. The array shape + is (num_centroids, num_descriptors). + centroids: an array the contains the centroids of the tesselation. The array + shape is (num_centroids, num_descriptors). + """ + + genotypes: Genotype + fitnesses: Fitness + descriptors: Descriptor + centroids: Centroid + observations: ExtraScores + ages: jnp.ndarray + l_value: jnp.ndarray + + def save(self, path: str = "./") -> None: + """Saves the grid on disk in the form of .npy files. + + Flattens the genotypes to store it with .npy format. Supposes that + a user will have access to the reconstruction function when loading + the genotypes. + + Args: + path: Path where the data will be saved. Defaults to "./". + """ + + def flatten_genotype(genotype: Genotype) -> jnp.ndarray: + flatten_genotype, _unravel_pytree = ravel_pytree(genotype) + return flatten_genotype + + # flatten all the genotypes + flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes) + + # save data + jnp.save(path + "genotypes.npy", flat_genotypes) + jnp.save(path + "fitnesses.npy", self.fitnesses) + jnp.save(path + "descriptors.npy", self.descriptors) + jnp.save(path + "centroids.npy", self.centroids) + jnp.save(path + "observations.npy", self.observations) + jnp.save(path + "l_value.npy", self.l_value) + jnp.save(path + "ages.npy", self.ages) + + @classmethod + def load( + cls, reconstruction_fn: Callable, path: str = "./" + ) -> UnstructuredRepertoire: + """Loads a MAP Elites Grid. + + Args: + reconstruction_fn: Function to reconstruct a PyTree + from a flat array. + path: Path where the data is saved. Defaults to "./". + + Returns: + A MAP Elites Repertoire. + """ + + flat_genotypes = jnp.load(path + "genotypes.npy") + genotypes = jax.vmap(reconstruction_fn)(flat_genotypes) + + fitnesses = jnp.load(path + "fitnesses.npy") + descriptors = jnp.load(path + "descriptors.npy") + centroids = jnp.load(path + "centroids.npy") + observations = jnp.load(path + "observations.npy") + l_value = jnp.load(path + "l_value.npy") + ages = jnp.load(path + "ages.npy") + + return UnstructuredRepertoire( + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + observations=observations, + l_value=l_value, + ages=ages, + ) + + @jax.jit + def add( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_observations: Observation, + ) -> UnstructuredRepertoire: + + ## We need to replace all the descriptors that are not filled with jnp inf + filtered_descriptors = jnp.where( + jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1), + jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf), + self.descriptors, + ) + + batch_of_indices, batch_of_distances = get_cells_indices( + batch_of_descriptors, filtered_descriptors, 2 + ) + + second_neighbours = batch_of_distances.at[ + ..., 1 + ].get() # Save the second nearest neighbours to check a condition + batch_of_indices = batch_of_indices.at[ + ..., 0 + ].get() ## Keep the Nearest neighbours + batch_of_distances = batch_of_distances.at[ + ..., 0 + ].get() ## Keep the Nearest neighbours + + # We remove individuals that are too close to the second nn. + # This avoids having clusters of individuals after adding them. + not_novel_enough = jnp.where( + jnp.squeeze(second_neighbours <= self.l_value), True, False + ) + + # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) + batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) + + num_centroids = self.centroids.shape[0] + + ### TODO Doesn't Work if Archive is full. Need to use the closest individuals in that case. + empty_indexes = jnp.squeeze( + jnp.nonzero( + jnp.where(jnp.isinf(self.fitnesses), 1, 0), + size=batch_of_indices.shape[0], + fill_value=-1, + )[0] + ) + batch_of_indices = jnp.where( + jnp.squeeze(batch_of_distances <= self.l_value), + jnp.squeeze(batch_of_indices), + -1, + ) + + sorted_bds = jax.lax.top_k( + -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] + )[ + 1 + ] ## We get all the indices of the empty bds first and then the filled ones (because of -1) + batch_of_indices = jnp.where( + jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value), + batch_of_indices.at[sorted_bds].get(), + empty_indexes, + ) + + batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + + ## ReIndexing of all the inputs to the correct sorted way + batch_of_distances = batch_of_distances.at[sorted_bds].get() + batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() + batch_of_genotypes = jax.tree_map( + lambda x: x.at[sorted_bds].get(), batch_of_genotypes + ) + # obs = obs.at[sorted_bds].get() + batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() + batch_of_observations = batch_of_observations.at[sorted_bds].get() + not_novel_enough = not_novel_enough.at[sorted_bds].get() + # dead = dead.at[sorted_bds].get() + + ## Check to find Individuals with same BD within the Batch + keep_indiv = jax.jit( + jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0)) + )( + batch_of_descriptors.squeeze(), + jnp.arange( + 0, batch_of_descriptors.shape[0], 1 + ), ## We do this to keep track of where we are in the batch to assure right comparisons + batch_of_descriptors.squeeze(), + batch_of_fitnesses.squeeze(), + self.l_value, + ) + + keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough)) + + # keep_indiv = jax.vmap(intra_batch_comp, in_axes=(0,0,None,None,None), out_axes=(0))(batch_of_descriptors.squeeze(),jnp.arange(0,batch_of_descriptors.shape[0],1),batch_of_descriptors.squeeze(),batch_of_fitnesses.squeeze(),self.l_value) + # get fitness segment max + best_fitnesses = jax.ops.segment_max( + batch_of_fitnesses, + batch_of_indices.astype(jnp.int32).squeeze(), + num_segments=num_centroids, + ) + + cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0) + + # put dominated fitness to -jnp.inf + batch_of_fitnesses = jnp.where( + batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + ) + + # get addition condition + grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0) + addition_condition = batch_of_fitnesses > current_fitnesses + addition_condition = jnp.logical_and( + addition_condition, jnp.expand_dims(keep_indiv, axis=-1) + ) + + # assign fake position when relevant : num_centroids is out of bounds + batch_of_indices = jnp.where( + addition_condition, x=batch_of_indices, y=num_centroids + ) + + # create new grid + new_grid_genotypes = jax.tree_map( + lambda grid_genotypes, new_genotypes: grid_genotypes.at[ + batch_of_indices.squeeze() + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set( + batch_of_fitnesses.squeeze() + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set( + batch_of_descriptors.squeeze() + ) + + new_observations = self.observations.at[batch_of_indices.squeeze()].set( + batch_of_observations.squeeze() + ) + + new_ages = self.ages.at[batch_of_indices.squeeze()].set(0.0) + 1 + + return UnstructuredRepertoire( + genotypes=new_grid_genotypes, + fitnesses=new_fitnesses.squeeze(), + descriptors=new_descriptors.squeeze(), + centroids=new_descriptors.squeeze(), + observations=new_observations.squeeze(), + l_value=self.l_value, + ages=new_ages, + ) + + @partial(jax.jit, static_argnames=("num_samples",)) + def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + """ + Sample elements in the grid. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + random_key, sub_key = jax.random.split(random_key) + grid_empty = self.fitnesses == -jnp.inf + p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) + + samples = jax.tree_map( + lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), + self.genotypes, + ) + + return samples, random_key + + @classmethod + def init( + cls, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + centroids: Centroid, + observations: ExtraScores, + l_value: jnp.ndarray, + ages: Optional[jnp.ndarray] = None, + ) -> UnstructuredRepertoire: + """ + Initialize a Map-Elites repertoire with an initial population of genotypes. + Requires the definition of centroids that can be computed with any method + such as CVT or Euclidean mapping. + + Note: this function has been kept outside of the object MapElites, so it can + be called easily called from other modules. + + Args: + genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + fitnesses: fitness of the initial genotypes of shape (batch_size,) + descriptors: descriptors of the initial genotypes + of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) + + Returns: + an initialized MAP-Elite repertoire + """ + + # Initialize grid with default values + num_centroids = centroids.shape[0] + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + default_genotypes = jax.tree_map( + lambda x: jnp.full( + shape=(num_centroids,) + x.shape[1:], fill_value=jnp.nan + ), + genotypes, + ) + default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1])) + + default_observations = jnp.full( + shape=(num_centroids,) + observations.shape[1:], fill_value=jnp.nan + ) + + if ages is None: + ages = jnp.zeros(shape=num_centroids) + + repertoire = UnstructuredRepertoire( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + centroids=centroids, + observations=default_observations, + l_value=l_value, + ages=ages, + ) + + # return new_repertoire # type: ignore + return repertoire.add(genotypes, descriptors, fitnesses, observations) + + @jax.jit + def add_relevant( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_observations: Observation, + proximity_scores: jnp.ndarray, + ) -> UnstructuredRepertoire: + + # Calculating new l values + new_l_values = self.l_value / proximity_scores + + # We need to replace all the descriptors that are not filled with jnp inf + filtered_descriptors = jnp.where( + jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1), + jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf), + self.descriptors, + ) + + batch_of_indices, batch_of_distances = get_cells_indices( + batch_of_descriptors, filtered_descriptors, 2 + ) + + second_neighbours = batch_of_distances.at[ + ..., 1 + ].get() # Save the second nearest neighbours to check a condition + batch_of_indices = batch_of_indices.at[ + ..., 0 + ].get() ## Keep the Nearest neighbours + batch_of_distances = batch_of_distances.at[ + ..., 0 + ].get() ## Keep the Nearest neighbours + + # We remove individuals that are too close to the second nn. + # This avoids having clusters of individuals after adding them. + not_novel_enough = jnp.where( + jnp.squeeze(second_neighbours <= new_l_values), True, False + ) + + # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) + batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) + + num_centroids = self.centroids.shape[0] + + ### TODO Doesn't Work if Archive is full. Need to use the closest individuals in that case. + empty_indexes = jnp.squeeze( + jnp.nonzero( + jnp.where(jnp.isinf(self.fitnesses), 1, 0), + size=batch_of_indices.shape[0], + fill_value=-1, + )[0] + ) + batch_of_indices = jnp.where( + jnp.squeeze(batch_of_distances <= new_l_values), + jnp.squeeze(batch_of_indices), + -1, + ) + + sorted_bds = jax.lax.top_k( + -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] + )[ + 1 + ] ## We get all the indices of the empty bds first and then the filled ones (because of -1) + batch_of_indices = jnp.where( + jnp.squeeze( + batch_of_distances.at[sorted_bds].get() + <= new_l_values.at[sorted_bds].get() + ), + batch_of_indices.at[sorted_bds].get(), + empty_indexes, + ) + + batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) + + ## ReIndexing of all the inputs to the correct sorted way + batch_of_distances = batch_of_distances.at[sorted_bds].get() + batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() + batch_of_genotypes = jax.tree_map( + lambda x: x.at[sorted_bds].get(), batch_of_genotypes + ) + # obs = obs.at[sorted_bds].get() + batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() + batch_of_observations = batch_of_observations.at[sorted_bds].get() + not_novel_enough = not_novel_enough.at[sorted_bds].get() + new_l_values = new_l_values.at[sorted_bds].get() + # dead = dead.at[sorted_bds].get() + + # filtered_l = jnp.where(new_l_values>self.l_value,self.l_value,new_l_values) + ## Check to find Individuals with same BD within the Batch + keep_indiv = jit( + jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, 0), out_axes=(0)) + )( + batch_of_descriptors.squeeze(), + jnp.arange( + 0, batch_of_descriptors.shape[0], 1 + ), ## We do this to keep track of where we are in the batch to assure right comparisons + batch_of_descriptors.squeeze(), + batch_of_fitnesses.squeeze(), + new_l_values, + ) + + # keep_indiv = jnp.logical_and(keep_indiv,jnp.logical_not(not_novel_enough)) + + # keep_indiv = jax.vmap(intra_batch_comp, in_axes=(0,0,None,None,None), out_axes=(0))(batch_of_descriptors.squeeze(),jnp.arange(0,batch_of_descriptors.shape[0],1),batch_of_descriptors.squeeze(),batch_of_fitnesses.squeeze(),self.l_value) + # get fitness segment max + best_fitnesses = jax.ops.segment_max( + batch_of_fitnesses, + batch_of_indices.astype(jnp.int32).squeeze(), + num_segments=num_centroids, + ) + + cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0) + + # put dominated fitness to -jnp.inf + batch_of_fitnesses = jnp.where( + batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + ) + + # get addition condition + grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0) + addition_condition = batch_of_fitnesses > current_fitnesses + addition_condition = jnp.logical_and( + addition_condition, jnp.expand_dims(keep_indiv, axis=-1) + ) + print(addition_condition) + print(batch_of_indices) + print(batch_of_descriptors) + print(batch_of_distances) + print(new_l_values) + + # assign fake position when relevant : num_centroids is out of bounds + batch_of_indices = jnp.where( + addition_condition, x=batch_of_indices, y=num_centroids + ) + + # create new grid + new_grid_genotypes = jax.tree_map( + lambda grid_genotypes, new_genotypes: grid_genotypes.at[ + batch_of_indices.squeeze() + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set( + batch_of_fitnesses.squeeze() + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set( + batch_of_descriptors.squeeze() + ) + + new_observations = self.observations.at[batch_of_indices.squeeze()].set( + batch_of_observations.squeeze() + ) + + new_ages = self.ages.at[batch_of_indices.squeeze()].set(0.0) + 1 + + return UnstructuredRepertoire( + genotypes=new_grid_genotypes, + fitnesses=new_fitnesses.squeeze(), + descriptors=new_descriptors.squeeze(), + centroids=new_descriptors.squeeze(), + observations=new_observations.squeeze(), + l_value=self.l_value, + ages=new_ages, + ) + + @classmethod + def init_relevant( + cls, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + centroids: Centroid, + observations: ExtraScores, + l_value: float, + proximity_scores: jnp.ndarray, + ages: Optional[jnp.ndarray] = None, + ) -> UnstructuredRepertoire: + """ + Initialize a Map-Elites repertoire with an initial population of genotypes. + Requires the definition of centroids that can be computed with any method + such as CVT or Euclidean mapping. + + Note: this function has been kept outside of the object MapElites, so it can + be called easily called from other modules. + + Args: + genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + fitnesses: fitness of the initial genotypes of shape (batch_size,) + descriptors: descriptors of the initial genotypes + of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) + + Returns: + an initialized MAP-Elite repertoire + """ + + # Initialize grid with default values + num_centroids = centroids.shape[0] + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + default_genotypes = jax.tree_map( + lambda x: jnp.full( + shape=(num_centroids,) + x.shape[1:], fill_value=jnp.nan + ), + genotypes, + ) + default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1])) + + default_observations = jnp.full( + shape=(num_centroids,) + observations.shape[1:], fill_value=jnp.nan + ) + + if ages is None: + ages = jnp.zeros(shape=num_centroids) + repertoire = UnstructuredRepertoire( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + centroids=centroids, + observations=default_observations, + l_value=l_value, + ages=ages, + ) + + # return new_repertoire + return repertoire.add_relevant( + genotypes, descriptors, fitnesses, observations, proximity_scores + ) diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index 4ec159c6..e81348a1 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -2,7 +2,8 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.types import Descriptor +from qdax.types import Descriptor, Params +from qdax.utils import train_seq2seq def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: @@ -36,3 +37,69 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri descriptors = descriptors / jnp.sum(1.0 - mask, axis=1) return descriptors + + +def get_aurora_bd( + data: QDTransition, + mask: jnp.ndarray, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + option: str = "full", + hidden_size: int = 10, + padding: bool = False, +) -> Descriptor: + """Compute final aurora embedding. + + This function suppose that state descriptor is the xy position, as it + just select the final one of the state descriptors given. + """ + # reshape mask for bd extraction + mask = jnp.expand_dims(mask, axis=-1) + + print("Mask: ", mask) + + # Get behavior descriptor + last_index = jnp.int32(jnp.sum(1.0 - mask, axis=1)) - 1 + + ## Doesn't Make Sense to take last valid Observation for Aurora, we take the full trajectory + # observations = jax.vmap(lambda x, y: x[y,:])(data.obs[:,::10,:15], last_index)) + + # TODO: try with all observations + # TODO: try with a padding + + state_obs = data.obs[:, ::10, :25] + filtered_mask = mask[:, ::10, :] + + # add the x/y position - (batch_size, traj_length, 2) + state_desc = data.state_desc[:, ::10] + + print("State Observations: ", state_obs) + print("XY positions: ", state_desc) + + if option == "full": + observations = jnp.concatenate([state_desc, state_obs], axis=-1) + print("New observations: ", observations) + elif option == "no_sd": + observations = state_obs + elif option == "only_sd": + observations = state_desc + + # add padding when the episode is done + if padding: + observations = jnp.where(filtered_mask, x=jnp.array(0.0), y=observations) + + # print("Observation: ", observations) + # print("Padded observation: ", padded_observations) + + model = train_seq2seq.get_model( + observations.shape[-1], True, hidden_size + ) ## lstm seq2seq + normalized_observations = (observations - mean_observations) / std_observations + descriptors = model.apply( + {"params": model_params}, normalized_observations, method=model.encode + ) + + print("Observations out of get aurora bd: ", observations) + + return descriptors.squeeze(), observations.squeeze() diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index b6754c28..80635960 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -89,7 +89,9 @@ # those are the configs from the official brax repo ENV_SYSTEM_CONFIG = { "ant": brax.envs.ant._SYSTEM_CONFIG, - "halfcheetah": brax.envs.half_cheetah._SYSTEM_CONFIG, + "halfcheetah": brax.envs.halfcheetah._SYSTEM_CONFIG + if brax.__version__ == "0.0.12" + else brax.envs.half_cheetah._SYSTEM_CONFIG, "walker2d": brax.envs.walker2d._SYSTEM_CONFIG, "hopper": brax.envs.hopper._SYSTEM_CONFIG, # "humanoid": brax.envs.humanoid._SYSTEM_CONFIG, diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 101d4a39..d94d49c1 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -1,6 +1,6 @@ import functools from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union import brax.envs import flax.linen as nn @@ -338,3 +338,63 @@ def create_default_brax_task_components( ) return env, policy_network, scoring_fn, random_key + + +def scoring_aurora_function( + policies_params: Genotype, + random_key: RNGKey, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + init_states: brax.envs.State, + episode_length: int, + play_step_fn: Callable[ + [EnvState, Params, RNGKey], + Tuple[EnvState, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, Dict[str, Union[jnp.ndarray, QDTransition]], RNGKey]: + """Evaluates policies contained in flatten_variables in parallel + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarly + evaluated with the same environment everytime, this won't be determinist. + + When the init states are different, this is not purely stochastic. This + choice was made for performance reason, as the reset function of brax envs + is quite time consuming. If pure stochasticity of the environment is needed + for a use case, please open an issue. + + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll, + episode_length=episode_length, + play_step_fn=play_step_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors, observations = behavior_descriptor_extractor( + data, mask, model_params, mean_observations, std_observations + ) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + "last_valid_observations": observations, + }, + random_key, + ) diff --git a/qdax/utils/seq2seq_model.py b/qdax/utils/seq2seq_model.py new file mode 100644 index 00000000..0747bbb2 --- /dev/null +++ b/qdax/utils/seq2seq_model.py @@ -0,0 +1,212 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""seq2seq example: Mode code.""" + +# See issue #620. +# pytype: disable=wrong-keyword-args + +import functools +from typing import Any, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn + +Array = Any +PRNGKey = Any + + +class EncoderLSTM(nn.Module): + """EncoderLSTM Module wrapped in a lifted scan transform.""" + + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=1, + out_axes=1, + split_rngs={"params": False}, + ) + @nn.compact + def __call__( + self, carry: Tuple[Array, Array], x: Array + ) -> Tuple[Tuple[Array, Array], Array]: + """Applies the module.""" + lstm_state, is_eos = carry + new_lstm_state, y = nn.LSTMCell()(lstm_state, x) + + def select_carried_state(new_state, old_state): + return jnp.where(is_eos[:, np.newaxis], old_state, new_state) + + # LSTM state is a tuple (c, h). + carried_lstm_state = tuple( + select_carried_state(*s) for s in zip(new_lstm_state, lstm_state) + ) + # Update `is_eos`. + # is_eos = jnp.logical_or(is_eos, x[:, 8]) + return (carried_lstm_state, is_eos), y + + @staticmethod + def initialize_carry(batch_size: int, hidden_size: int): + # Use a dummy key since the default state init fn is just zeros. + return nn.LSTMCell.initialize_carry( + jax.random.PRNGKey(0), (batch_size,), hidden_size + ) + + +class Encoder(nn.Module): + """LSTM encoder, returning state after finding the EOS token in the input.""" + + hidden_size: int + + @nn.compact + def __call__(self, inputs: Array): + # inputs.shape = (batch_size, seq_length, vocab_size). + batch_size = inputs.shape[0] + lstm = EncoderLSTM(name="encoder_lstm") + init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) + # We use the `is_eos` array to determine whether the encoder should carry + # over the last lstm state, or apply the LSTM cell on the previous state. + init_is_eos = jnp.zeros(batch_size, dtype=bool) + init_carry = (init_lstm_state, init_is_eos) + (final_state, _), _ = lstm(init_carry, inputs) + return final_state + + +class DecoderLSTM(nn.Module): + """DecoderLSTM Module wrapped in a lifted scan transform. + + Attributes: + teacher_force: See docstring on Seq2seq module. + obs_size: Size of the observations. + """ + + teacher_force: bool + obs_size: int + + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=1, + out_axes=1, + split_rngs={"params": False, "lstm": True}, + ) + @nn.compact + def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: + """Applies the DecoderLSTM model.""" + lstm_state, last_prediction = carry + if not self.teacher_force: + x = last_prediction + lstm_state, y = nn.LSTMCell()(lstm_state, x) + logits = nn.Dense(features=self.obs_size)(y) + + return (lstm_state, logits), (logits, logits) + + +class Decoder(nn.Module): + """LSTM decoder. + + Attributes: + init_state: [batch_size, hidden_size] + Initial state of the decoder (i.e., the final state of the encoder). + teacher_force: See docstring on Seq2seq module. + obs_size: Size of the observations. + """ + + teacher_force: bool + obs_size: int + + @nn.compact + def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]: + """Applies the decoder model. + + Args: + inputs: [batch_size, max_output_len-1, obs_size] + Contains the inputs to the decoder at each time step (only used when not + using teacher forcing). Since each token at position i is fed as input + to the decoder at position i+1, the last token is not provided. + + Returns: + Pair (logits, predictions), which are two arrays of respectively decoded + logits and predictions (in one hot-encoding format). + """ + lstm = DecoderLSTM(teacher_force=self.teacher_force, obs_size=self.obs_size) + init_carry = (init_state, inputs[:, 0]) + _, (logits, predictions) = lstm(init_carry, inputs) + return logits, predictions + + +class Seq2seq(nn.Module): + """Sequence-to-sequence class using encoder/decoder architecture. + + Attributes: + teacher_force: whether to use `decoder_inputs` as input to the decoder at + every step. If False, only the first input (i.e., the "=" token) is used, + followed by samples taken from the previous output logits. + hidden_size: int, the number of hidden dimensions in the encoder and decoder + LSTMs. + obs_size: the size of the observations. + eos_id: EOS id. + """ + + teacher_force: bool + hidden_size: int + obs_size: int + + def setup(self): + self.encoder = Encoder(hidden_size=self.hidden_size) + self.decoder = Decoder(teacher_force=self.teacher_force, obs_size=self.obs_size) + + @nn.compact + def __call__( + self, encoder_inputs: Array, decoder_inputs: Array + ) -> Tuple[Array, Array]: + """Applies the seq2seq model. + + Args: + encoder_inputs: [batch_size, max_input_length, obs_size]. + padded batch of input sequences to encode. + decoder_inputs: [batch_size, max_output_length, obs_size]. + padded batch of expected decoded sequences for teacher forcing. + When sampling (i.e., `teacher_force = False`), only the first token is + input into the decoder (which is the token "="), and samples are used + for the following inputs. The second dimension of this tensor determines + how many steps will be decoded, regardless of the value of + `teacher_force`. + + Returns: + Pair (logits, predictions), which are two arrays of length `batch_size` + containing respectively decoded logits and predictions (in one hot + encoding format). + """ + # Encode inputs. + # print(encoder_inputs) + init_decoder_state = self.encoder(encoder_inputs) + # print(init_decoder_state) + # Encoder(hidden_size=self.hidden_size) + # Decode outputs. + logits, predictions = self.decoder(decoder_inputs, init_decoder_state) + # Decoder( + # init_state=init_decoder_state, + # teacher_force=self.teacher_force, + # obs_size=self.obs_size)(decoder_inputs[:, :-1],init_decoder_state) + + return logits, predictions + + def encode(self, encoder_inputs: Array): + # Encode inputs. + init_decoder_state = self.encoder(encoder_inputs) + final_output, hidden_state = init_decoder_state + return final_output diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py new file mode 100644 index 00000000..e2481b18 --- /dev/null +++ b/qdax/utils/train_seq2seq.py @@ -0,0 +1,244 @@ +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""seq2seq addition example.""" + +# See issue #620. +# pytype: disable=wrong-keyword-args + +from typing import Any, Dict, Tuple + +import jax +import jax.numpy as jnp +import optax +from absl import flags +from flax.training import train_state + +from qdax.utils.seq2seq_model import Seq2seq + +Array = Any +FLAGS = flags.FLAGS +PRNGKey = Any + +flags.DEFINE_string("workdir", default=".", help="Where to store log output.") + +flags.DEFINE_float( + "learning_rate", default=0.003, help=("The learning rate for the Adam optimizer.") +) + +flags.DEFINE_integer("batch_size", default=128, help=("Batch size for training.")) + +flags.DEFINE_integer("hidden_size", default=16, help=("Hidden size of the LSTM.")) + +flags.DEFINE_integer("num_train_steps", default=10000, help=("Number of train steps.")) + +flags.DEFINE_integer( + "decode_frequency", + default=200, + help=("Frequency of decoding during training, e.g. every 1000 steps."), +) + +flags.DEFINE_integer( + "max_len_query_digit", default=3, help=("Maximum length of a single input digit.") +) + + +def get_model(obs_size, teacher_force: bool = False, hidden_size=10) -> Seq2seq: + return Seq2seq( + teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size + ) + + +def get_initial_params( + model: Seq2seq, rng: PRNGKey, encoder_input_shape +) -> Dict[str, Any]: + """Returns the initial parameters of a seq2seq model.""" + rng1, rng2, rng3 = jax.random.split(rng, 3) + variables = model.init( + {"params": rng1, "lstm": rng2, "dropout": rng3}, + jnp.ones(encoder_input_shape, jnp.float32), + jnp.ones(encoder_input_shape, jnp.float32), + ) + return variables["params"] + + +@jax.jit +def train_step( + state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey +) -> Tuple[train_state.TrainState, Dict[str, float]]: + """Trains one step.""" + lstm_key = jax.random.fold_in(lstm_rng, state.step) + dropout_key, lstm_key = jax.random.split(lstm_key, 2) + # Shift Input by One to avoid leakage + batch_decoder = jnp.roll(batch, shift=1, axis=1) + ### Large number as zero token + batch_decoder = batch_decoder.at[:, 0, :].set(-1000) + + def loss_fn(params): + logits, _ = state.apply_fn( + {"params": params}, + batch, + batch_decoder, + rngs={"lstm": lstm_key, "dropout": dropout_key}, + ) + + def squared_error(x, y): + return jnp.inner(y - x, y - x) / 2.0 + + def mean_squared_error(x, y): + return jnp.inner(y - x, y - x) / x.shape[-1] + + # res = jax.vmap(squared_error)(logits, batch) + # res = jax.vmap(squared_error)(jnp.reshape(logits,(logits.shape[0],-1)),jnp.reshape(batch,(batch.shape[0],-1))) + res = jax.vmap(mean_squared_error)( + jnp.reshape(logits.at[:, :-1, ...].get(), (logits.shape[0], -1)), + jnp.reshape( + batch_decoder.at[:, 1:, ...].get(), (batch_decoder.shape[0], -1) + ), + ) + loss = jnp.mean(res, axis=0) + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss_val, logits), grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + + return state, loss_val + + +def lstm_ae_train(key, repertoire, params, epoch, hidden_size=10): + batch_size = 128 # 2048 + + if epoch > 100: + num_epochs = 25 + alpha = 0.0001 # Gradient step size + else: + num_epochs = 100 + alpha = 0.0001 # Gradient step size + + rng, key, key_selection = jax.random.split(key, 3) + dimensions_data = jnp.prod(jnp.asarray(repertoire.observations.shape[1:])) + + # get the model used (seq2seq) + model = get_model( + repertoire.observations.shape[-1], teacher_force=True, hidden_size=hidden_size + ) + + print("Beginning of the lstm ae training: ") + print("Repertoire observation: ", repertoire.observations) + + print("Repertoire fitnesses: ", repertoire.fitnesses) + + # compute mean/std of the obs for normalization + mean_obs = jnp.nanmean(repertoire.observations, axis=(0, 1)) + std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) + + print("Mean obs - wo NaN: ", mean_obs) + print("Std obs - wo NaN: ", std_obs) + + # TODO: maybe we could just compute this data on the valid dataset + + # create optimizer and optimized state + tx = optax.adam(alpha) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + # size of the repertoire + repertoire_size = repertoire.centroids.shape[0] + print("Repertoire size: ", repertoire_size) + + # number of individuals in the repertoire + num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + print("Number of individuals: ", num_indivs) + + # select repertoire_size indexes going from 0 to num_indivs + # TODO: WHY?? + key_select_p1, rng = jax.random.split(key_selection, 2) + idx_p1 = jax.random.randint( + key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs + ) + print("idx p1: ", idx_p1) + + # TODO: what is the diff with repertoire_size?? + tot_indivs = repertoire.fitnesses.ravel().shape[0] + print("Total individuals: ", tot_indivs) + + # get indexes where fitness is not -inf?? + indexes = jnp.argwhere( + jnp.logical_not(jnp.isinf(repertoire.fitnesses)), size=tot_indivs + ) + indexes = jnp.transpose(indexes, axes=(1, 0)) + print("Indexes: ", indexes) + + # ??? + indiv_indices = jnp.array( + jnp.ravel_multi_index(indexes, repertoire.fitnesses.shape, mode="clip") + ).astype(int) + print("Indiv indices: ", indexes) + + # ??? + valid_indexes = indiv_indices.at[idx_p1].get() + print("Valid indexes: ", valid_indexes) + + # Normalising Dataset + # training_dataset = (repertoire.observations.at[valid_indexes].get()-mean_obs)/std_obs #jnp.where(std_obs==0,mean_obs,std_obs) + steps_per_epoch = repertoire.observations.shape[0] // batch_size + + loss_val = 0.0 + for epoch in range(num_epochs): + rng, shuffle_key = jax.random.split(rng, 2) + valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) + + # TODO: the std where they were NaNs is set to zero. But here we divide by the + # std, so NaNs appear here... + # std_obs += 1e-6 + + std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) + + # create dataset with the observation from the sample of valid indexes + training_dataset = ( + repertoire.observations.at[valid_indexes, ...].get() - mean_obs + ) / std_obs # jnp.where(std_obs==0,mean_obs,std_obs) + training_dataset = training_dataset.at[valid_indexes].get() + + if epoch == 0: + print("Training dataset for first epoch: ", training_dataset) + print("Training dataset first data for first epoch: ", training_dataset[0]) + + for i in range(steps_per_epoch): + batch = jnp.asarray( + training_dataset.at[ + (i * batch_size) : (i * batch_size) + batch_size, :, : + ].get() + ) + # print(batch) + if batch.shape[0] < batch_size: + # print(batch.shape) + continue + state, loss_val = train_step(state, batch, rng) + + ### To see the actual value we cannot jit this function (i.e. the _one_es_epoch function nor the train function) + print("Eval epoch: {}, loss: {:.4f}".format(epoch + 1, loss_val)) + + # TODO: put this in metrics so we can jit the function and see the metrics + # TODO: not urgent because the training is not that long + + # return repertoire.replace(ae_params=state.params,mean_obs=mean_obs,std_obs=std_obs) + + train_step.clear_cache() + del tx + del model + params = state.params + del state + + return params, mean_obs, std_obs diff --git a/requirements.txt b/requirements.txt index b97297fa..13bef5dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ absl-py==1.0.0 -brax==0.0.15 +brax==0.0.12 chex==0.1.5 dm-haiku==0.0.5 flax==0.6.0 diff --git a/setup.py b/setup.py index 2e50e0ea..c180ca3d 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "jinja2<3.1.0", "jumanji>=0.1.3", "flax>=0.6, <0.6.2", - "brax>=0.0.15", + "brax>=0.0.12", "gym>=0.23.1", "numpy>=1.22.3", "scikit-learn>=1.0.2", diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py new file mode 100644 index 00000000..519007eb --- /dev/null +++ b/tests/core_test/aurora_test.py @@ -0,0 +1,309 @@ +"""Tests AURORA implementation""" + +import functools +from typing import Any, Dict, Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.aurora import AURORA +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.environments.bd_extractors import get_aurora_bd +from qdax.tasks.brax_envs import scoring_aurora_function +from qdax.types import EnvState, Params, RNGKey +from qdax.utils import train_seq2seq + + +@pytest.mark.parametrize( + "env_name, batch_size", + [("halfcheetah_uni", 10), ("walker2d_uni", 10), ("hopper_uni", 10)], +) +def test_aurora(env_name: str, batch_size: int) -> None: + batch_size = batch_size + env_name = env_name + episode_length = 100 + num_iterations = 5 + seed = 42 + policy_hidden_layer_sizes = (64, 64) + num_centroids = 50 + + observation_option = "only_sd" + hidden_size = 5 + l_value_init = 0.2 + + log_freq = 5 + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + # Create the initial environment states + random_key, subkey = jax.random.split(random_key) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) + reset_fn = jax.jit(jax.vmap(env.reset)) + init_states = reset_fn(keys) + + # Define the fonction to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + """ + Play an environment step and return the updated state and the transition. + """ + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = functools.partial( + get_aurora_bd, + option=observation_option, + hidden_size=hidden_size, + ) + scoring_fn = functools.partial( + scoring_aurora_function, + init_states=init_states, + episode_length=episode_length, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Define emitter + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + def metrics_fn(repertoire: MapElitesRepertoire) -> Dict: + + # Get metrics + grid_empty = repertoire.fitnesses == -jnp.inf + qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) + # Add offset for positive qd_score + qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) + coverage = 100 * jnp.mean(1.0 - grid_empty) + max_fitness = jnp.max(repertoire.fitnesses) + + return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + + # Instantiate MAP-Elites + aurora = AURORA( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + aurora_dims = hidden_size + centroids = jnp.zeros(shape=(num_centroids, aurora_dims)) + + @jax.jit + def update_scan_fn(carry: Any, unused: Any) -> Any: + # iterate over grid + ( + repertoire, + random_key, + model_params, + mean_observations, + std_observations, + ) = carry + (repertoire, _, metrics, random_key,) = aurora.update( + repertoire, + None, + random_key, + model_params, + mean_observations, + std_observations, + ) + + return ( + (repertoire, random_key, model_params, mean_observations, std_observations), + metrics, + ) + + # Init algorithm + ## AutoEncoder Params and INIT + # observations_dims = (20, 25) + obs_dim = jnp.minimum(env.observation_size, 25) + if observation_option == "full": + observations_dims = (25, obs_dim + 2) # 250 / 10, 25 + 2 + if observation_option == "no_sd": + observations_dims = (25, obs_dim) # 250 / 10, 25 + if observation_option == "only_sd": + observations_dims = (25, 2) # 250 / 10, 2 + + model = train_seq2seq.get_model( + observations_dims[-1], True, hidden_size=hidden_size + ) + random_key, subkey = jax.random.split(random_key) + + # design aurora's schedule + default_update_base = 10 + update_base = int(jnp.ceil(default_update_base / log_freq)) + schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base)) + print("Schedules: ", schedules) + + model_params = train_seq2seq.get_initial_params( + model, subkey, (1, observations_dims[0], observations_dims[-1]) + ) + # model_params = train_seq2seq.get_initial_params(model,subkey,(1,repertoire.observations.shape[1],repertoire.observations.shape[-1])) + print(jax.tree_map(lambda x: x.shape, model_params)) + + mean_observations = jnp.zeros(observations_dims[-1]) + + std_observations = jnp.ones(observations_dims[-1]) + + repertoire, _, random_key = aurora.init( + init_variables, + centroids, + random_key, + model_params, + mean_observations, + std_observations, + l_value_init, + ) + + ## Initializing Means and stds and Aurora + random_key, subkey = jax.random.split(random_key) + model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train( + subkey, repertoire, model_params, 0, hidden_size=hidden_size + ) + + current_step_estimation = 0 + num_iterations = 0 + + # Main loop + n_target = 1024 + + previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target + + iteration = 1 # to be consistent with other exp scripts + while iteration < num_iterations: + + ( + (repertoire, random_key, model_params, mean_observations, std_observations), + metrics, + ) = jax.lax.scan( + update_scan_fn, + (repertoire, random_key, model_params, mean_observations, std_observations), + (), + length=log_freq, + ) + + num_iterations = iteration * log_freq + + # update nb steps estimation + current_step_estimation += batch_size * episode_length * log_freq + + ## Autoencoder Steps and CVC + # individuals_in_repo = jnp.sum(repertoire.fitnesses != -jnp.inf) + + if (iteration + 1) in schedules: + random_key, subkey = jax.random.split(random_key) + + ( + model_params, + mean_observations, + std_observations, + ) = train_seq2seq.lstm_ae_train( + subkey, + repertoire, + model_params, + iteration, + hidden_size=hidden_size, + ) + ### RE-ADDITION OF ALL THE NEW BEHAVIOURAL DESCRIPTORS WITH THE NEW AE + + # model = train_seq2seq.get_model(repertoire.observations.shape[-1],True) ## lstm seq2seq + normalized_observations = ( + repertoire.observations - mean_observations + ) / std_observations + new_descriptors = model.apply( + {"params": model_params}, normalized_observations, method=model.encode + ) + repertoire = repertoire.init( + genotypes=repertoire.genotypes, + centroids=repertoire.centroids, + fitnesses=repertoire.fitnesses, + descriptors=new_descriptors, + observations=repertoire.observations, + l_value=repertoire.l_value, + ) + num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + + elif iteration % 2 == 0: + + num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + + # l_value = repertoire.l_value * (1+1*10e-7*(num_indivs-n_target)) + current_error = num_indivs - n_target + change_rate = current_error - previous_error + prop_gain = 1 * 10e-6 + l_value = ( + repertoire.l_value + + (prop_gain * (current_error)) + + (prop_gain * change_rate) + ) + print(change_rate, current_error) + previous_error = current_error + ## CVC Implementation to keep a Constant number of individuals in the Archive + repertoire = repertoire.init( + genotypes=repertoire.genotypes, + centroids=repertoire.centroids, + fitnesses=repertoire.fitnesses, + descriptors=repertoire.descriptors, + observations=repertoire.observations, + l_value=l_value, + ) + new_num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_aurora(env_name="pointmaze", batch_size=10) From a86bdb0f7a9ad3fc883d3eb4d93aeb5c0d025e07 Mon Sep 17 00:00:00 2001 From: Felix Date: Tue, 28 Feb 2023 11:33:19 +0200 Subject: [PATCH 02/26] WIP - first cleaning --- .../containers/unstructured_repertoire.py | 149 +++++++++--------- 1 file changed, 73 insertions(+), 76 deletions(-) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 9c0fe15a..9e5a1c5d 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -47,14 +47,13 @@ def _get_cells_indices( centroids of shape (num_centroids, num_descriptors) """ - # distances = jnp.sum(jnp.square(jnp.subtract(descriptors, centroids)), axis=-1) distances = jax.vmap(jnp.linalg.norm)(descriptors - centroids) - ## Negating distances because we want the smallest ones + + # Negating distances because we want the smallest ones min_dist, min_args = jax.lax.top_k(-1 * distances, k_nn) - # return jnp.argmin(distances),jnp.min(distances) + return min_args, -1 * min_dist - # func = jax.vmap(lambda x: _get_cells_indices(x, centroids,k_nn),in_axes=(0,None,None,)) func = jax.vmap( _get_cells_indices, in_axes=( @@ -64,7 +63,6 @@ def _get_cells_indices( ), ) - # return func(batch_of_descriptors) return func(batch_of_descriptors, centroids, k_nn) @@ -77,52 +75,59 @@ def intra_batch_comp( l_value, ): - ## Check for individuals that are Nans, we remove them at the end + # Check for individuals that are Nans, we remove them at the end not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) - ## Fill in Nans to do computations + + # Fill in Nans to do computations normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) eval_scores = jnp.where( jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores ) - ## If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) + + # If we do not use a fitness (i.e same fitness everywhere), we create a virtual fitness + # function to add individuals with the same bd additional_score = jnp.where( jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 ) additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) - ## Add scores to empty individuals + + # Add scores to empty individuals eval_scores = jnp.where( jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores ) - ##Virtual eval_scores + # Virtual eval_scores eval_scores = eval_scores + additional_scores - ## For each point we check what other points are the closest ones. + + # For each point we check what other points are the closest ones. knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] ) - ## We negated the scores to use top_k so we reverse it. + # We negated the scores to use top_k so we reverse it. knn_relevant_scores = knn_relevant_scores * -1 - ##Check if the individual is close enough to compare (under l-value) + # Check if the individual is close enough to compare (under l-value) fitness = jnp.where(jnp.squeeze(knn_relevant_scores < l_value), True, False) - ## We want to eliminate the same individual (distance 0) - # fitness = jnp.where(knn_relevant_scores==0.0,False,fitness) + + # We want to eliminate the same individual (distance 0) fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) current_fitness = jnp.squeeze( eval_scores.at[knn_relevant_indices.at[0].get()].get() ) - ## Is the fitness of the other individual higher? - ## If both are True then we discard the current individual since this individual would be replaced by the better one. + # Is the fitness of the other individual higher? + # If both are True then we discard the current individual since this individual would be replaced + # by the better one. discard_indiv = jnp.logical_and( jnp.where( eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False ), fitness, ).any() - ## Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + + # Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) discard_indiv = jnp.logical_or(discard_indiv, not_existent) - ## Negate to know if we keep the individual + # Negate to know if we keep the individual return jnp.logical_not(discard_indiv) @@ -135,54 +140,60 @@ def intra_batch_comp_relevant( relevant_l_values, ): - ## Check for individuals that are Nans, we remove them at the end + # Check for individuals that are Nans, we remove them at the end not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) - ## Fill in Nans to do computations + + # Fill in Nans to do computations normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) eval_scores = jnp.where( jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores ) - ## If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) + + # If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) additional_score = jnp.where( jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 ) additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) - ## Add scores to empty individuals + + # Add scores to empty individuals eval_scores = jnp.where( jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores ) - ##Virtual eval_scores + + # Virtual eval_scores eval_scores = eval_scores + additional_scores - ## For each point we check what other points are the closest ones. + # For each point we check what other points are the closest ones. knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] ) - ## We negated the scores to use top_k so we reverse it. + # We negated the scores to use top_k so we reverse it. knn_relevant_scores = knn_relevant_scores * -1 - ##Check if the individual is close enough to compare (under l-value) + # Check if the individual is close enough to compare (under l-value) fitness = jnp.where( jnp.squeeze(knn_relevant_scores < relevant_l_values), True, False ) - ## We want to eliminate the same individual (distance 0) - # fitness = jnp.where(knn_relevant_scores==0.0,False,fitness) + + # We want to eliminate the same individual (distance 0) fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) current_fitness = jnp.squeeze( eval_scores.at[knn_relevant_indices.at[0].get()].get() ) - ## Is the fitness of the other individual higher? - ## If both are True then we discard the current individual since this individual would be replaced by the better one. + # Is the fitness of the other individual higher? + # If both are True then we discard the current individual since this individual would be replaced + # by the better one. discard_indiv = jnp.logical_and( jnp.where( eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False ), fitness, ).any() - ## Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + + # Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) discard_indiv = jnp.logical_or(discard_indiv, not_existent) - ## Negate to know if we keep the individual + # Negate to know if we keep the individual return jnp.logical_not(discard_indiv) @@ -283,7 +294,7 @@ def add( batch_of_observations: Observation, ) -> UnstructuredRepertoire: - ## We need to replace all the descriptors that are not filled with jnp inf + # We need to replace all the descriptors that are not filled with jnp inf filtered_descriptors = jnp.where( jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1), jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf), @@ -294,15 +305,14 @@ def add( batch_of_descriptors, filtered_descriptors, 2 ) - second_neighbours = batch_of_distances.at[ - ..., 1 - ].get() # Save the second nearest neighbours to check a condition - batch_of_indices = batch_of_indices.at[ - ..., 0 - ].get() ## Keep the Nearest neighbours - batch_of_distances = batch_of_distances.at[ - ..., 0 - ].get() ## Keep the Nearest neighbours + # Save the second nearest neighbours to check a condition + second_neighbours = batch_of_distances.at[..., 1].get() + + # Keep the Nearest neighbours + batch_of_indices = batch_of_indices.at[..., 0].get() + + # Keep the Nearest neighbours + batch_of_distances = batch_of_distances.at[..., 0].get() # We remove individuals that are too close to the second nn. # This avoids having clusters of individuals after adding them. @@ -316,7 +326,7 @@ def add( num_centroids = self.centroids.shape[0] - ### TODO Doesn't Work if Archive is full. Need to use the closest individuals in that case. + # TODO: Doesn't Work if Archive is full. Need to use the closest individuals in that case. empty_indexes = jnp.squeeze( jnp.nonzero( jnp.where(jnp.isinf(self.fitnesses), 1, 0), @@ -330,11 +340,10 @@ def add( -1, ) + # We get all the indices of the empty bds first and then the filled ones (because of -1) sorted_bds = jax.lax.top_k( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] - )[ - 1 - ] ## We get all the indices of the empty bds first and then the filled ones (because of -1) + )[1] batch_of_indices = jnp.where( jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value), batch_of_indices.at[sorted_bds].get(), @@ -343,26 +352,24 @@ def add( batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) - ## ReIndexing of all the inputs to the correct sorted way + # ReIndexing of all the inputs to the correct sorted way batch_of_distances = batch_of_distances.at[sorted_bds].get() batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() batch_of_genotypes = jax.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes ) - # obs = obs.at[sorted_bds].get() batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() batch_of_observations = batch_of_observations.at[sorted_bds].get() not_novel_enough = not_novel_enough.at[sorted_bds].get() - # dead = dead.at[sorted_bds].get() - ## Check to find Individuals with same BD within the Batch + # Check to find Individuals with same BD within the Batch keep_indiv = jax.jit( jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0)) )( batch_of_descriptors.squeeze(), jnp.arange( 0, batch_of_descriptors.shape[0], 1 - ), ## We do this to keep track of where we are in the batch to assure right comparisons + ), # We do this to keep track of where we are in the batch to assure right comparisons batch_of_descriptors.squeeze(), batch_of_fitnesses.squeeze(), self.l_value, @@ -370,7 +377,6 @@ def add( keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough)) - # keep_indiv = jax.vmap(intra_batch_comp, in_axes=(0,0,None,None,None), out_axes=(0))(batch_of_descriptors.squeeze(),jnp.arange(0,batch_of_descriptors.shape[0],1),batch_of_descriptors.squeeze(),batch_of_fitnesses.squeeze(),self.l_value) # get fitness segment max best_fitnesses = jax.ops.segment_max( batch_of_fitnesses, @@ -542,15 +548,14 @@ def add_relevant( batch_of_descriptors, filtered_descriptors, 2 ) - second_neighbours = batch_of_distances.at[ - ..., 1 - ].get() # Save the second nearest neighbours to check a condition - batch_of_indices = batch_of_indices.at[ - ..., 0 - ].get() ## Keep the Nearest neighbours - batch_of_distances = batch_of_distances.at[ - ..., 0 - ].get() ## Keep the Nearest neighbours + # Save the second nearest neighbours to check a condition + second_neighbours = batch_of_distances.at[..., 1].get() + + # Keep the Nearest neighbours + batch_of_indices = batch_of_indices.at[..., 0].get() + + # Keep the Nearest neighbours + batch_of_distances = batch_of_distances.at[..., 0].get() # We remove individuals that are too close to the second nn. # This avoids having clusters of individuals after adding them. @@ -558,13 +563,12 @@ def add_relevant( jnp.squeeze(second_neighbours <= new_l_values), True, False ) - # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) num_centroids = self.centroids.shape[0] - ### TODO Doesn't Work if Archive is full. Need to use the closest individuals in that case. + # TODO: Doesn't Work if Archive is full. Need to use the closest individuals in that case. empty_indexes = jnp.squeeze( jnp.nonzero( jnp.where(jnp.isinf(self.fitnesses), 1, 0), @@ -578,11 +582,10 @@ def add_relevant( -1, ) + # We get all the indices of the empty bds first and then the filled ones (because of -1) sorted_bds = jax.lax.top_k( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] - )[ - 1 - ] ## We get all the indices of the empty bds first and then the filled ones (because of -1) + )[1] batch_of_indices = jnp.where( jnp.squeeze( batch_of_distances.at[sorted_bds].get() @@ -594,36 +597,30 @@ def add_relevant( batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) - ## ReIndexing of all the inputs to the correct sorted way + # ReIndexing of all the inputs to the correct sorted way batch_of_distances = batch_of_distances.at[sorted_bds].get() batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() batch_of_genotypes = jax.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes ) - # obs = obs.at[sorted_bds].get() batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() batch_of_observations = batch_of_observations.at[sorted_bds].get() not_novel_enough = not_novel_enough.at[sorted_bds].get() new_l_values = new_l_values.at[sorted_bds].get() - # dead = dead.at[sorted_bds].get() - # filtered_l = jnp.where(new_l_values>self.l_value,self.l_value,new_l_values) - ## Check to find Individuals with same BD within the Batch + # Check to find Individuals with same BD within the Batch keep_indiv = jit( jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, 0), out_axes=(0)) )( batch_of_descriptors.squeeze(), jnp.arange( 0, batch_of_descriptors.shape[0], 1 - ), ## We do this to keep track of where we are in the batch to assure right comparisons + ), # We do this to keep track of where we are in the batch to assure right comparisons batch_of_descriptors.squeeze(), batch_of_fitnesses.squeeze(), new_l_values, ) - # keep_indiv = jnp.logical_and(keep_indiv,jnp.logical_not(not_novel_enough)) - - # keep_indiv = jax.vmap(intra_batch_comp, in_axes=(0,0,None,None,None), out_axes=(0))(batch_of_descriptors.squeeze(),jnp.arange(0,batch_of_descriptors.shape[0],1),batch_of_descriptors.squeeze(),batch_of_fitnesses.squeeze(),self.l_value) # get fitness segment max best_fitnesses = jax.ops.segment_max( batch_of_fitnesses, From 4d4955c4147e6e9ea99f05d29d1c59ee39bd5e4f Mon Sep 17 00:00:00 2001 From: Felix Date: Tue, 28 Feb 2023 13:17:51 +0200 Subject: [PATCH 03/26] fix pre-commits --- .../containers/unstructured_repertoire.py | 92 ++++++++-------- qdax/environments/bd_extractors.py | 19 +--- qdax/tasks/brax_envs.py | 5 +- qdax/utils/seq2seq_model.py | 59 ++++------ qdax/utils/train_seq2seq.py | 102 +++++++----------- tests/core_test/aurora_test.py | 18 ++-- 6 files changed, 119 insertions(+), 176 deletions(-) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 9e5a1c5d..48648be9 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -8,15 +8,7 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.types import ( - Centroid, - Descriptor, - ExtraScores, - Fitness, - Genotype, - Observation, - RNGKey, -) +from qdax.types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey @partial(jax.jit, static_argnames=("k_nn",)) @@ -63,17 +55,17 @@ def _get_cells_indices( ), ) - return func(batch_of_descriptors, centroids, k_nn) + return func(batch_of_descriptors, centroids, k_nn) # type: ignore @jax.jit def intra_batch_comp( - normed, - current_index, - normed_all, - eval_scores, - l_value, -): + normed: jnp.ndarray, + current_index: jnp.ndarray, + normed_all: jnp.ndarray, + eval_scores: jnp.ndarray, + l_value: jnp.ndarray, +) -> jnp.ndarray: # Check for individuals that are Nans, we remove them at the end not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) @@ -84,8 +76,8 @@ def intra_batch_comp( jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores ) - # If we do not use a fitness (i.e same fitness everywhere), we create a virtual fitness - # function to add individuals with the same bd + # If we do not use a fitness (i.e same fitness everywhere), we create a virtual + # fitness function to add individuals with the same bd additional_score = jnp.where( jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 ) @@ -115,8 +107,8 @@ def intra_batch_comp( ) # Is the fitness of the other individual higher? - # If both are True then we discard the current individual since this individual would be replaced - # by the better one. + # If both are True then we discard the current individual since this individual + # would be replaced by the better one. discard_indiv = jnp.logical_and( jnp.where( eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False @@ -124,7 +116,8 @@ def intra_batch_comp( fitness, ).any() - # Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + # Discard Individuals with Nans as their BD (mainly for the readdition where we + # have NaN bds) discard_indiv = jnp.logical_or(discard_indiv, not_existent) # Negate to know if we keep the individual @@ -133,12 +126,12 @@ def intra_batch_comp( @jax.jit def intra_batch_comp_relevant( - normed, - current_index, - normed_all, - eval_scores, - relevant_l_values, -): + normed: jnp.ndarray, + current_index: jnp.ndarray, + normed_all: jnp.ndarray, + eval_scores: jnp.ndarray, + relevant_l_values: jnp.ndarray, +) -> jnp.ndarray: # Check for individuals that are Nans, we remove them at the end not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) @@ -149,7 +142,8 @@ def intra_batch_comp_relevant( jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores ) - # If we do not use a fitness (i.e same fitness everywhere, we create a virtual fitness function to add individuals with the same bd) + # If we do not use a fitness (i.e same fitness everywhere, we create a virtual + # fitness function to add individuals with the same bd) additional_score = jnp.where( jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 ) @@ -181,8 +175,8 @@ def intra_batch_comp_relevant( ) # Is the fitness of the other individual higher? - # If both are True then we discard the current individual since this individual would be replaced - # by the better one. + # If both are True then we discard the current individual since this individual + # would be replaced by the better one. discard_indiv = jnp.logical_and( jnp.where( eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False @@ -190,7 +184,8 @@ def intra_batch_comp_relevant( fitness, ).any() - # Discard Individuals with Nans as their BD (mainly for the readdition where we have NaN bds) + # Discard Individuals with Nans as their BD (mainly for the readdition where we + # have NaN bds) discard_indiv = jnp.logical_or(discard_indiv, not_existent) # Negate to know if we keep the individual @@ -219,7 +214,7 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): fitnesses: Fitness descriptors: Descriptor centroids: Centroid - observations: ExtraScores + observations: Observation ages: jnp.ndarray l_value: jnp.ndarray @@ -326,7 +321,8 @@ def add( num_centroids = self.centroids.shape[0] - # TODO: Doesn't Work if Archive is full. Need to use the closest individuals in that case. + # TODO: Doesn't Work if Archive is full. Need to use the closest individuals + # in that case. empty_indexes = jnp.squeeze( jnp.nonzero( jnp.where(jnp.isinf(self.fitnesses), 1, 0), @@ -340,7 +336,8 @@ def add( -1, ) - # We get all the indices of the empty bds first and then the filled ones (because of -1) + # We get all the indices of the empty bds first and then the filled ones + # (because of -1) sorted_bds = jax.lax.top_k( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] )[1] @@ -369,7 +366,7 @@ def add( batch_of_descriptors.squeeze(), jnp.arange( 0, batch_of_descriptors.shape[0], 1 - ), # We do this to keep track of where we are in the batch to assure right comparisons + ), # keep track of where we are in the batch to assure right comparisons batch_of_descriptors.squeeze(), batch_of_fitnesses.squeeze(), self.l_value, @@ -469,7 +466,7 @@ def init( fitnesses: Fitness, descriptors: Descriptor, centroids: Centroid, - observations: ExtraScores, + observations: Observation, l_value: jnp.ndarray, ages: Optional[jnp.ndarray] = None, ) -> UnstructuredRepertoire: @@ -521,8 +518,9 @@ def init( ages=ages, ) - # return new_repertoire # type: ignore - return repertoire.add(genotypes, descriptors, fitnesses, observations) + return repertoire.add( # type: ignore + genotypes, descriptors, fitnesses, observations + ) @jax.jit def add_relevant( @@ -568,7 +566,7 @@ def add_relevant( num_centroids = self.centroids.shape[0] - # TODO: Doesn't Work if Archive is full. Need to use the closest individuals in that case. + # TODO: Doesn't Work if Archive is full. Use closest individuals in that case. empty_indexes = jnp.squeeze( jnp.nonzero( jnp.where(jnp.isinf(self.fitnesses), 1, 0), @@ -582,7 +580,8 @@ def add_relevant( -1, ) - # We get all the indices of the empty bds first and then the filled ones (because of -1) + # get all the indices of the empty bds first and then the filled ones + # (because of -1) sorted_bds = jax.lax.top_k( -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] )[1] @@ -609,13 +608,13 @@ def add_relevant( new_l_values = new_l_values.at[sorted_bds].get() # Check to find Individuals with same BD within the Batch - keep_indiv = jit( + keep_indiv = jax.jit( jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, 0), out_axes=(0)) )( batch_of_descriptors.squeeze(), jnp.arange( 0, batch_of_descriptors.shape[0], 1 - ), # We do this to keep track of where we are in the batch to assure right comparisons + ), # keep track of where we are in the batch to assure right comparisons batch_of_descriptors.squeeze(), batch_of_fitnesses.squeeze(), new_l_values, @@ -642,11 +641,6 @@ def add_relevant( addition_condition = jnp.logical_and( addition_condition, jnp.expand_dims(keep_indiv, axis=-1) ) - print(addition_condition) - print(batch_of_indices) - print(batch_of_descriptors) - print(batch_of_distances) - print(new_l_values) # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( @@ -693,7 +687,7 @@ def init_relevant( fitnesses: Fitness, descriptors: Descriptor, centroids: Centroid, - observations: ExtraScores, + observations: Observation, l_value: float, proximity_scores: jnp.ndarray, ages: Optional[jnp.ndarray] = None, @@ -746,6 +740,6 @@ def init_relevant( ) # return new_repertoire - return repertoire.add_relevant( + return repertoire.add_relevant( # type: ignore genotypes, descriptors, fitnesses, observations, proximity_scores ) diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index e81348a1..b41e6ace 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -57,17 +57,6 @@ def get_aurora_bd( # reshape mask for bd extraction mask = jnp.expand_dims(mask, axis=-1) - print("Mask: ", mask) - - # Get behavior descriptor - last_index = jnp.int32(jnp.sum(1.0 - mask, axis=1)) - 1 - - ## Doesn't Make Sense to take last valid Observation for Aurora, we take the full trajectory - # observations = jax.vmap(lambda x, y: x[y,:])(data.obs[:,::10,:15], last_index)) - - # TODO: try with all observations - # TODO: try with a padding - state_obs = data.obs[:, ::10, :25] filtered_mask = mask[:, ::10, :] @@ -89,12 +78,8 @@ def get_aurora_bd( if padding: observations = jnp.where(filtered_mask, x=jnp.array(0.0), y=observations) - # print("Observation: ", observations) - # print("Padded observation: ", padded_observations) - - model = train_seq2seq.get_model( - observations.shape[-1], True, hidden_size - ) ## lstm seq2seq + # lstm seq2seq + model = train_seq2seq.get_model(observations.shape[-1], True, hidden_size) normalized_observations = (observations - mean_observations) / std_observations descriptors = model.apply( {"params": model_params}, normalized_observations, method=model.encode diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index d94d49c1..8cc267c7 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -18,6 +18,7 @@ ExtraScores, Fitness, Genotype, + Observation, Params, RNGKey, ) @@ -352,7 +353,9 @@ def scoring_aurora_function( [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition], ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], + behavior_descriptor_extractor: Callable[ + [QDTransition, jnp.ndarray, Params, Observation, Observation], Descriptor + ], ) -> Tuple[Fitness, Descriptor, Dict[str, Union[jnp.ndarray, QDTransition]], RNGKey]: """Evaluates policies contained in flatten_variables in parallel diff --git a/qdax/utils/seq2seq_model.py b/qdax/utils/seq2seq_model.py index 0747bbb2..4660de84 100644 --- a/qdax/utils/seq2seq_model.py +++ b/qdax/utils/seq2seq_model.py @@ -1,21 +1,12 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""seq2seq example: Mode code.""" - -# See issue #620. -# pytype: disable=wrong-keyword-args +"""seq2seq example: Mode code. + +Inspired by Flax library - +https://github.com/google/flax/blob/main/examples/seq2seq/models.py + +Copyright 2022 The Flax Authors. +Licensed under the Apache License, Version 2.0 (the "License") +""" + import functools from typing import Any, Tuple @@ -47,7 +38,7 @@ def __call__( lstm_state, is_eos = carry new_lstm_state, y = nn.LSTMCell()(lstm_state, x) - def select_carried_state(new_state, old_state): + def select_carried_state(new_state: Array, old_state: Array) -> Array: return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). @@ -59,9 +50,9 @@ def select_carried_state(new_state, old_state): return (carried_lstm_state, is_eos), y @staticmethod - def initialize_carry(batch_size: int, hidden_size: int): + def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: # Use a dummy key since the default state init fn is just zeros. - return nn.LSTMCell.initialize_carry( + return nn.LSTMCell.initialize_carry( # type: ignore jax.random.PRNGKey(0), (batch_size,), hidden_size ) @@ -72,16 +63,17 @@ class Encoder(nn.Module): hidden_size: int @nn.compact - def __call__(self, inputs: Array): - # inputs.shape = (batch_size, seq_length, vocab_size). + def __call__(self, inputs: Array) -> Array: batch_size = inputs.shape[0] lstm = EncoderLSTM(name="encoder_lstm") init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) + # We use the `is_eos` array to determine whether the encoder should carry # over the last lstm state, or apply the LSTM cell on the previous state. init_is_eos = jnp.zeros(batch_size, dtype=bool) init_carry = (init_lstm_state, init_is_eos) (final_state, _), _ = lstm(init_carry, inputs) + return final_state @@ -106,6 +98,7 @@ class DecoderLSTM(nn.Module): @nn.compact def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: """Applies the DecoderLSTM model.""" + lstm_state, last_prediction = carry if not self.teacher_force: x = last_prediction @@ -165,7 +158,7 @@ class Seq2seq(nn.Module): hidden_size: int obs_size: int - def setup(self): + def setup(self) -> None: self.encoder = Encoder(hidden_size=self.hidden_size) self.decoder = Decoder(teacher_force=self.teacher_force, obs_size=self.obs_size) @@ -191,22 +184,16 @@ def __call__( containing respectively decoded logits and predictions (in one hot encoding format). """ - # Encode inputs. - # print(encoder_inputs) + # encode inputs init_decoder_state = self.encoder(encoder_inputs) - # print(init_decoder_state) - # Encoder(hidden_size=self.hidden_size) - # Decode outputs. + + # decode outputs logits, predictions = self.decoder(decoder_inputs, init_decoder_state) - # Decoder( - # init_state=init_decoder_state, - # teacher_force=self.teacher_force, - # obs_size=self.obs_size)(decoder_inputs[:, :-1],init_decoder_state) return logits, predictions - def encode(self, encoder_inputs: Array): - # Encode inputs. + def encode(self, encoder_inputs: Array) -> Array: + # encode inputs init_decoder_state = self.encoder(encoder_inputs) - final_output, hidden_state = init_decoder_state + final_output, _hidden_state = init_decoder_state return final_output diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index e2481b18..12c14c41 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -1,67 +1,37 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""seq2seq addition example.""" - -# See issue #620. -# pytype: disable=wrong-keyword-args +"""seq2seq addition example + +Inspired by Flax library - +https://github.com/google/flax/blob/main/examples/seq2seq/train.py + +Copyright 2022 The Flax Authors. +Licensed under the Apache License, Version 2.0 (the "License") +""" from typing import Any, Dict, Tuple import jax import jax.numpy as jnp import optax -from absl import flags from flax.training import train_state +from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire +from qdax.types import Observation, Params, RNGKey from qdax.utils.seq2seq_model import Seq2seq Array = Any -FLAGS = flags.FLAGS PRNGKey = Any -flags.DEFINE_string("workdir", default=".", help="Where to store log output.") - -flags.DEFINE_float( - "learning_rate", default=0.003, help=("The learning rate for the Adam optimizer.") -) - -flags.DEFINE_integer("batch_size", default=128, help=("Batch size for training.")) - -flags.DEFINE_integer("hidden_size", default=16, help=("Hidden size of the LSTM.")) - -flags.DEFINE_integer("num_train_steps", default=10000, help=("Number of train steps.")) - -flags.DEFINE_integer( - "decode_frequency", - default=200, - help=("Frequency of decoding during training, e.g. every 1000 steps."), -) - -flags.DEFINE_integer( - "max_len_query_digit", default=3, help=("Maximum length of a single input digit.") -) - -def get_model(obs_size, teacher_force: bool = False, hidden_size=10) -> Seq2seq: +def get_model( + obs_size: int, teacher_force: bool = False, hidden_size: int = 10 +) -> Seq2seq: return Seq2seq( teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size ) def get_initial_params( - model: Seq2seq, rng: PRNGKey, encoder_input_shape + model: Seq2seq, rng: PRNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: """Returns the initial parameters of a seq2seq model.""" rng1, rng2, rng3 = jax.random.split(rng, 3) @@ -70,7 +40,7 @@ def get_initial_params( jnp.ones(encoder_input_shape, jnp.float32), jnp.ones(encoder_input_shape, jnp.float32), ) - return variables["params"] + return variables["params"] # type: ignore @jax.jit @@ -80,12 +50,14 @@ def train_step( """Trains one step.""" lstm_key = jax.random.fold_in(lstm_rng, state.step) dropout_key, lstm_key = jax.random.split(lstm_key, 2) - # Shift Input by One to avoid leakage + + # Shift input by one to avoid leakage batch_decoder = jnp.roll(batch, shift=1, axis=1) - ### Large number as zero token + + # Large number as zero token batch_decoder = batch_decoder.at[:, 0, :].set(-1000) - def loss_fn(params): + def loss_fn(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]: logits, _ = state.apply_fn( {"params": params}, batch, @@ -93,14 +65,9 @@ def loss_fn(params): rngs={"lstm": lstm_key, "dropout": dropout_key}, ) - def squared_error(x, y): - return jnp.inner(y - x, y - x) / 2.0 - - def mean_squared_error(x, y): + def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: return jnp.inner(y - x, y - x) / x.shape[-1] - # res = jax.vmap(squared_error)(logits, batch) - # res = jax.vmap(squared_error)(jnp.reshape(logits,(logits.shape[0],-1)),jnp.reshape(batch,(batch.shape[0],-1))) res = jax.vmap(mean_squared_error)( jnp.reshape(logits.at[:, :-1, ...].get(), (logits.shape[0], -1)), jnp.reshape( @@ -111,24 +78,33 @@ def mean_squared_error(x, y): return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss_val, logits), grads = grad_fn(state.params) + (loss_val, _logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state, loss_val -def lstm_ae_train(key, repertoire, params, epoch, hidden_size=10): +def lstm_ae_train( + key: RNGKey, + repertoire: UnstructuredRepertoire, + params: Params, + epoch: int, + hidden_size: int = 10, +) -> Tuple[Params, Observation, Observation]: batch_size = 128 # 2048 if epoch > 100: num_epochs = 25 - alpha = 0.0001 # Gradient step size + + # Gradient step size + alpha = 0.0001 else: num_epochs = 100 - alpha = 0.0001 # Gradient step size + + # Gradient step size + alpha = 0.0001 rng, key, key_selection = jax.random.split(key, 3) - dimensions_data = jnp.prod(jnp.asarray(repertoire.observations.shape[1:])) # get the model used (seq2seq) model = get_model( @@ -191,7 +167,6 @@ def lstm_ae_train(key, repertoire, params, epoch, hidden_size=10): print("Valid indexes: ", valid_indexes) # Normalising Dataset - # training_dataset = (repertoire.observations.at[valid_indexes].get()-mean_obs)/std_obs #jnp.where(std_obs==0,mean_obs,std_obs) steps_per_epoch = repertoire.observations.shape[0] // batch_size loss_val = 0.0 @@ -208,7 +183,7 @@ def lstm_ae_train(key, repertoire, params, epoch, hidden_size=10): # create dataset with the observation from the sample of valid indexes training_dataset = ( repertoire.observations.at[valid_indexes, ...].get() - mean_obs - ) / std_obs # jnp.where(std_obs==0,mean_obs,std_obs) + ) / std_obs training_dataset = training_dataset.at[valid_indexes].get() if epoch == 0: @@ -227,14 +202,13 @@ def lstm_ae_train(key, repertoire, params, epoch, hidden_size=10): continue state, loss_val = train_step(state, batch, rng) - ### To see the actual value we cannot jit this function (i.e. the _one_es_epoch function nor the train function) + # To see the actual value we cannot jit this function (i.e. the _one_es_epoch + # function nor the train function) print("Eval epoch: {}, loss: {:.4f}".format(epoch + 1, loss_val)) # TODO: put this in metrics so we can jit the function and see the metrics # TODO: not urgent because the training is not that long - # return repertoire.replace(ae_params=state.params,mean_obs=mean_obs,std_obs=std_obs) - train_step.clear_cache() del tx del model diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 519007eb..7b6e10e8 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -167,7 +167,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: ) # Init algorithm - ## AutoEncoder Params and INIT + # AutoEncoder Params and INIT # observations_dims = (20, 25) obs_dim = jnp.minimum(env.observation_size, 25) if observation_option == "full": @@ -191,7 +191,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: model_params = train_seq2seq.get_initial_params( model, subkey, (1, observations_dims[0], observations_dims[-1]) ) - # model_params = train_seq2seq.get_initial_params(model,subkey,(1,repertoire.observations.shape[1],repertoire.observations.shape[-1])) + print(jax.tree_map(lambda x: x.shape, model_params)) mean_observations = jnp.zeros(observations_dims[-1]) @@ -208,7 +208,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: l_value_init, ) - ## Initializing Means and stds and Aurora + # Initializing Means and stds and Aurora random_key, subkey = jax.random.split(random_key) model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train( subkey, repertoire, model_params, 0, hidden_size=hidden_size @@ -240,7 +240,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: # update nb steps estimation current_step_estimation += batch_size * episode_length * log_freq - ## Autoencoder Steps and CVC + # Autoencoder Steps and CVC # individuals_in_repo = jnp.sum(repertoire.fitnesses != -jnp.inf) if (iteration + 1) in schedules: @@ -257,12 +257,13 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: iteration, hidden_size=hidden_size, ) - ### RE-ADDITION OF ALL THE NEW BEHAVIOURAL DESCRIPTORS WITH THE NEW AE - # model = train_seq2seq.get_model(repertoire.observations.shape[-1],True) ## lstm seq2seq + # re-addition of all the new behavioural descriotpors with the new ae + normalized_observations = ( repertoire.observations - mean_observations ) / std_observations + new_descriptors = model.apply( {"params": model_params}, normalized_observations, method=model.encode ) @@ -280,7 +281,6 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) - # l_value = repertoire.l_value * (1+1*10e-7*(num_indivs-n_target)) current_error = num_indivs - n_target change_rate = current_error - previous_error prop_gain = 1 * 10e-6 @@ -291,7 +291,8 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: ) print(change_rate, current_error) previous_error = current_error - ## CVC Implementation to keep a Constant number of individuals in the Archive + + # CVC Implementation to keep a Constant number of individuals in the Archive repertoire = repertoire.init( genotypes=repertoire.genotypes, centroids=repertoire.centroids, @@ -300,7 +301,6 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: observations=repertoire.observations, l_value=l_value, ) - new_num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) pytest.assume(repertoire is not None) From 27def80b6dd1a655d4198b8b72c65235530ba832 Mon Sep 17 00:00:00 2001 From: Felix Date: Tue, 28 Feb 2023 15:48:15 +0200 Subject: [PATCH 04/26] WIP - example notebook --- examples/aurora.ipynb | 540 +++++++++++++++++++++++++++++ qdax/environments/bd_extractors.py | 12 +- tests/core_test/aurora_test.py | 74 ++-- 3 files changed, 588 insertions(+), 38 deletions(-) create mode 100644 examples/aurora.ipynb diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb new file mode 100644 index 00000000..308fe504 --- /dev/null +++ b/examples/aurora.ipynb @@ -0,0 +1,540 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with AURORA in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [AURORA](https://arxiv.org/pdf/1905.11874.pdf).\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an AURORA instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.aurora import AURORA\n", + "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import scoring_aurora_function\n", + "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.types import EnvState, Params, RNGKey\n", + "from qdax.utils import train_seq2seq\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "from jax.flatten_util import ravel_pytree\n", + "\n", + "from IPython.display import HTML\n", + "from brax.io import html\n", + "\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 100 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", + "\n", + "\n", + "# Create the initial environment states\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)\n", + "reset_fn = jax.jit(jax.vmap(env.reset))\n", + "init_states = reset_fn(keys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"\n", + " Play an environment step and return the updated state and the transition.\n", + " \"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " \n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " scoring_function,\n", + " init_states=init_states,\n", + " episode_length=episode_length,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_function = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=None, \n", + " variation_fn=variation_fn, \n", + " variation_percentage=1.0, \n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate MAP-Elites\n", + "map_elites = MAPElites(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_function,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# Compute initial repertoire and emitter state\n", + "repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch MAP-Elites iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"mapelites-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "map_elites_scan_update = map_elites.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " map_elites_scan_update,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to save/load a repertoire\n", + "\n", + "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the final repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire_path = \"./last_repertoire/\"\n", + "os.makedirs(repertoire_path, exist_ok=True)\n", + "repertoire.save(path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the reconstruction function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init population of policies\n", + "random_key, subkey = jax.random.split(random_key)\n", + "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", + "fake_params = policy_network.init(subkey, fake_batch)\n", + "\n", + "_, reconstruction_fn = ravel_pytree(fake_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the reconstruction function to load and re-create the repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get the best individual of the repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_idx = jnp.argmax(repertoire.fitnesses)\n", + "best_fitness = jnp.max(repertoire.fitnesses)\n", + "best_bd = repertoire.descriptors[best_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\",\n", + " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\",\n", + " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_params = jax.tree_util.tree_map(\n", + " lambda x: x[best_idx],\n", + " repertoire.genotypes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Play some steps in the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jit_env_reset = jax.jit(env.reset)\n", + "jit_env_step = jax.jit(env.step)\n", + "jit_inference_fn = jax.jit(policy_network.apply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout = []\n", + "rng = jax.random.PRNGKey(seed=1)\n", + "state = jit_env_reset(rng=rng)\n", + "while not state.done:\n", + " rollout.append(state)\n", + " action = jit_inference_fn(my_params, state.obs)\n", + " state = jit_env_step(state, action)\n", + "\n", + "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index b41e6ace..a68abf2c 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -47,7 +47,8 @@ def get_aurora_bd( std_observations: jnp.ndarray, option: str = "full", hidden_size: int = 10, - padding: bool = False, + traj_sampling_freq: int = 10, + max_observation_size: int = 25, ) -> Descriptor: """Compute final aurora embedding. @@ -57,11 +58,10 @@ def get_aurora_bd( # reshape mask for bd extraction mask = jnp.expand_dims(mask, axis=-1) - state_obs = data.obs[:, ::10, :25] - filtered_mask = mask[:, ::10, :] + state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] # add the x/y position - (batch_size, traj_length, 2) - state_desc = data.state_desc[:, ::10] + state_desc = data.state_desc[:, ::traj_sampling_freq] print("State Observations: ", state_obs) print("XY positions: ", state_desc) @@ -74,10 +74,6 @@ def get_aurora_bd( elif option == "only_sd": observations = state_desc - # add padding when the episode is done - if padding: - observations = jnp.where(filtered_mask, x=jnp.array(0.0), y=observations) - # lstm seq2seq model = train_seq2seq.get_model(observations.shape[-1], True, hidden_size) normalized_observations = (observations - mean_observations) / std_observations diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 7b6e10e8..e40b27bd 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -9,7 +9,7 @@ from qdax import environments from qdax.core.aurora import AURORA -from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.neuroevolution.buffers.buffer import QDTransition @@ -27,7 +27,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: batch_size = batch_size env_name = env_name - episode_length = 100 + episode_length = 250 num_iterations = 5 seed = 42 policy_hidden_layer_sizes = (64, 64) @@ -37,6 +37,10 @@ def test_aurora(env_name: str, batch_size: int) -> None: hidden_size = 5 l_value_init = 0.2 + traj_sampling_freq = 10 + max_observation_size = 25 + prior_descriptor_dim = 2 + log_freq = 5 # Init environment @@ -98,6 +102,8 @@ def play_step_fn( get_aurora_bd, option=observation_option, hidden_size=hidden_size, + traj_sampling_freq=traj_sampling_freq, + max_observation_size=max_observation_size, ) scoring_fn = functools.partial( scoring_aurora_function, @@ -120,7 +126,7 @@ def play_step_fn( reward_offset = environments.reward_offset[env_name] # Define a metrics function - def metrics_fn(repertoire: MapElitesRepertoire) -> Dict: + def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: # Get metrics grid_empty = repertoire.fitnesses == -jnp.inf @@ -144,7 +150,7 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict: @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: - # iterate over grid + """Scan the udpate function.""" ( repertoire, random_key, @@ -152,6 +158,8 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: mean_observations, std_observations, ) = carry + + # update (repertoire, _, metrics, random_key,) = aurora.update( repertoire, None, @@ -168,36 +176,40 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: # Init algorithm # AutoEncoder Params and INIT - # observations_dims = (20, 25) - obs_dim = jnp.minimum(env.observation_size, 25) + obs_dim = jnp.minimum(env.observation_size, max_observation_size) if observation_option == "full": - observations_dims = (25, obs_dim + 2) # 250 / 10, 25 + 2 - if observation_option == "no_sd": - observations_dims = (25, obs_dim) # 250 / 10, 25 - if observation_option == "only_sd": - observations_dims = (25, 2) # 250 / 10, 2 + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim + prior_descriptor_dim, + ) + elif observation_option == "no_sd": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim, + ) + elif observation_option == "only_sd": + observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) + else: + ValueError("The chosen option is not correct.") + # define the seq2seq model model = train_seq2seq.get_model( observations_dims[-1], True, hidden_size=hidden_size ) - random_key, subkey = jax.random.split(random_key) - - # design aurora's schedule - default_update_base = 10 - update_base = int(jnp.ceil(default_update_base / log_freq)) - schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base)) - print("Schedules: ", schedules) + # init the model params + random_key, subkey = jax.random.split(random_key) model_params = train_seq2seq.get_initial_params( - model, subkey, (1, observations_dims[0], observations_dims[-1]) + model, subkey, (1, *observations_dims) ) print(jax.tree_map(lambda x: x.shape, model_params)) + # define arbitrary observation's mean/std mean_observations = jnp.zeros(observations_dims[-1]) - std_observations = jnp.ones(observations_dims[-1]) + # init step of the aurora algorithm repertoire, _, random_key = aurora.init( init_variables, centroids, @@ -208,12 +220,17 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: l_value_init, ) - # Initializing Means and stds and Aurora + # initializing means and stds and AURORA random_key, subkey = jax.random.split(random_key) model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train( subkey, repertoire, model_params, 0, hidden_size=hidden_size ) + # design aurora's schedule + default_update_base = 10 + update_base = int(jnp.ceil(default_update_base / log_freq)) + schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base)) + current_step_estimation = 0 num_iterations = 0 @@ -222,7 +239,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target - iteration = 1 # to be consistent with other exp scripts + iteration = 0 while iteration < num_iterations: ( @@ -240,12 +257,10 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: # update nb steps estimation current_step_estimation += batch_size * episode_length * log_freq - # Autoencoder Steps and CVC - # individuals_in_repo = jnp.sum(repertoire.fitnesses != -jnp.inf) - + # autoencoder steps and CVC if (iteration + 1) in schedules: + # train the autoencoder random_key, subkey = jax.random.split(random_key) - ( model_params, mean_observations, @@ -259,7 +274,6 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: ) # re-addition of all the new behavioural descriotpors with the new ae - normalized_observations = ( repertoire.observations - mean_observations ) / std_observations @@ -278,9 +292,10 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) elif iteration % 2 == 0: - + # update the l value num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + # CVC Implementation to keep a constant number of individuals in the archive current_error = num_indivs - n_target change_rate = current_error - previous_error prop_gain = 1 * 10e-6 @@ -289,10 +304,9 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: + (prop_gain * (current_error)) + (prop_gain * change_rate) ) - print(change_rate, current_error) + previous_error = current_error - # CVC Implementation to keep a Constant number of individuals in the Archive repertoire = repertoire.init( genotypes=repertoire.genotypes, centroids=repertoire.centroids, From 86a19a0220ff8b780d27b1fbe65fe55f029f1a9c Mon Sep 17 00:00:00 2001 From: Felix Date: Tue, 28 Feb 2023 16:55:05 +0200 Subject: [PATCH 05/26] fix hyperparams issue - clean example notebook --- examples/aurora.ipynb | 403 ++++++++++++++++++++------------- qdax/utils/train_seq2seq.py | 6 +- tests/core_test/aurora_test.py | 23 +- 3 files changed, 266 insertions(+), 166 deletions(-) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index 308fe504..a784f676 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb)" ] }, { @@ -41,7 +41,7 @@ "\n", "from IPython.display import clear_output\n", "import functools\n", - "import time\n", + "from typing import Dict, Any\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -68,19 +68,9 @@ "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", - "from qdax.utils.plotting import plot_map_elites_results\n", "\n", - "from qdax.types import EnvState, Params, RNGKey\n", "from qdax.utils import train_seq2seq\n", "\n", - "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", - "\n", - "from jax.flatten_util import ravel_pytree\n", - "\n", - "from IPython.display import HTML\n", - "from brax.io import html\n", - "\n", - "\n", "\n", "if \"COLAB_TPU_ADDR\" in os.environ:\n", " from jax.tools import colab_tpu\n", @@ -98,18 +88,30 @@ "source": [ "#@title QD Training Definitions Fields\n", "#@markdown ---\n", - "batch_size = 100 #@param {type:\"number\"}\n", + "batch_size = 10 #@param {type:\"number\"}\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", - "episode_length = 100 #@param {type:\"integer\"}\n", - "num_iterations = 1000 #@param {type:\"integer\"}\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "max_iterations = 50 #@param {type:\"integer\"}\n", "seed = 42 #@param {type:\"integer\"}\n", "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", "iso_sigma = 0.005 #@param {type:\"number\"}\n", "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", - "num_centroids = 1024 #@param {type:\"integer\"}\n", + "num_centroids = 50 #@param {type:\"integer\"}\n", "min_bd = 0. #@param {type:\"number\"}\n", "max_bd = 1.0 #@param {type:\"number\"}\n", + "\n", + "batch_size = 128 #@param {type:\"integer\"}\n", + "\n", + "observation_option = \"no_sd\" #@param['no_sd', 'only_sd', 'full']\n", + "hidden_size = 5 #@param {type:\"integer\"}\n", + "l_value_init = 0.2 #@param {type:\"number\"}\n", + "\n", + "traj_sampling_freq = 10 #@param {type:\"integer\"}\n", + "max_observation_size = 25 #@param {type:\"integer\"}\n", + "prior_descriptor_dim = 2 #@param {type:\"integer\"}\n", + "\n", + "log_freq = 5 #@param {type:\"integer\"}\n", "#@markdown ---" ] }, @@ -216,9 +218,15 @@ "outputs": [], "source": [ "# Prepare the scoring function\n", - "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "bd_extraction_fn = functools.partial(\n", + " get_aurora_bd,\n", + " option=observation_option,\n", + " hidden_size=hidden_size,\n", + " traj_sampling_freq=traj_sampling_freq,\n", + " max_observation_size=max_observation_size,\n", + ")\n", "scoring_fn = functools.partial(\n", - " scoring_function,\n", + " scoring_aurora_function,\n", " init_states=init_states,\n", " episode_length=episode_length,\n", " play_step_fn=play_step_fn,\n", @@ -229,10 +237,17 @@ "reward_offset = environments.reward_offset[env_name]\n", "\n", "# Define a metrics function\n", - "metrics_function = functools.partial(\n", - " default_qd_metrics,\n", - " qd_offset=reward_offset * episode_length,\n", - ")" + "def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict:\n", + "\n", + " # Get metrics\n", + " grid_empty = repertoire.fitnesses == -jnp.inf\n", + " qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)\n", + " # Add offset for positive qd_score\n", + " qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)\n", + " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", + " max_fitness = jnp.max(repertoire.fitnesses)\n", + "\n", + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}\n" ] }, { @@ -255,10 +270,10 @@ " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", ")\n", "mixing_emitter = MixingEmitter(\n", - " mutation_fn=None, \n", - " variation_fn=variation_fn, \n", - " variation_percentage=1.0, \n", - " batch_size=batch_size\n", + " mutation_fn=lambda x, y: (x, y),\n", + " variation_fn=variation_fn,\n", + " variation_percentage=1.0,\n", + " batch_size=batch_size,\n", ")" ] }, @@ -275,32 +290,105 @@ "metadata": {}, "outputs": [], "source": [ - "# Instantiate MAP-Elites\n", - "map_elites = MAPElites(\n", + "# Instantiate AURORA\n", + "aurora = AURORA(\n", " scoring_function=scoring_fn,\n", " emitter=mixing_emitter,\n", - " metrics_function=metrics_function,\n", + " metrics_function=metrics_fn,\n", ")\n", "\n", - "# Compute the centroids\n", - "centroids, random_key = compute_cvt_centroids(\n", - " num_descriptors=env.behavior_descriptor_length,\n", - " num_init_cvt_samples=num_init_cvt_samples,\n", - " num_centroids=num_centroids,\n", - " minval=min_bd,\n", - " maxval=max_bd,\n", - " random_key=random_key,\n", + "aurora_dims = hidden_size\n", + "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", + "\n", + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " \"\"\"Scan the udpate function.\"\"\"\n", + " (\n", + " repertoire,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = carry\n", + "\n", + " # update\n", + " (repertoire, _, metrics, random_key,) = aurora.update(\n", + " repertoire,\n", + " None,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " )\n", + "\n", + " return (\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " )\n", + "\n", + "# Init algorithm\n", + "# AutoEncoder Params and INIT\n", + "obs_dim = jnp.minimum(env.observation_size, max_observation_size)\n", + "if observation_option == \"full\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim + prior_descriptor_dim,\n", + " )\n", + "elif observation_option == \"no_sd\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim,\n", + " )\n", + "elif observation_option == \"only_sd\":\n", + " observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim)\n", + "else:\n", + " ValueError(\"The chosen option is not correct.\")\n", + "\n", + "# define the seq2seq model\n", + "model = train_seq2seq.get_model(\n", + " observations_dims[-1], True, hidden_size=hidden_size\n", ")\n", "\n", - "# Compute initial repertoire and emitter state\n", - "repertoire, emitter_state, random_key = map_elites.init(init_variables, centroids, random_key)" + "# init the model params\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params = train_seq2seq.get_initial_params(\n", + " model, subkey, (1, *observations_dims)\n", + ")\n", + "\n", + "print(jax.tree_map(lambda x: x.shape, model_params))\n", + "\n", + "# define arbitrary observation's mean/std\n", + "mean_observations = jnp.zeros(observations_dims[-1])\n", + "std_observations = jnp.ones(observations_dims[-1])\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, _, random_key = aurora.init(\n", + " init_variables,\n", + " centroids,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " l_value_init,\n", + ")\n", + "\n", + "# initializing means and stds and AURORA\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", + " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + ")\n", + "\n", + "# design aurora's schedule\n", + "default_update_base = 10\n", + "update_base = int(jnp.ceil(default_update_base / log_freq))\n", + "schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Launch MAP-Elites iterations" + "## Launch AURORA iterations" ] }, { @@ -309,41 +397,93 @@ "metadata": {}, "outputs": [], "source": [ - "log_period = 10\n", - "num_loops = int(num_iterations / log_period)\n", + "current_step_estimation = 0\n", + "num_iterations = 0\n", "\n", - "csv_logger = CSVLogger(\n", - " \"mapelites-logs.csv\",\n", - " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", - ")\n", - "all_metrics = {}\n", - "\n", - "# main loop\n", - "map_elites_scan_update = map_elites.scan_update\n", - "for i in range(num_loops):\n", - " start_time = time.time()\n", - " # main iterations\n", - " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", - " map_elites_scan_update,\n", - " (repertoire, emitter_state, random_key),\n", - " (),\n", - " length=log_period,\n", - " )\n", - " timelapse = time.time() - start_time\n", + "# Main loop\n", + "n_target = 1024\n", + "\n", + "previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n", "\n", - " # log metrics\n", - " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", - " for key, value in metrics.items():\n", - " # take last value\n", - " logged_metrics[key] = value[-1]\n", + "iteration = 0\n", + "while iteration < max_iterations:\n", "\n", - " # take all values\n", - " if key in all_metrics.keys():\n", - " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", - " else:\n", - " all_metrics[key] = value\n", + " (\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " ) = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, random_key, model_params, mean_observations, std_observations),\n", + " (),\n", + " length=log_freq,\n", + " )\n", "\n", - " csv_logger.log(logged_metrics)" + " num_iterations = iteration * log_freq\n", + "\n", + " # update nb steps estimation\n", + " current_step_estimation += batch_size * episode_length * log_freq\n", + "\n", + " # autoencoder steps and CVC\n", + " if (iteration + 1) in schedules:\n", + " # train the autoencoder\n", + " random_key, subkey = jax.random.split(random_key)\n", + " (\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = train_seq2seq.lstm_ae_train(\n", + " subkey,\n", + " repertoire,\n", + " model_params,\n", + " iteration,\n", + " hidden_size=hidden_size,\n", + " batch_size=lstm_batch_size\n", + " )\n", + "\n", + " # re-addition of all the new behavioural descriotpors with the new ae\n", + " normalized_observations = (\n", + " repertoire.observations - mean_observations\n", + " ) / std_observations\n", + "\n", + " new_descriptors = model.apply(\n", + " {\"params\": model_params}, normalized_observations, method=model.encode\n", + " )\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=new_descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=repertoire.l_value,\n", + " )\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " elif iteration % 2 == 0:\n", + " # update the l value\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " # CVC Implementation to keep a constant number of individuals in the archive\n", + " current_error = num_indivs - n_target\n", + " change_rate = current_error - previous_error\n", + " prop_gain = 1 * 10e-6\n", + " l_value = (\n", + " repertoire.l_value\n", + " + (prop_gain * (current_error))\n", + " + (prop_gain * change_rate)\n", + " )\n", + "\n", + " previous_error = current_error\n", + "\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=repertoire.descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=l_value,\n", + " )\n", + "\n", + " iteration += 1" ] }, { @@ -352,165 +492,114 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Visualization\n", - "\n", - "# create the x-axis array\n", - "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", - "\n", - "# create the plots and the grid\n", - "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + "for k, v in metrics.items():\n", + " print(k, \" - \", v[-1])" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# How to save/load a repertoire\n", - "\n", - "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." - ] + "outputs": [], + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Load the final repertoire" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "repertoire_path = \"./last_repertoire/\"\n", - "os.makedirs(repertoire_path, exist_ok=True)\n", - "repertoire.save(path=repertoire_path)" - ] + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Build the reconstruction function" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# Init population of policies\n", - "random_key, subkey = jax.random.split(random_key)\n", - "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", - "fake_params = policy_network.init(subkey, fake_batch)\n", - "\n", - "_, reconstruction_fn = ravel_pytree(fake_params)" - ] + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Use the reconstruction function to load and re-create the repertoire" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" - ] + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Get the best individual of the repertoire" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "best_idx = jnp.argmax(repertoire.fitnesses)\n", - "best_fitness = jnp.max(repertoire.fitnesses)\n", - "best_bd = repertoire.descriptors[best_idx]" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "print(\n", - " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\",\n", - " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\",\n", - " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", - ")" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "my_params = jax.tree_util.tree_map(\n", - " lambda x: x[best_idx],\n", - " repertoire.genotypes\n", - ")" - ] + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "## Play some steps in the environment" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "jit_env_reset = jax.jit(env.reset)\n", - "jit_env_step = jax.jit(env.step)\n", - "jit_inference_fn = jax.jit(policy_network.apply)" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "rollout = []\n", - "rng = jax.random.PRNGKey(seed=1)\n", - "state = jit_env_reset(rng=rng)\n", - "while not state.done:\n", - " rollout.append(state)\n", - " action = jit_inference_fn(my_params, state.obs)\n", - " state = jit_env_step(state, action)\n", - "\n", - "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" - ] + "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))" - ] + "source": [] } ], "metadata": { diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 12c14c41..edb4be3f 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -90,8 +90,8 @@ def lstm_ae_train( params: Params, epoch: int, hidden_size: int = 10, + batch_size: int = 128, ) -> Tuple[Params, Observation, Observation]: - batch_size = 128 # 2048 if epoch > 100: num_epochs = 25 @@ -169,6 +169,8 @@ def lstm_ae_train( # Normalising Dataset steps_per_epoch = repertoire.observations.shape[0] // batch_size + print("Steps per epoch: ", steps_per_epoch) + loss_val = 0.0 for epoch in range(num_epochs): rng, shuffle_key = jax.random.split(rng, 2) @@ -200,7 +202,9 @@ def lstm_ae_train( if batch.shape[0] < batch_size: # print(batch.shape) continue + state, loss_val = train_step(state, batch, rng) + print("Loss value has been updated, new value: ", loss_val) # To see the actual value we cannot jit this function (i.e. the _one_es_epoch # function nor the train function) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index e40b27bd..6bbd7587 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -28,12 +28,14 @@ def test_aurora(env_name: str, batch_size: int) -> None: batch_size = batch_size env_name = env_name episode_length = 250 - num_iterations = 5 + max_iterations = 5 seed = 42 policy_hidden_layer_sizes = (64, 64) num_centroids = 50 - observation_option = "only_sd" + lstm_batch_size = 12 + + observation_option = "no_sd" # "full", "no_sd", "only_sd" hidden_size = 5 l_value_init = 0.2 @@ -138,7 +140,7 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} - # Instantiate MAP-Elites + # Instantiate AURORA aurora = AURORA( scoring_function=scoring_fn, emitter=mixing_emitter, @@ -223,7 +225,12 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: # initializing means and stds and AURORA random_key, subkey = jax.random.split(random_key) model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train( - subkey, repertoire, model_params, 0, hidden_size=hidden_size + subkey, + repertoire, + model_params, + 0, + hidden_size=hidden_size, + batch_size=lstm_batch_size, ) # design aurora's schedule @@ -232,7 +239,6 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base)) current_step_estimation = 0 - num_iterations = 0 # Main loop n_target = 1024 @@ -240,7 +246,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target iteration = 0 - while iteration < num_iterations: + while iteration < max_iterations: ( (repertoire, random_key, model_params, mean_observations, std_observations), @@ -252,8 +258,6 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: length=log_freq, ) - num_iterations = iteration * log_freq - # update nb steps estimation current_step_estimation += batch_size * episode_length * log_freq @@ -271,6 +275,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: model_params, iteration, hidden_size=hidden_size, + batch_size=lstm_batch_size, ) # re-addition of all the new behavioural descriotpors with the new ae @@ -316,6 +321,8 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: l_value=l_value, ) + iteration += 1 + pytest.assume(repertoire is not None) From 955a0815aeb381b6548773c08c0b135f2f13c1e2 Mon Sep 17 00:00:00 2001 From: Felix Date: Wed, 1 Mar 2023 13:17:02 +0200 Subject: [PATCH 06/26] fix brax version issue with halfcheetah --- qdax/environments/exploration_wrappers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index 80635960..30ac426b 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -85,13 +85,16 @@ } """ +try: + HALFCHEETAH_SYSTEM_CONFIG = brax.envs.halfcheetah._SYSTEM_CONFIG +except AttributeError: + HALFCHEETAH_SYSTEM_CONFIG = brax.envs.half_cheetah._SYSTEM_CONFIG + # storing the classic env configurations # those are the configs from the official brax repo ENV_SYSTEM_CONFIG = { "ant": brax.envs.ant._SYSTEM_CONFIG, - "halfcheetah": brax.envs.halfcheetah._SYSTEM_CONFIG - if brax.__version__ == "0.0.12" - else brax.envs.half_cheetah._SYSTEM_CONFIG, + "halfcheetah": HALFCHEETAH_SYSTEM_CONFIG, "walker2d": brax.envs.walker2d._SYSTEM_CONFIG, "hopper": brax.envs.hopper._SYSTEM_CONFIG, # "humanoid": brax.envs.humanoid._SYSTEM_CONFIG, From 5adb3a403301ee080f6af56bccbe1d2971be154f Mon Sep 17 00:00:00 2001 From: Felix Date: Wed, 1 Mar 2023 16:48:23 +0200 Subject: [PATCH 07/26] update readme --- README.md | 2 +- examples/aurora.ipynb | 105 ------------------------------------------ 2 files changed, 1 insertion(+), 106 deletions(-) diff --git a/README.md b/README.md index 452950d1..fac1016a 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ QDax currently supports the following algorithms: | [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb) | | [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | | [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | - +| [AURORA](https://dl.acm.org/doi/abs/10.1145/3321707.3321804) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb) | ## QDax baseline algorithms The QDax library also provides implementations for some useful baseline algorithms: diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index a784f676..9e335f26 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -495,111 +495,6 @@ "for k, v in metrics.items():\n", " print(k, \" - \", v[-1])" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From a67ef6ac4b1e58c32ee4697b7f6cac4613ebebbb Mon Sep 17 00:00:00 2001 From: Felix Date: Wed, 1 Mar 2023 17:44:28 +0200 Subject: [PATCH 08/26] add docs and pga aurora example --- README.md | 1 + docs/api_documentation/core/aurora.md | 7 + docs/api_documentation/core/pga_aurora.md | 5 + examples/aurora.ipynb | 6 +- examples/pga_aurora.ipynb | 571 ++++++++++++++++++++++ examples/pgame.ipynb | 3 +- mkdocs.yml | 4 + 7 files changed, 592 insertions(+), 5 deletions(-) create mode 100644 docs/api_documentation/core/aurora.md create mode 100644 docs/api_documentation/core/pga_aurora.md create mode 100644 examples/pga_aurora.ipynb diff --git a/README.md b/README.md index fac1016a..65a33e7d 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ QDax currently supports the following algorithms: | [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | | [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | | [AURORA](https://dl.acm.org/doi/abs/10.1145/3321707.3321804) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/aurora.ipynb) | +| [PGA-AURORA](https://arxiv.org/abs/2210.03516) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pga_aurora.ipynb) | ## QDax baseline algorithms The QDax library also provides implementations for some useful baseline algorithms: diff --git a/docs/api_documentation/core/aurora.md b/docs/api_documentation/core/aurora.md new file mode 100644 index 00000000..1088b2cd --- /dev/null +++ b/docs/api_documentation/core/aurora.md @@ -0,0 +1,7 @@ +# AURORA class + +This class implement the base mechanism of AURORA. It must be used with an emitter. To get the usual AURORA algorithm, one must use the [mixing emitter](emitters.md#qdax.core.emitters.standard_emitters.MixingEmitter). + +The AURORA class can be used with other emitters to create variants, like [PGA-AURORA](pga_aurora.md). + +::: qdax.core.aurora.AURORA diff --git a/docs/api_documentation/core/pga_aurora.md b/docs/api_documentation/core/pga_aurora.md new file mode 100644 index 00000000..dc4fd6d1 --- /dev/null +++ b/docs/api_documentation/core/pga_aurora.md @@ -0,0 +1,5 @@ +# Policy Gradient Assisted AURORA (PGA-AURORA) + +To create an instance of PGA-AURORA (introduced [in this paper](https://arxiv.org/abs/2210.03516)), one needs to use an instance of [AURORA](map_elites.md) with the PGAMEEmitter, detailed below. + +::: qdax.core.emitters.pga_me_emitter.PGAMEEmitter diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index 9e335f26..a6a47894 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -88,7 +88,7 @@ "source": [ "#@title QD Training Definitions Fields\n", "#@markdown ---\n", - "batch_size = 10 #@param {type:\"number\"}\n", + "batch_size = 100 #@param {type:\"number\"}\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", "max_iterations = 50 #@param {type:\"integer\"}\n", @@ -97,11 +97,11 @@ "iso_sigma = 0.005 #@param {type:\"number\"}\n", "line_sigma = 0.05 #@param {type:\"number\"}\n", "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", - "num_centroids = 50 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", "min_bd = 0. #@param {type:\"number\"}\n", "max_bd = 1.0 #@param {type:\"number\"}\n", "\n", - "batch_size = 128 #@param {type:\"integer\"}\n", + "lstm_batch_size = 128 #@param {type:\"integer\"}\n", "\n", "observation_option = \"no_sd\" #@param['no_sd', 'only_sd', 'full']\n", "hidden_size = 5 #@param {type:\"integer\"}\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb new file mode 100644 index 00000000..11ed6afe --- /dev/null +++ b/examples/pga_aurora.ipynb @@ -0,0 +1,571 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pga_aurora.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with PGA-AURORA in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [PGA-AURORA](https://arxiv.org/abs/2210.03516).\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an AURORA instance and mix it with the right emitter to define PGA-AURORA\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "from typing import Dict, Any\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.0.15 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.aurora import AURORA\n", + "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import scoring_aurora_function\n", + "from qdax.environments.bd_extractors import get_aurora_bd\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", + "\n", + "from qdax.utils import train_seq2seq\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "env_batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "max_iterations = 50 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "\n", + "lstm_batch_size = 128 #@param {type:\"integer\"}\n", + "\n", + "observation_option = \"no_sd\" #@param['no_sd', 'only_sd', 'full']\n", + "hidden_size = 5 #@param {type:\"integer\"}\n", + "l_value_init = 0.2 #@param {type:\"number\"}\n", + "\n", + "traj_sampling_freq = 10 #@param {type:\"integer\"}\n", + "max_observation_size = 25 #@param {type:\"integer\"}\n", + "prior_descriptor_dim = 2 #@param {type:\"integer\"}\n", + "\n", + "proportion_mutation_ga = 0.5 #@param {type:\"number\"}\n", + "\n", + "# TD3 params\n", + "replay_buffer_size = 1000000 #@param {type:\"number\"}\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "critic_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "greedy_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "policy_learning_rate = 1e-3 #@param {type:\"number\"}\n", + "noise_clip = 0.5 #@param {type:\"number\"}\n", + "policy_noise = 0.2 #@param {type:\"number\"}\n", + "discount = 0.99 #@param {type:\"number\"}\n", + "reward_scaling = 1.0 #@param {type:\"number\"}\n", + "transitions_batch_size = 256 #@param {type:\"number\"}\n", + "soft_tau_update = 0.005 #@param {type:\"number\"}\n", + "num_critic_training_steps = 300 #@param {type:\"number\"}\n", + "num_pg_training_steps = 100 #@param {type:\"number\"}\n", + "policy_delay = 2 #@param {type:\"number\"}\n", + "\n", + "log_freq = 5 #@param {type:\"integer\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=env_batch_size)\n", + "fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", + "\n", + "\n", + "# Create the initial environment states\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)\n", + "reset_fn = jax.jit(jax.vmap(env.reset))\n", + "init_states = reset_fn(keys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"\n", + " Play an environment step and return the updated state and the transition.\n", + " \"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " \n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = functools.partial(\n", + " get_aurora_bd,\n", + " option=observation_option,\n", + " hidden_size=hidden_size,\n", + " traj_sampling_freq=traj_sampling_freq,\n", + " max_observation_size=max_observation_size,\n", + ")\n", + "scoring_fn = functools.partial(\n", + " scoring_aurora_function,\n", + " init_states=init_states,\n", + " episode_length=episode_length,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict:\n", + "\n", + " # Get metrics\n", + " grid_empty = repertoire.fitnesses == -jnp.inf\n", + " qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)\n", + " # Add offset for positive qd_score\n", + " qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)\n", + " coverage = 100 * jnp.mean(1.0 - grid_empty)\n", + " max_fitness = jnp.max(repertoire.fitnesses)\n", + "\n", + " return {\"qd_score\": qd_score, \"max_fitness\": max_fitness, \"coverage\": coverage}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the PG-emitter config\n", + "pga_emitter_config = PGAMEConfig(\n", + " env_batch_size=env_batch_size,\n", + " batch_size=transitions_batch_size,\n", + " proportion_mutation_ga=proportion_mutation_ga,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " critic_learning_rate=critic_learning_rate,\n", + " greedy_learning_rate=greedy_learning_rate,\n", + " policy_learning_rate=policy_learning_rate,\n", + " noise_clip=noise_clip,\n", + " policy_noise=policy_noise,\n", + " discount=discount,\n", + " reward_scaling=reward_scaling,\n", + " replay_buffer_size=replay_buffer_size,\n", + " soft_tau_update=soft_tau_update,\n", + " num_critic_training_steps=num_critic_training_steps,\n", + " num_pg_training_steps=num_pg_training_steps,\n", + " policy_delay=policy_delay,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "\n", + "pg_emitter = PGAMEEmitter(\n", + " config=pga_emitter_config,\n", + " policy_network=policy_network,\n", + " env=env,\n", + " variation_fn=variation_fn,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate AURORA\n", + "aurora = AURORA(\n", + " scoring_function=scoring_fn,\n", + " emitter=pg_emitter,\n", + " metrics_function=metrics_fn,\n", + ")\n", + "\n", + "aurora_dims = hidden_size\n", + "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", + "\n", + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " \"\"\"Scan the udpate function.\"\"\"\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = carry\n", + "\n", + " # update\n", + " (repertoire, emitter_state, metrics, random_key,) = aurora.update(\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " )\n", + "\n", + " return (\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " )\n", + "\n", + "# Init algorithm\n", + "# AutoEncoder Params and INIT\n", + "obs_dim = jnp.minimum(env.observation_size, max_observation_size)\n", + "if observation_option == \"full\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim + prior_descriptor_dim,\n", + " )\n", + "elif observation_option == \"no_sd\":\n", + " observations_dims = (\n", + " episode_length // traj_sampling_freq,\n", + " obs_dim,\n", + " )\n", + "elif observation_option == \"only_sd\":\n", + " observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim)\n", + "else:\n", + " ValueError(\"The chosen option is not correct.\")\n", + "\n", + "# define the seq2seq model\n", + "model = train_seq2seq.get_model(\n", + " observations_dims[-1], True, hidden_size=hidden_size\n", + ")\n", + "\n", + "# init the model params\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params = train_seq2seq.get_initial_params(\n", + " model, subkey, (1, *observations_dims)\n", + ")\n", + "\n", + "print(jax.tree_map(lambda x: x.shape, model_params))\n", + "\n", + "# define arbitrary observation's mean/std\n", + "mean_observations = jnp.zeros(observations_dims[-1])\n", + "std_observations = jnp.ones(observations_dims[-1])\n", + "\n", + "# init step of the aurora algorithm\n", + "repertoire, emitter_state, random_key = aurora.init(\n", + " init_variables,\n", + " centroids,\n", + " random_key,\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " l_value_init,\n", + ")\n", + "\n", + "# initializing means and stds and AURORA\n", + "random_key, subkey = jax.random.split(random_key)\n", + "model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train(\n", + " subkey, repertoire, model_params, 0, hidden_size=hidden_size, batch_size=lstm_batch_size\n", + ")\n", + "\n", + "# design aurora's schedule\n", + "default_update_base = 10\n", + "update_base = int(jnp.ceil(default_update_base / log_freq))\n", + "schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch AURORA iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "current_step_estimation = 0\n", + "num_iterations = 0\n", + "\n", + "# Main loop\n", + "n_target = 1024\n", + "\n", + "previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n", + "\n", + "iteration = 0\n", + "while iteration < max_iterations:\n", + "\n", + " (\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " metrics,\n", + " ) = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, emitter_state, random_key, model_params, mean_observations, std_observations),\n", + " (),\n", + " length=log_freq,\n", + " )\n", + "\n", + " num_iterations = iteration * log_freq\n", + "\n", + " # update nb steps estimation\n", + " current_step_estimation += env_batch_size * episode_length * log_freq\n", + "\n", + " # autoencoder steps and CVC\n", + " if (iteration + 1) in schedules:\n", + " # train the autoencoder\n", + " random_key, subkey = jax.random.split(random_key)\n", + " (\n", + " model_params,\n", + " mean_observations,\n", + " std_observations,\n", + " ) = train_seq2seq.lstm_ae_train(\n", + " subkey,\n", + " repertoire,\n", + " model_params,\n", + " iteration,\n", + " hidden_size=hidden_size,\n", + " batch_size=lstm_batch_size\n", + " )\n", + "\n", + " # re-addition of all the new behavioural descriotpors with the new ae\n", + " normalized_observations = (\n", + " repertoire.observations - mean_observations\n", + " ) / std_observations\n", + "\n", + " new_descriptors = model.apply(\n", + " {\"params\": model_params}, normalized_observations, method=model.encode\n", + " )\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=new_descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=repertoire.l_value,\n", + " )\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " elif iteration % 2 == 0:\n", + " # update the l value\n", + " num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf)\n", + "\n", + " # CVC Implementation to keep a constant number of individuals in the archive\n", + " current_error = num_indivs - n_target\n", + " change_rate = current_error - previous_error\n", + " prop_gain = 1 * 10e-6\n", + " l_value = (\n", + " repertoire.l_value\n", + " + (prop_gain * (current_error))\n", + " + (prop_gain * change_rate)\n", + " )\n", + "\n", + " previous_error = current_error\n", + "\n", + " repertoire = repertoire.init(\n", + " genotypes=repertoire.genotypes,\n", + " centroids=repertoire.centroids,\n", + " fitnesses=repertoire.fitnesses,\n", + " descriptors=repertoire.descriptors,\n", + " observations=repertoire.observations,\n", + " l_value=l_value,\n", + " )\n", + "\n", + " iteration += 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in metrics.items():\n", + " print(k, \" - \", v[-1])" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 24222ddf..f1fb5e1c 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -98,8 +98,7 @@ "min_bd = 0. #@param {type:\"number\"}\n", "max_bd = 1.0 #@param {type:\"number\"}\n", "\n", - "#@title PGA-ME Emitter Definitions Fields\n", - "proportion_mutation_ga = 0.5\n", + "proportion_mutation_ga = 0.5 #@param {type:\"number\"}\n", "\n", "# TD3 params\n", "env_batch_size = 100 #@param {type:\"number\"}\n", diff --git a/mkdocs.yml b/mkdocs.yml index 5ab48d12..7c60aa8f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -125,6 +125,8 @@ nav: - SMERL: examples/smerl.ipynb - CMA ES: examples/cmaes.ipynb - NSGA2/SPEA2: examples/nsga2_spea2.ipynb + - AURORA: examples/aurora.ipynb + - PGA AURORA: examples/pga_aurora.ipynb - Jumanji Snake: examples/jumanji_snake.ipynb - API documentation: - Core: @@ -137,6 +139,8 @@ nav: - CMA MEGA: api_documentation/core/cma_mega.md - MOME: api_documentation/core/mome.md - ME ES: api_documentation/core/mees.md + - AURORA: api_documentation/core/aurora.md + - PGA AURORA: api_documentation/core/pga_aurora.md - Baseline algorithms: - SMERL: api_documentation/core/smerl.md - DIAYN: api_documentation/core/diayn.md From 03bda2c5ac5b879c39176091ef8e8dc41c1f290c Mon Sep 17 00:00:00 2001 From: Felix Date: Wed, 1 Mar 2023 21:19:48 +0200 Subject: [PATCH 09/26] add docstrings + clean + move seq2seq model --- qdax/core/aurora.py | 33 +++++++---- .../containers/unstructured_repertoire.py | 58 ++++++++++++------- .../networks/seq2seq_networks.py} | 3 +- qdax/utils/train_seq2seq.py | 28 +-------- 4 files changed, 61 insertions(+), 61 deletions(-) rename qdax/{utils/seq2seq_model.py => core/neuroevolution/networks/seq2seq_networks.py} (98%) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 51766635..13ba7749 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -16,8 +16,7 @@ class AURORA: - """ - Core elements of the AURORA algorithm. + """Core elements of the AURORA algorithm. Args: scoring_function: a function that takes a batch of genotypes and compute @@ -26,7 +25,7 @@ class AURORA: repertoire. It has two compulsory functions. A function that takes emits a new population, and a function that update the internal state of the emitter. - metrics_function: a function that takes a MAP-Elites repertoire and compute + metrics_function: a function that takes a repertoire and computes any useful metric to track its evolution """ @@ -54,19 +53,24 @@ def init( std_observations: jnp.ndarray, l_value: jnp.ndarray, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: - """ - Initialize a Map-Elites grid with an initial population of genotypes. Requires - the definition of centroids that can be computed with any method such as - CVT or Euclidean mapping. + """Initialize an unstructured repertoire with an initial population of + genotypes. Requires the definition of centroids that can be computed with + any method such as CVT or Euclidean mapping. Args: init_genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. + model_params: parameters of the model used to define the behavior + descriptors. + mean_observations: mean of the observations gathered. + std_observations: standard deviation of the observations + gathered. Returns: - an initialized MAP-Elite repertoire with the initial state of the emitter. + an initialized unstructured repertoire with the initial state of + the emitter. """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( init_genotypes, @@ -84,6 +88,7 @@ def init( observations=extra_scores["last_valid_observations"], # type: ignore l_value=l_value, ) + # get initial state of the emitter emitter_state, random_key = self._emitter.init( init_genotypes=init_genotypes, random_key=random_key @@ -110,16 +115,22 @@ def update( mean_observations: jnp.ndarray, std_observations: jnp.ndarray, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: - """ - Performs one iteration of the MAP-Elites algorithm. + """Main step of the AURORA algorithm. + + + Performs one iteration of the AURORA algorithm. 1. A batch of genotypes is sampled in the archive and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the archive. Args: - repertoire: the MAP-Elites repertoire + repertoire: unstructured repertoire emitter_state: state of the emitter random_key: a jax PRNG random key + model_params: params of the model used to define the behavior descriptor. + mean_observations: mean of the observations gathered. + std_observations: standard deviation of the observations + gathered. Results: the updated MAP-Elites repertoire diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 48648be9..8e7136c5 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -13,10 +13,9 @@ @partial(jax.jit, static_argnames=("k_nn",)) def get_cells_indices( - batch_of_descriptors: jnp.ndarray, centroids: jnp.ndarray, k_nn: int + batch_of_descriptors: Descriptor, centroids: Centroid, k_nn: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - Returns the array of cells indices for a batch of descriptors + """Returns the array of cells indices for a batch of descriptors given the centroids of the grid. Args: @@ -34,8 +33,9 @@ def _get_cells_indices( centroids: jnp.ndarray, k_nn: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - set_of_descriptors of shape (1, num_descriptors) + """Inner function. + + descriptors of shape (1, num_descriptors) centroids of shape (num_centroids, num_descriptors) """ @@ -66,6 +66,7 @@ def intra_batch_comp( eval_scores: jnp.ndarray, l_value: jnp.ndarray, ) -> jnp.ndarray: + """Function to know if an individual should be kept or not.""" # Check for individuals that are Nans, we remove them at the end not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) @@ -208,6 +209,8 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): is (num_centroids, num_descriptors). centroids: an array the contains the centroids of the tesselation. The array shape is (num_centroids, num_descriptors). + observations: observations that the genotype gathered in the environment. + ages: time spent by the genotype in the repertoire. """ genotypes: Genotype @@ -249,7 +252,7 @@ def flatten_genotype(genotype: Genotype) -> jnp.ndarray: def load( cls, reconstruction_fn: Callable, path: str = "./" ) -> UnstructuredRepertoire: - """Loads a MAP Elites Grid. + """Loads an unstructured repertoire. Args: reconstruction_fn: Function to reconstruct a PyTree @@ -257,7 +260,7 @@ def load( path: Path where the data is saved. Defaults to "./". Returns: - A MAP Elites Repertoire. + An unstructured repertoire. """ flat_genotypes = jnp.load(path + "genotypes.npy") @@ -288,6 +291,19 @@ def add( batch_of_fitnesses: Fitness, batch_of_observations: Observation, ) -> UnstructuredRepertoire: + """Adds a batch of genotypes to the repertoire. + + Args: + batch_of_genotypes: genotypes of the individuals to be considered + for addition in the repertoire. + batch_of_descriptors: associated descriptors. + batch_of_fitnesses: associated fitness. + batch_of_observations: associated observations. + + Returns: + A new unstructured repertoire where the relevant individuals have been + added. + """ # We need to replace all the descriptors that are not filled with jnp inf filtered_descriptors = jnp.where( @@ -436,8 +452,7 @@ def add( @partial(jax.jit, static_argnames=("num_samples",)) def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: - """ - Sample elements in the grid. + """Sample elements in the repertoire. Args: random_key: a jax PRNG random key @@ -470,14 +485,10 @@ def init( l_value: jnp.ndarray, ages: Optional[jnp.ndarray] = None, ) -> UnstructuredRepertoire: - """ - Initialize a Map-Elites repertoire with an initial population of genotypes. + """Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. - Note: this function has been kept outside of the object MapElites, so it can - be called easily called from other modules. - Args: genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) @@ -485,9 +496,12 @@ def init( descriptors: descriptors of the initial genotypes of shape (batch_size, num_descriptors) centroids: tesselation centroids of shape (batch_size, num_descriptors) + observations: observations experienced in the evaluation task. + l_value: threshold distance of the repertoire. + ages: ages of the genotypes. Returns: - an initialized MAP-Elite repertoire + an initialized unstructured repertoire. """ # Initialize grid with default values @@ -692,14 +706,10 @@ def init_relevant( proximity_scores: jnp.ndarray, ages: Optional[jnp.ndarray] = None, ) -> UnstructuredRepertoire: - """ - Initialize a Map-Elites repertoire with an initial population of genotypes. + """Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. - Note: this function has been kept outside of the object MapElites, so it can - be called easily called from other modules. - Args: genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) @@ -707,9 +717,14 @@ def init_relevant( descriptors: descriptors of the initial genotypes of shape (batch_size, num_descriptors) centroids: tesselation centroids of shape (batch_size, num_descriptors) + observations: observations gathered by the genotypes while being evaluated + on the task of interest. + l_value: threshold of the repertoire. + proximity_scores: measure of proximity of the individuals to the population. + ages: ages of the individuals. Returns: - an initialized MAP-Elite repertoire + an initialized unstructured repertoire """ # Initialize grid with default values @@ -729,6 +744,7 @@ def init_relevant( if ages is None: ages = jnp.zeros(shape=num_centroids) + repertoire = UnstructuredRepertoire( genotypes=default_genotypes, fitnesses=default_fitnesses, diff --git a/qdax/utils/seq2seq_model.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py similarity index 98% rename from qdax/utils/seq2seq_model.py rename to qdax/core/neuroevolution/networks/seq2seq_networks.py index 4660de84..83070c2d 100644 --- a/qdax/utils/seq2seq_model.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -45,8 +45,7 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: carried_lstm_state = tuple( select_carried_state(*s) for s in zip(new_lstm_state, lstm_state) ) - # Update `is_eos`. - # is_eos = jnp.logical_or(is_eos, x[:, 8]) + return (carried_lstm_state, is_eos), y @staticmethod diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index edb4be3f..2360f812 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -111,18 +111,10 @@ def lstm_ae_train( repertoire.observations.shape[-1], teacher_force=True, hidden_size=hidden_size ) - print("Beginning of the lstm ae training: ") - print("Repertoire observation: ", repertoire.observations) - - print("Repertoire fitnesses: ", repertoire.fitnesses) - # compute mean/std of the obs for normalization mean_obs = jnp.nanmean(repertoire.observations, axis=(0, 1)) std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) - print("Mean obs - wo NaN: ", mean_obs) - print("Std obs - wo NaN: ", std_obs) - # TODO: maybe we could just compute this data on the valid dataset # create optimizer and optimized state @@ -131,11 +123,9 @@ def lstm_ae_train( # size of the repertoire repertoire_size = repertoire.centroids.shape[0] - print("Repertoire size: ", repertoire_size) # number of individuals in the repertoire num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) - print("Number of individuals: ", num_indivs) # select repertoire_size indexes going from 0 to num_indivs # TODO: WHY?? @@ -143,34 +133,27 @@ def lstm_ae_train( idx_p1 = jax.random.randint( key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs ) - print("idx p1: ", idx_p1) # TODO: what is the diff with repertoire_size?? tot_indivs = repertoire.fitnesses.ravel().shape[0] - print("Total individuals: ", tot_indivs) # get indexes where fitness is not -inf?? indexes = jnp.argwhere( jnp.logical_not(jnp.isinf(repertoire.fitnesses)), size=tot_indivs ) indexes = jnp.transpose(indexes, axes=(1, 0)) - print("Indexes: ", indexes) # ??? indiv_indices = jnp.array( jnp.ravel_multi_index(indexes, repertoire.fitnesses.shape, mode="clip") ).astype(int) - print("Indiv indices: ", indexes) # ??? valid_indexes = indiv_indices.at[idx_p1].get() - print("Valid indexes: ", valid_indexes) # Normalising Dataset steps_per_epoch = repertoire.observations.shape[0] // batch_size - print("Steps per epoch: ", steps_per_epoch) - loss_val = 0.0 for epoch in range(num_epochs): rng, shuffle_key = jax.random.split(rng, 2) @@ -188,23 +171,18 @@ def lstm_ae_train( ) / std_obs training_dataset = training_dataset.at[valid_indexes].get() - if epoch == 0: - print("Training dataset for first epoch: ", training_dataset) - print("Training dataset first data for first epoch: ", training_dataset[0]) - for i in range(steps_per_epoch): batch = jnp.asarray( training_dataset.at[ (i * batch_size) : (i * batch_size) + batch_size, :, : ].get() ) - # print(batch) + if batch.shape[0] < batch_size: # print(batch.shape) continue state, loss_val = train_step(state, batch, rng) - print("Loss value has been updated, new value: ", loss_val) # To see the actual value we cannot jit this function (i.e. the _one_es_epoch # function nor the train function) @@ -213,10 +191,6 @@ def lstm_ae_train( # TODO: put this in metrics so we can jit the function and see the metrics # TODO: not urgent because the training is not that long - train_step.clear_cache() - del tx - del model params = state.params - del state return params, mean_obs, std_obs From 9bbcf0aceff49cd77d9397e864e1622a7ae077d6 Mon Sep 17 00:00:00 2001 From: Felix Date: Thu, 2 Mar 2023 16:46:09 +0200 Subject: [PATCH 10/26] minor fix - trigger ci --- qdax/core/aurora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 13ba7749..9d4df338 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -136,7 +136,7 @@ def update( the updated MAP-Elites repertoire the updated (if needed) emitter state metrics about the updated repertoire - a new jax PRNG key + a new key """ # generate offsprings with the emitter genotypes, random_key = self._emitter.emit( From 4b6c4ec1b0e8c50e78338c473f4eeebe1220f57a Mon Sep 17 00:00:00 2001 From: Felix Date: Thu, 2 Mar 2023 17:39:16 +0200 Subject: [PATCH 11/26] upgrade action in ci --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c9b979c2..27f55b44 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -60,7 +60,7 @@ jobs: # cache-to: type=gha,mode=max - name: Build and push test Docker image - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 with: context: . file: dev.Dockerfile @@ -72,7 +72,7 @@ jobs: cache-to: type=gha,mode=max - name: Build and push tool Docker image - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v4 with: context: . file: tool.Dockerfile From d13294449970d7fcb0f6c67f4da816308e7bc42b Mon Sep 17 00:00:00 2001 From: Felix Date: Thu, 2 Mar 2023 17:59:53 +0200 Subject: [PATCH 12/26] update isort --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d037820e..af8f2bc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.6.4 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] From 88bb2319b7712f763ce34b2940527106dc3c604f Mon Sep 17 00:00:00 2001 From: Felix Date: Thu, 2 Mar 2023 19:12:57 +0200 Subject: [PATCH 13/26] fix path to seq2seq --- qdax/utils/train_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 2360f812..ec5b3c7f 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -15,8 +15,8 @@ from flax.training import train_state from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire +from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq from qdax.types import Observation, Params, RNGKey -from qdax.utils.seq2seq_model import Seq2seq Array = Any PRNGKey = Any From 5d93322c2c738c90b5b29040552ec3f7782caed3 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 20 Apr 2023 19:54:24 +0200 Subject: [PATCH 14/26] removing ages and centroids and add max_size --- qdax/core/aurora.py | 21 +- .../containers/unstructured_repertoire.py | 339 +----------------- tests/core_test/aurora_test.py | 14 +- 3 files changed, 40 insertions(+), 334 deletions(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 9d4df338..ec23a519 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -5,6 +5,7 @@ from functools import partial from typing import Callable, Optional, Tuple +import flax.struct import jax import jax.numpy as jnp from chex import ArrayTree @@ -12,7 +13,15 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, Fitness, Genotype, Metrics, Params, RNGKey +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.types import Descriptor, Fitness, Genotype, Metrics, Params, RNGKey, Observation + + +@flax.struct.dataclass +class AuroraExtraInfo: + model_params: Params + mean_observations: jnp.ndarray + std_observations: jnp.ndarray class AURORA: @@ -32,26 +41,30 @@ class AURORA: def __init__( self, scoring_function: Callable[ - [Genotype, RNGKey, Params, jnp.ndarray, jnp.ndarray], + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ArrayTree, RNGKey], ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], + bd_extraction_fn: Callable[ + [QDTransition, jnp.ndarray, Params, Observation, Observation], Descriptor + ], ) -> None: self._scoring_function = scoring_function self._emitter = emitter self._metrics_function = metrics_function + self._bd_extraction_fn = bd_extraction_fn @partial(jax.jit, static_argnames=("self",)) def init( self, init_genotypes: Genotype, - centroids: Centroid, random_key: RNGKey, model_params: Params, mean_observations: jnp.ndarray, std_observations: jnp.ndarray, l_value: jnp.ndarray, + max_size: int, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: """Initialize an unstructured repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with @@ -84,9 +97,9 @@ def init( genotypes=init_genotypes, fitnesses=fitnesses, descriptors=descriptors, - centroids=centroids, observations=extra_scores["last_valid_observations"], # type: ignore l_value=l_value, + max_size=max_size, ) # get initial state of the emitter diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 8e7136c5..1685d66b 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -3,7 +3,7 @@ from functools import partial from typing import Callable, Optional, Tuple -import flax +import flax.struct import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree @@ -125,74 +125,6 @@ def intra_batch_comp( return jnp.logical_not(discard_indiv) -@jax.jit -def intra_batch_comp_relevant( - normed: jnp.ndarray, - current_index: jnp.ndarray, - normed_all: jnp.ndarray, - eval_scores: jnp.ndarray, - relevant_l_values: jnp.ndarray, -) -> jnp.ndarray: - - # Check for individuals that are Nans, we remove them at the end - not_existent = jnp.where((jnp.isnan(normed)).any(), True, False) - - # Fill in Nans to do computations - normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed) - eval_scores = jnp.where( - jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores - ) - - # If we do not use a fitness (i.e same fitness everywhere, we create a virtual - # fitness function to add individuals with the same bd) - additional_score = jnp.where( - jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0 - ) - additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0]) - - # Add scores to empty individuals - eval_scores = jnp.where( - jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores - ) - - # Virtual eval_scores - eval_scores = eval_scores + additional_scores - # For each point we check what other points are the closest ones. - knn_relevant_scores, knn_relevant_indices = jax.lax.top_k( - -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0] - ) - # We negated the scores to use top_k so we reverse it. - knn_relevant_scores = knn_relevant_scores * -1 - - # Check if the individual is close enough to compare (under l-value) - fitness = jnp.where( - jnp.squeeze(knn_relevant_scores < relevant_l_values), True, False - ) - - # We want to eliminate the same individual (distance 0) - fitness = jnp.where(knn_relevant_indices == current_index, False, fitness) - current_fitness = jnp.squeeze( - eval_scores.at[knn_relevant_indices.at[0].get()].get() - ) - - # Is the fitness of the other individual higher? - # If both are True then we discard the current individual since this individual - # would be replaced by the better one. - discard_indiv = jnp.logical_and( - jnp.where( - eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False - ), - fitness, - ).any() - - # Discard Individuals with Nans as their BD (mainly for the readdition where we - # have NaN bds) - discard_indiv = jnp.logical_or(discard_indiv, not_existent) - - # Negate to know if we keep the individual - return jnp.logical_not(discard_indiv) - - class UnstructuredRepertoire(flax.struct.PyTreeNode): """ Class for the unstructured repertoire in Map Elites. @@ -210,16 +142,14 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): centroids: an array the contains the centroids of the tesselation. The array shape is (num_centroids, num_descriptors). observations: observations that the genotype gathered in the environment. - ages: time spent by the genotype in the repertoire. """ genotypes: Genotype fitnesses: Fitness descriptors: Descriptor - centroids: Centroid observations: Observation - ages: jnp.ndarray l_value: jnp.ndarray + max_size: int = flax.struct.field(pytree_node=False) def save(self, path: str = "./") -> None: """Saves the grid on disk in the form of .npy files. @@ -243,10 +173,9 @@ def flatten_genotype(genotype: Genotype) -> jnp.ndarray: jnp.save(path + "genotypes.npy", flat_genotypes) jnp.save(path + "fitnesses.npy", self.fitnesses) jnp.save(path + "descriptors.npy", self.descriptors) - jnp.save(path + "centroids.npy", self.centroids) jnp.save(path + "observations.npy", self.observations) jnp.save(path + "l_value.npy", self.l_value) - jnp.save(path + "ages.npy", self.ages) + jnp.save(path + "max_size.npy", self.max_size) @classmethod def load( @@ -268,19 +197,17 @@ def load( fitnesses = jnp.load(path + "fitnesses.npy") descriptors = jnp.load(path + "descriptors.npy") - centroids = jnp.load(path + "centroids.npy") observations = jnp.load(path + "observations.npy") l_value = jnp.load(path + "l_value.npy") - ages = jnp.load(path + "ages.npy") + max_size = int(jnp.load(path + "max_size.npy").item()) return UnstructuredRepertoire( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - centroids=centroids, observations=observations, l_value=l_value, - ages=ages, + max_size=max_size, ) @jax.jit @@ -335,8 +262,6 @@ def add( batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) - num_centroids = self.centroids.shape[0] - # TODO: Doesn't Work if Archive is full. Need to use the closest individuals # in that case. empty_indexes = jnp.squeeze( @@ -394,7 +319,7 @@ def add( best_fitnesses = jax.ops.segment_max( batch_of_fitnesses, batch_of_indices.astype(jnp.int32).squeeze(), - num_segments=num_centroids, + num_segments=self.max_size, ) cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0) @@ -414,7 +339,7 @@ def add( # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, x=batch_of_indices, y=self.max_size, ) # create new grid @@ -438,16 +363,13 @@ def add( batch_of_observations.squeeze() ) - new_ages = self.ages.at[batch_of_indices.squeeze()].set(0.0) + 1 - return UnstructuredRepertoire( genotypes=new_grid_genotypes, fitnesses=new_fitnesses.squeeze(), descriptors=new_descriptors.squeeze(), - centroids=new_descriptors.squeeze(), observations=new_observations.squeeze(), l_value=self.l_value, - ages=new_ages, + max_size=self.max_size, ) @partial(jax.jit, static_argnames=("num_samples",)) @@ -480,10 +402,9 @@ def init( genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, - centroids: Centroid, observations: Observation, l_value: jnp.ndarray, - ages: Optional[jnp.ndarray] = None, + max_size: int, ) -> UnstructuredRepertoire: """Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method @@ -498,264 +419,36 @@ def init( centroids: tesselation centroids of shape (batch_size, num_descriptors) observations: observations experienced in the evaluation task. l_value: threshold distance of the repertoire. - ages: ages of the genotypes. + max_size: maximal size of the container Returns: an initialized unstructured repertoire. """ + # Initialize grid with default values - num_centroids = centroids.shape[0] - default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) default_genotypes = jax.tree_map( lambda x: jnp.full( - shape=(num_centroids,) + x.shape[1:], fill_value=jnp.nan + shape=(max_size,) + x.shape[1:], fill_value=jnp.nan ), genotypes, ) - default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1])) + default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1])) default_observations = jnp.full( - shape=(num_centroids,) + observations.shape[1:], fill_value=jnp.nan + shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan ) - if ages is None: - ages = jnp.zeros(shape=num_centroids) - repertoire = UnstructuredRepertoire( genotypes=default_genotypes, fitnesses=default_fitnesses, descriptors=default_descriptors, - centroids=centroids, observations=default_observations, l_value=l_value, - ages=ages, + max_size=max_size, ) return repertoire.add( # type: ignore genotypes, descriptors, fitnesses, observations ) - - @jax.jit - def add_relevant( - self, - batch_of_genotypes: Genotype, - batch_of_descriptors: Descriptor, - batch_of_fitnesses: Fitness, - batch_of_observations: Observation, - proximity_scores: jnp.ndarray, - ) -> UnstructuredRepertoire: - - # Calculating new l values - new_l_values = self.l_value / proximity_scores - - # We need to replace all the descriptors that are not filled with jnp inf - filtered_descriptors = jnp.where( - jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1), - jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf), - self.descriptors, - ) - - batch_of_indices, batch_of_distances = get_cells_indices( - batch_of_descriptors, filtered_descriptors, 2 - ) - - # Save the second nearest neighbours to check a condition - second_neighbours = batch_of_distances.at[..., 1].get() - - # Keep the Nearest neighbours - batch_of_indices = batch_of_indices.at[..., 0].get() - - # Keep the Nearest neighbours - batch_of_distances = batch_of_distances.at[..., 0].get() - - # We remove individuals that are too close to the second nn. - # This avoids having clusters of individuals after adding them. - not_novel_enough = jnp.where( - jnp.squeeze(second_neighbours <= new_l_values), True, False - ) - - batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1) - batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1) - - num_centroids = self.centroids.shape[0] - - # TODO: Doesn't Work if Archive is full. Use closest individuals in that case. - empty_indexes = jnp.squeeze( - jnp.nonzero( - jnp.where(jnp.isinf(self.fitnesses), 1, 0), - size=batch_of_indices.shape[0], - fill_value=-1, - )[0] - ) - batch_of_indices = jnp.where( - jnp.squeeze(batch_of_distances <= new_l_values), - jnp.squeeze(batch_of_indices), - -1, - ) - - # get all the indices of the empty bds first and then the filled ones - # (because of -1) - sorted_bds = jax.lax.top_k( - -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0] - )[1] - batch_of_indices = jnp.where( - jnp.squeeze( - batch_of_distances.at[sorted_bds].get() - <= new_l_values.at[sorted_bds].get() - ), - batch_of_indices.at[sorted_bds].get(), - empty_indexes, - ) - - batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) - - # ReIndexing of all the inputs to the correct sorted way - batch_of_distances = batch_of_distances.at[sorted_bds].get() - batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() - batch_of_genotypes = jax.tree_map( - lambda x: x.at[sorted_bds].get(), batch_of_genotypes - ) - batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() - batch_of_observations = batch_of_observations.at[sorted_bds].get() - not_novel_enough = not_novel_enough.at[sorted_bds].get() - new_l_values = new_l_values.at[sorted_bds].get() - - # Check to find Individuals with same BD within the Batch - keep_indiv = jax.jit( - jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, 0), out_axes=(0)) - )( - batch_of_descriptors.squeeze(), - jnp.arange( - 0, batch_of_descriptors.shape[0], 1 - ), # keep track of where we are in the batch to assure right comparisons - batch_of_descriptors.squeeze(), - batch_of_fitnesses.squeeze(), - new_l_values, - ) - - # get fitness segment max - best_fitnesses = jax.ops.segment_max( - batch_of_fitnesses, - batch_of_indices.astype(jnp.int32).squeeze(), - num_segments=num_centroids, - ) - - cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0) - - # put dominated fitness to -jnp.inf - batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf - ) - - # get addition condition - grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) - current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0) - addition_condition = batch_of_fitnesses > current_fitnesses - addition_condition = jnp.logical_and( - addition_condition, jnp.expand_dims(keep_indiv, axis=-1) - ) - - # assign fake position when relevant : num_centroids is out of bounds - batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids - ) - - # create new grid - new_grid_genotypes = jax.tree_map( - lambda grid_genotypes, new_genotypes: grid_genotypes.at[ - batch_of_indices.squeeze() - ].set(new_genotypes), - self.genotypes, - batch_of_genotypes, - ) - - # compute new fitness and descriptors - new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set( - batch_of_fitnesses.squeeze() - ) - new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set( - batch_of_descriptors.squeeze() - ) - - new_observations = self.observations.at[batch_of_indices.squeeze()].set( - batch_of_observations.squeeze() - ) - - new_ages = self.ages.at[batch_of_indices.squeeze()].set(0.0) + 1 - - return UnstructuredRepertoire( - genotypes=new_grid_genotypes, - fitnesses=new_fitnesses.squeeze(), - descriptors=new_descriptors.squeeze(), - centroids=new_descriptors.squeeze(), - observations=new_observations.squeeze(), - l_value=self.l_value, - ages=new_ages, - ) - - @classmethod - def init_relevant( - cls, - genotypes: Genotype, - fitnesses: Fitness, - descriptors: Descriptor, - centroids: Centroid, - observations: Observation, - l_value: float, - proximity_scores: jnp.ndarray, - ages: Optional[jnp.ndarray] = None, - ) -> UnstructuredRepertoire: - """Initialize a Map-Elites repertoire with an initial population of genotypes. - Requires the definition of centroids that can be computed with any method - such as CVT or Euclidean mapping. - - Args: - genotypes: initial genotypes, pytree in which leaves - have shape (batch_size, num_features) - fitnesses: fitness of the initial genotypes of shape (batch_size,) - descriptors: descriptors of the initial genotypes - of shape (batch_size, num_descriptors) - centroids: tesselation centroids of shape (batch_size, num_descriptors) - observations: observations gathered by the genotypes while being evaluated - on the task of interest. - l_value: threshold of the repertoire. - proximity_scores: measure of proximity of the individuals to the population. - ages: ages of the individuals. - - Returns: - an initialized unstructured repertoire - """ - - # Initialize grid with default values - num_centroids = centroids.shape[0] - default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) - default_genotypes = jax.tree_map( - lambda x: jnp.full( - shape=(num_centroids,) + x.shape[1:], fill_value=jnp.nan - ), - genotypes, - ) - default_descriptors = jnp.zeros(shape=(num_centroids, centroids.shape[-1])) - - default_observations = jnp.full( - shape=(num_centroids,) + observations.shape[1:], fill_value=jnp.nan - ) - - if ages is None: - ages = jnp.zeros(shape=num_centroids) - - repertoire = UnstructuredRepertoire( - genotypes=default_genotypes, - fitnesses=default_fitnesses, - descriptors=default_descriptors, - centroids=centroids, - observations=default_observations, - l_value=l_value, - ages=ages, - ) - - # return new_repertoire - return repertoire.add_relevant( # type: ignore - genotypes, descriptors, fitnesses, observations, proximity_scores - ) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 6bbd7587..5b7fa514 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -31,7 +31,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: max_iterations = 5 seed = 42 policy_hidden_layer_sizes = (64, 64) - num_centroids = 50 + max_size = 50 lstm_batch_size = 12 @@ -148,7 +148,7 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: ) aurora_dims = hidden_size - centroids = jnp.zeros(shape=(num_centroids, aurora_dims)) + centroids = jnp.zeros(shape=(max_size, aurora_dims)) @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: @@ -214,12 +214,12 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: # init step of the aurora algorithm repertoire, _, random_key = aurora.init( init_variables, - centroids, random_key, model_params, mean_observations, std_observations, - l_value_init, + jnp.array(l_value_init), + max_size, ) # initializing means and stds and AURORA @@ -288,11 +288,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: ) repertoire = repertoire.init( genotypes=repertoire.genotypes, - centroids=repertoire.centroids, fitnesses=repertoire.fitnesses, descriptors=new_descriptors, observations=repertoire.observations, l_value=repertoire.l_value, + max_size=repertoire.max_size, ) num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) @@ -306,7 +306,7 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: prop_gain = 1 * 10e-6 l_value = ( repertoire.l_value - + (prop_gain * (current_error)) + + (prop_gain * current_error) + (prop_gain * change_rate) ) @@ -314,11 +314,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: repertoire = repertoire.init( genotypes=repertoire.genotypes, - centroids=repertoire.centroids, fitnesses=repertoire.fitnesses, descriptors=repertoire.descriptors, observations=repertoire.observations, l_value=l_value, + max_size=repertoire.max_size, ) iteration += 1 From 8293057057f7302ae2886ed97338754e44825be2 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Sun, 23 Apr 2023 23:47:21 +0100 Subject: [PATCH 15/26] refactoring bd and obs generation --- qdax/core/aurora.py | 40 ++++---- qdax/environments/bd_extractors.py | 51 ++++------ qdax/tasks/brax_envs.py | 76 +++++---------- qdax/utils/train_seq2seq.py | 4 + tests/core_test/aurora_test.py | 151 +++++++++++++---------------- 5 files changed, 132 insertions(+), 190 deletions(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index ec23a519..3e87907b 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -14,14 +14,11 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.environments.bd_extractors import AuroraExtraInfo from qdax.types import Descriptor, Fitness, Genotype, Metrics, Params, RNGKey, Observation -@flax.struct.dataclass -class AuroraExtraInfo: - model_params: Params - mean_observations: jnp.ndarray - std_observations: jnp.ndarray + class AURORA: @@ -46,23 +43,21 @@ def __init__( ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], - bd_extraction_fn: Callable[ - [QDTransition, jnp.ndarray, Params, Observation, Observation], Descriptor + encoder_function: Callable[ + [Observation, AuroraExtraInfo], Descriptor ], ) -> None: self._scoring_function = scoring_function self._emitter = emitter self._metrics_function = metrics_function - self._bd_extraction_fn = bd_extraction_fn + self._encoder_fn = encoder_function @partial(jax.jit, static_argnames=("self",)) def init( self, init_genotypes: Genotype, random_key: RNGKey, - model_params: Params, - mean_observations: jnp.ndarray, - std_observations: jnp.ndarray, + aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -88,16 +83,19 @@ def init( fitnesses, descriptors, extra_scores, random_key = self._scoring_function( init_genotypes, random_key, - model_params, - mean_observations, - std_observations, ) + observations = extra_scores["last_valid_observations"] + + descriptors = self._encoder_fn(observations, + aurora_extra_info) + + repertoire = UnstructuredRepertoire.init( genotypes=init_genotypes, fitnesses=fitnesses, descriptors=descriptors, - observations=extra_scores["last_valid_observations"], # type: ignore + observations=observations, # type: ignore l_value=l_value, max_size=max_size, ) @@ -124,9 +122,7 @@ def update( repertoire: MapElitesRepertoire, emitter_state: Optional[EmitterState], random_key: RNGKey, - model_params: Params, - mean_observations: jnp.ndarray, - std_observations: jnp.ndarray, + aurora_extra_info: AuroraExtraInfo, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]: """Main step of the AURORA algorithm. @@ -159,11 +155,13 @@ def update( fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key, - model_params, - mean_observations, - std_observations, ) + observations = extra_scores["last_valid_observations"] + + descriptors = self._encoder_fn(observations, + aurora_extra_info) + # add genotypes and observations in the repertoire repertoire = repertoire.add( genotypes, descriptors, fitnesses, extra_scores["last_valid_observations"] diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index a68abf2c..d7ca6507 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -1,3 +1,4 @@ +import flax.struct import jax import jax.numpy as jnp @@ -39,43 +40,31 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri return descriptors -def get_aurora_bd( - data: QDTransition, - mask: jnp.ndarray, - model_params: Params, - mean_observations: jnp.ndarray, - std_observations: jnp.ndarray, - option: str = "full", - hidden_size: int = 10, - traj_sampling_freq: int = 10, - max_observation_size: int = 25, -) -> Descriptor: - """Compute final aurora embedding. - This function suppose that state descriptor is the xy position, as it - just select the final one of the state descriptors given. - """ - # reshape mask for bd extraction - mask = jnp.expand_dims(mask, axis=-1) +class AuroraExtraInfo(flax.struct.PyTreeNode): + model_params: Params - state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] +class AuroraExtraInfoNormalization(AuroraExtraInfo): + mean_observations: jnp.ndarray + std_observations: jnp.ndarray - # add the x/y position - (batch_size, traj_length, 2) - state_desc = data.state_desc[:, ::traj_sampling_freq] - print("State Observations: ", state_obs) - print("XY positions: ", state_desc) +def get_aurora_encoding( + observations: jnp.ndarray, + model: flax.linen.Module, + aurora_extra_info: AuroraExtraInfoNormalization, +) -> Descriptor: + """ + Compute final aurora embedding. - if option == "full": - observations = jnp.concatenate([state_desc, state_obs], axis=-1) - print("New observations: ", observations) - elif option == "no_sd": - observations = state_obs - elif option == "only_sd": - observations = state_desc + This function suppose that state descriptor is the xy position, as it + just select the final one of the state descriptors given. + """ + model_params = aurora_extra_info.model_params + mean_observations = aurora_extra_info.mean_observations + std_observations = aurora_extra_info.std_observations # lstm seq2seq - model = train_seq2seq.get_model(observations.shape[-1], True, hidden_size) normalized_observations = (observations - mean_observations) / std_observations descriptors = model.apply( {"params": model_params}, normalized_observations, method=model.encode @@ -83,4 +72,4 @@ def get_aurora_bd( print("Observations out of get aurora bd: ", observations) - return descriptors.squeeze(), observations.squeeze() + return descriptors.squeeze() diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 8cc267c7..05868a3e 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -9,7 +9,7 @@ import qdax.environments from qdax import environments -from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition from qdax.core.neuroevolution.mdp_utils import generate_unroll from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( @@ -83,6 +83,15 @@ def default_play_step_fn( return default_play_step_fn +def get_mask_from_transitions( + data: Transition, +) -> jnp.ndarray: + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + return mask + + @partial( jax.jit, static_argnames=( @@ -135,9 +144,7 @@ def scoring_function_brax_envs( _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) + mask = get_mask_from_transitions(data) # scores fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) @@ -341,22 +348,10 @@ def create_default_brax_task_components( return env, policy_network, scoring_fn, random_key -def scoring_aurora_function( - policies_params: Genotype, - random_key: RNGKey, - model_params: Params, - mean_observations: jnp.ndarray, - std_observations: jnp.ndarray, - init_states: brax.envs.State, - episode_length: int, - play_step_fn: Callable[ - [EnvState, Params, RNGKey], - Tuple[EnvState, Params, RNGKey, QDTransition], - ], - behavior_descriptor_extractor: Callable[ - [QDTransition, jnp.ndarray, Params, Observation, Observation], Descriptor - ], -) -> Tuple[Fitness, Descriptor, Dict[str, Union[jnp.ndarray, QDTransition]], RNGKey]: +def get_aurora_scoring_fn( + scoring_fn: Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], + observation_extractor_fn: Callable[[Transition], Observation], +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]]: """Evaluates policies contained in flatten_variables in parallel This rollout is only deterministic when all the init states are the same. @@ -367,37 +362,14 @@ def scoring_aurora_function( choice was made for performance reason, as the reset function of brax envs is quite time consuming. If pure stochasticity of the environment is needed for a use case, please open an issue. - """ - # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) - unroll_fn = partial( - generate_unroll, - episode_length=episode_length, - play_step_fn=play_step_fn, - random_key=subkey, - ) - - _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) - - # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) - - # scores - add offset to ensure positive fitness (through positive rewards) - fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) - descriptors, observations = behavior_descriptor_extractor( - data, mask, model_params, mean_observations, std_observations - ) - - return ( - fitnesses, - descriptors, - { - "transitions": data, - "last_valid_observations": observations, - }, - random_key, - ) + @functools.wraps(scoring_fn) + def _wrapper(params: Params, # Perform rollouts with each policy + random_key: RNGKey): + fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) + data = extra_scores["data"] + observation = observation_extractor_fn(data) # type: ignore + extra_scores["last_valid_observations"] = observation + return fitnesses, None, extra_scores, random_key + return _wrapper diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index ec5b3c7f..d8c74ae6 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -25,6 +25,7 @@ def get_model( obs_size: int, teacher_force: bool = False, hidden_size: int = 10 ) -> Seq2seq: + # TODO: add docstring return Seq2seq( teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size ) @@ -33,6 +34,7 @@ def get_model( def get_initial_params( model: Seq2seq, rng: PRNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: + # TODO: add docstring """Returns the initial parameters of a seq2seq model.""" rng1, rng2, rng3 = jax.random.split(rng, 3) variables = model.init( @@ -47,6 +49,8 @@ def get_initial_params( def train_step( state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey ) -> Tuple[train_state.TrainState, Dict[str, float]]: + # TODO: add docstring + """Trains one step.""" lstm_key = jax.random.fold_in(lstm_rng, state.step) dropout_key, lstm_key = jax.random.split(lstm_key, 2) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 5b7fa514..b665ef15 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -14,9 +14,9 @@ from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.environments.bd_extractors import get_aurora_bd -from qdax.tasks.brax_envs import scoring_aurora_function -from qdax.types import EnvState, Params, RNGKey +from qdax.environments.bd_extractors import get_aurora_encoding +from qdax.tasks.brax_envs import get_aurora_scoring_fn, create_default_brax_task_components +from qdax.types import EnvState, Params, RNGKey, Observation from qdax.utils import train_seq2seq @@ -25,12 +25,9 @@ [("halfcheetah_uni", 10), ("walker2d_uni", 10), ("hopper_uni", 10)], ) def test_aurora(env_name: str, batch_size: int) -> None: - batch_size = batch_size - env_name = env_name episode_length = 250 max_iterations = 5 seed = 42 - policy_hidden_layer_sizes = (64, 64) max_size = 50 lstm_batch_size = 12 @@ -45,18 +42,13 @@ def test_aurora(env_name: str, batch_size: int) -> None: log_freq = 5 - # Init environment - env = environments.create(env_name, episode_length=episode_length) - # Init a random key random_key = jax.random.PRNGKey(seed) - # Init policy network - policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) - policy_network = MLP( - layer_sizes=policy_layer_sizes, - kernel_init=jax.nn.initializers.lecun_uniform(), - final_activation=jnp.tanh, + # Init environment + env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + env_name=env_name, + random_key=random_key, ) # Init population of controllers @@ -65,54 +57,34 @@ def test_aurora(env_name: str, batch_size: int) -> None: fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) - # Create the initial environment states - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) - reset_fn = jax.jit(jax.vmap(env.reset)) - init_states = reset_fn(keys) - - # Define the fonction to play a step with the policy in the environment - def play_step_fn( - env_state: EnvState, - policy_params: Params, - random_key: RNGKey, - ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: - """ - Play an environment step and return the updated state and the transition. - """ - - actions = policy_network.apply(policy_params, env_state.obs) - - state_desc = env_state.info["state_descriptor"] - next_state = env.step(env_state, actions) - - transition = QDTransition( - obs=env_state.obs, - next_obs=next_state.obs, - rewards=next_state.reward, - dones=next_state.done, - actions=actions, - truncations=next_state.info["truncation"], - state_desc=state_desc, - next_state_desc=next_state.info["state_descriptor"], - ) + def observation_extractor_fn( + data: QDTransition, + ) -> Observation: + """Extract observation from the state.""" + state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] + + # add the x/y position - (batch_size, traj_length, 2) + state_desc = data.state_desc[:, ::traj_sampling_freq] + + print("State Observations: ", state_obs) + print("XY positions: ", state_desc) - return next_state, policy_params, random_key, transition + if observation_option == "full": + observations = jnp.concatenate([state_desc, state_obs], axis=-1) + print("New observations: ", observations) + elif observation_option == "no_sd": + observations = state_obs + elif observation_option == "only_sd": + observations = state_desc + else: + raise ValueError("Unknown observation option.") + + return observations # Prepare the scoring function - bd_extraction_fn = functools.partial( - get_aurora_bd, - option=observation_option, - hidden_size=hidden_size, - traj_sampling_freq=traj_sampling_freq, - max_observation_size=max_observation_size, - ) - scoring_fn = functools.partial( - scoring_aurora_function, - init_states=init_states, - episode_length=episode_length, - play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + aurora_scoring_fn = get_aurora_scoring_fn( + scoring_fn=scoring_fn, + observation_extractor_fn=observation_extractor_fn, ) # Define emitter @@ -140,19 +112,46 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + # Init algorithm + # AutoEncoder Params and INIT + obs_dim = jnp.minimum(env.observation_size, max_observation_size) + if observation_option == "full": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim + prior_descriptor_dim, + ) + elif observation_option == "no_sd": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim, + ) + elif observation_option == "only_sd": + observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) + else: + raise ValueError(f"Unknown observation option: {observation_option}") + + # define the seq2seq model + model = train_seq2seq.get_model( + observations_dims[-1], True, hidden_size=hidden_size + ) + + encoder_fn = functools.partial( + get_aurora_encoding, + model=model, + ) + # Instantiate AURORA aurora = AURORA( - scoring_function=scoring_fn, + scoring_function=aurora_scoring_fn, emitter=mixing_emitter, metrics_function=metrics_fn, + encoder_function=encoder_fn, ) - aurora_dims = hidden_size - centroids = jnp.zeros(shape=(max_size, aurora_dims)) - @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: """Scan the udpate function.""" + # TODO: fix shadowing names from outer scopes. ( repertoire, random_key, @@ -176,28 +175,8 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: metrics, ) - # Init algorithm - # AutoEncoder Params and INIT - obs_dim = jnp.minimum(env.observation_size, max_observation_size) - if observation_option == "full": - observations_dims = ( - episode_length // traj_sampling_freq, - obs_dim + prior_descriptor_dim, - ) - elif observation_option == "no_sd": - observations_dims = ( - episode_length // traj_sampling_freq, - obs_dim, - ) - elif observation_option == "only_sd": - observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) - else: - ValueError("The chosen option is not correct.") - # define the seq2seq model - model = train_seq2seq.get_model( - observations_dims[-1], True, hidden_size=hidden_size - ) + # init the model params random_key, subkey = jax.random.split(random_key) From 7294c51ed499049be1927b2707d36fc4b81fb4c0 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Mon, 24 Apr 2023 00:23:10 +0100 Subject: [PATCH 16/26] put training and csc inside AURORA --- qdax/core/aurora.py | 73 +++++++++++++++-- qdax/environments/bd_extractors.py | 11 +++ qdax/utils/train_seq2seq.py | 5 +- tests/core_test/aurora_test.py | 126 +++++++---------------------- 4 files changed, 112 insertions(+), 103 deletions(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 3e87907b..a70bd9c9 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -46,11 +46,75 @@ def __init__( encoder_function: Callable[ [Observation, AuroraExtraInfo], Descriptor ], + training_function: Callable[ + [RNGKey, UnstructuredRepertoire, Params, int], AuroraExtraInfo + ], ) -> None: self._scoring_function = scoring_function self._emitter = emitter self._metrics_function = metrics_function self._encoder_fn = encoder_function + self._train_fn = training_function + + def train( + self, + repertoire: UnstructuredRepertoire, + model_params: Params, + iteration: int, + random_key: RNGKey, + ): + random_key, subkey = jax.random.split(random_key) + aurora_extra_info = self._train_fn( + random_key, + repertoire, + model_params, + iteration, + ) + + # re-addition of all the new behavioural descriptors with the new ae + new_descriptors = self._encoder_fn(repertoire.observations, aurora_extra_info) + + return repertoire.init( + genotypes=repertoire.genotypes, + fitnesses=repertoire.fitnesses, + descriptors=new_descriptors, + observations=repertoire.observations, + l_value=repertoire.l_value, + max_size=repertoire.max_size, + ) + return aurora_extra_info + + + @partial(jax.jit, static_argnames=("self",)) + def container_size_control( + self, + repertoire: UnstructuredRepertoire, + target_size: int, + previous_error: jnp.ndarray, + ): + # update the l value + num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + + # CVC Implementation to keep a constant number of individuals in the archive + current_error = num_indivs - target_size + change_rate = current_error - previous_error + prop_gain = 1 * 10e-6 + l_value = ( + repertoire.l_value + + (prop_gain * current_error) + + (prop_gain * change_rate) + ) + + repertoire = repertoire.init( + genotypes=repertoire.genotypes, + fitnesses=repertoire.fitnesses, + descriptors=repertoire.descriptors, + observations=repertoire.observations, + l_value=l_value, + max_size=repertoire.max_size, + ) + + return repertoire, current_error @partial(jax.jit, static_argnames=("self",)) def init( @@ -60,7 +124,7 @@ def init( aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, - ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: + ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], RNGKey]: """Initialize an unstructured repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping. @@ -136,10 +200,7 @@ def update( repertoire: unstructured repertoire emitter_state: state of the emitter random_key: a jax PRNG random key - model_params: params of the model used to define the behavior descriptor. - mean_observations: mean of the observations gathered. - std_observations: standard deviation of the observations - gathered. + aurora_extra_info: extra info for the encoding # TODO Results: the updated MAP-Elites repertoire @@ -164,7 +225,7 @@ def update( # add genotypes and observations in the repertoire repertoire = repertoire.add( - genotypes, descriptors, fitnesses, extra_scores["last_valid_observations"] + genotypes, descriptors, fitnesses, observations, ) # update emitter state after scoring is made diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index d7ca6507..7a77eea8 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -48,6 +48,17 @@ class AuroraExtraInfoNormalization(AuroraExtraInfo): mean_observations: jnp.ndarray std_observations: jnp.ndarray + @classmethod + def create(cls, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + ): + return cls(model_params=model_params, + mean_observations=mean_observations, + std_observations=std_observations, + ) + def get_aurora_encoding( observations: jnp.ndarray, diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index d8c74ae6..4868e7f2 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -16,6 +16,7 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq +from qdax.environments.bd_extractors import AuroraExtraInfoNormalization from qdax.types import Observation, Params, RNGKey Array = Any @@ -95,7 +96,7 @@ def lstm_ae_train( epoch: int, hidden_size: int = 10, batch_size: int = 128, -) -> Tuple[Params, Observation, Observation]: +) -> AuroraExtraInfoNormalization: if epoch > 100: num_epochs = 25 @@ -197,4 +198,4 @@ def lstm_ae_train( params = state.params - return params, mean_obs, std_obs + return AuroraExtraInfoNormalization.create(params, mean_obs, std_obs) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index b665ef15..687d8f60 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -14,7 +14,7 @@ from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP -from qdax.environments.bd_extractors import get_aurora_encoding +from qdax.environments.bd_extractors import get_aurora_encoding, AuroraExtraInfoNormalization from qdax.tasks.brax_envs import get_aurora_scoring_fn, create_default_brax_task_components from qdax.types import EnvState, Params, RNGKey, Observation from qdax.utils import train_seq2seq @@ -140,44 +140,21 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: model=model, ) + train_fn = functools.partial( + train_seq2seq.lstm_ae_train, + hidden_size=hidden_size, + batch_size=lstm_batch_size, + ) + # Instantiate AURORA aurora = AURORA( scoring_function=aurora_scoring_fn, emitter=mixing_emitter, metrics_function=metrics_fn, encoder_function=encoder_fn, + training_function=train_fn, ) - @jax.jit - def update_scan_fn(carry: Any, _: Any) -> Any: - """Scan the udpate function.""" - # TODO: fix shadowing names from outer scopes. - ( - repertoire, - random_key, - model_params, - mean_observations, - std_observations, - ) = carry - - # update - (repertoire, _, metrics, random_key,) = aurora.update( - repertoire, - None, - random_key, - model_params, - mean_observations, - std_observations, - ) - - return ( - (repertoire, random_key, model_params, mean_observations, std_observations), - metrics, - ) - - - - # init the model params random_key, subkey = jax.random.split(random_key) model_params = train_seq2seq.get_initial_params( @@ -220,22 +197,30 @@ def update_scan_fn(carry: Any, _: Any) -> Any: current_step_estimation = 0 # Main loop - n_target = 1024 + target_repertoire_size = 1024 - previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target + previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size iteration = 0 - while iteration < max_iterations: - ( - (repertoire, random_key, model_params, mean_observations, std_observations), - metrics, - ) = jax.lax.scan( - update_scan_fn, - (repertoire, random_key, model_params, mean_observations, std_observations), - (), - length=log_freq, + emitter_state = None + + while iteration < max_iterations: + collected_metrics = [] + aurora_extra_info = AuroraExtraInfoNormalization.create( + model_params=model_params, + mean_observations=mean_observations, + std_observations=std_observations, ) + # update + for _ in range(log_freq): + repertoire, emitter_state, metrics, random_key = aurora.update( + repertoire, + emitter_state, + random_key, + aurora_extra_info=aurora_extra_info, + ) + collected_metrics.append(metrics) # update nb steps estimation current_step_estimation += batch_size * episode_length * log_freq @@ -244,61 +229,12 @@ def update_scan_fn(carry: Any, _: Any) -> Any: if (iteration + 1) in schedules: # train the autoencoder random_key, subkey = jax.random.split(random_key) - ( - model_params, - mean_observations, - std_observations, - ) = train_seq2seq.lstm_ae_train( - subkey, - repertoire, - model_params, - iteration, - hidden_size=hidden_size, - batch_size=lstm_batch_size, - ) - - # re-addition of all the new behavioural descriotpors with the new ae - normalized_observations = ( - repertoire.observations - mean_observations - ) / std_observations - - new_descriptors = model.apply( - {"params": model_params}, normalized_observations, method=model.encode - ) - repertoire = repertoire.init( - genotypes=repertoire.genotypes, - fitnesses=repertoire.fitnesses, - descriptors=new_descriptors, - observations=repertoire.observations, - l_value=repertoire.l_value, - max_size=repertoire.max_size, - ) - num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + repertoire = aurora.train(repertoire, model_params, iteration, subkey) elif iteration % 2 == 0: - # update the l value - num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) - - # CVC Implementation to keep a constant number of individuals in the archive - current_error = num_indivs - n_target - change_rate = current_error - previous_error - prop_gain = 1 * 10e-6 - l_value = ( - repertoire.l_value - + (prop_gain * current_error) - + (prop_gain * change_rate) - ) - - previous_error = current_error - - repertoire = repertoire.init( - genotypes=repertoire.genotypes, - fitnesses=repertoire.fitnesses, - descriptors=repertoire.descriptors, - observations=repertoire.observations, - l_value=l_value, - max_size=repertoire.max_size, - ) + repertoire, previous_error = aurora.container_size_control(repertoire, + target_size=target_repertoire_size, + previous_error=previous_error) iteration += 1 From b0aba19e4310a35bfce7c8b693e70cd13cf65e36 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 25 Apr 2023 18:11:49 +0100 Subject: [PATCH 17/26] Add missing comments and refactor tests --- qdax/core/aurora.py | 43 +++--- .../containers/unstructured_repertoire.py | 25 ++-- qdax/utils/train_seq2seq.py | 96 ++++++------ tests/core_test/aurora_test.py | 137 ++++++++---------- tests/core_test/map_elites_test.py | 33 ++--- 5 files changed, 168 insertions(+), 166 deletions(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index a70bd9c9..551a1966 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -62,7 +62,7 @@ def train( model_params: Params, iteration: int, random_key: RNGKey, - ): + ) -> Tuple[UnstructuredRepertoire, AuroraExtraInfo]: random_key, subkey = jax.random.split(random_key) aurora_extra_info = self._train_fn( random_key, @@ -74,15 +74,17 @@ def train( # re-addition of all the new behavioural descriptors with the new ae new_descriptors = self._encoder_fn(repertoire.observations, aurora_extra_info) - return repertoire.init( + return ( + repertoire.init( genotypes=repertoire.genotypes, fitnesses=repertoire.fitnesses, descriptors=new_descriptors, observations=repertoire.observations, l_value=repertoire.l_value, max_size=repertoire.max_size, - ) - return aurora_extra_info + ), + aurora_extra_info + ) @partial(jax.jit, static_argnames=("self",)) @@ -116,33 +118,29 @@ def container_size_control( return repertoire, current_error - @partial(jax.jit, static_argnames=("self",)) def init( self, init_genotypes: Genotype, - random_key: RNGKey, aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, - ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], RNGKey]: + random_key: RNGKey, + ) -> Tuple[UnstructuredRepertoire, Optional[EmitterState], AuroraExtraInfo, RNGKey]: """Initialize an unstructured repertoire with an initial population of - genotypes. Requires the definition of centroids that can be computed with - any method such as CVT or Euclidean mapping. + genotypes. Also performs the first training of the AURORA encoder. Args: init_genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tesselation centroids of shape (batch_size, num_descriptors) + aurora_extra_info: information to perform AURORA encodings, + such as the encoder parameters + l_value: threshold distance for the unstructured repertoire + max_size: maximum size of the repertoire random_key: a random key used for stochastic operations. - model_params: parameters of the model used to define the behavior - descriptors. - mean_observations: mean of the observations gathered. - std_observations: standard deviation of the observations - gathered. Returns: - an initialized unstructured repertoire with the initial state of - the emitter. + an initialized unstructured repertoire, with the initial state of + the emitter, and the updated information to perform AURORA encodings """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( init_genotypes, @@ -154,12 +152,11 @@ def init( descriptors = self._encoder_fn(observations, aurora_extra_info) - repertoire = UnstructuredRepertoire.init( genotypes=init_genotypes, fitnesses=fitnesses, descriptors=descriptors, - observations=observations, # type: ignore + observations=observations, l_value=l_value, max_size=max_size, ) @@ -178,7 +175,13 @@ def init( extra_scores=extra_scores, ) - return repertoire, emitter_state, random_key + random_key, subkey = jax.random.split(random_key) + repertoire, updated_aurora_extra_info = self.train(repertoire, + aurora_extra_info.model_params, + iteration=0, + random_key=subkey) + + return repertoire, emitter_state, updated_aurora_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def update( diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 1685d66b..59b64ce4 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Callable, Optional, Tuple +from typing import Callable, Tuple import flax.struct import jax @@ -29,9 +29,9 @@ def get_cells_indices( """ def _get_cells_indices( - descriptors: jnp.ndarray, - centroids: jnp.ndarray, - k_nn: int, + _descriptors: jnp.ndarray, + _centroids: jnp.ndarray, + _k_nn: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Inner function. @@ -39,10 +39,10 @@ def _get_cells_indices( centroids of shape (num_centroids, num_descriptors) """ - distances = jax.vmap(jnp.linalg.norm)(descriptors - centroids) + distances = jax.vmap(jnp.linalg.norm)(_descriptors - _centroids) # Negating distances because we want the smallest ones - min_dist, min_args = jax.lax.top_k(-1 * distances, k_nn) + min_dist, min_args = jax.lax.top_k(-1 * distances, _k_nn) return min_args, -1 * min_dist @@ -151,6 +151,14 @@ class UnstructuredRepertoire(flax.struct.PyTreeNode): l_value: jnp.ndarray max_size: int = flax.struct.field(pytree_node=False) + def get_maximal_size(self) -> int: + """Returns the maximal number of individuals in the repertoire.""" + return self.max_size + + def get_number_genotypes(self) -> jnp.ndarray: + """Returns the number of genotypes in the repertoire.""" + return jnp.sum(self.fitnesses != -jnp.inf) + def save(self, path: str = "./") -> None: """Saves the grid on disk in the form of .npy files. @@ -243,7 +251,7 @@ def add( batch_of_descriptors, filtered_descriptors, 2 ) - # Save the second nearest neighbours to check a condition + # Save the second-nearest neighbours to check a condition second_neighbours = batch_of_distances.at[..., 1].get() # Keep the Nearest neighbours @@ -291,7 +299,6 @@ def add( batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1) # ReIndexing of all the inputs to the correct sorted way - batch_of_distances = batch_of_distances.at[sorted_bds].get() batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() batch_of_genotypes = jax.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes @@ -416,7 +423,6 @@ def init( fitnesses: fitness of the initial genotypes of shape (batch_size,) descriptors: descriptors of the initial genotypes of shape (batch_size, num_descriptors) - centroids: tesselation centroids of shape (batch_size, num_descriptors) observations: observations experienced in the evaluation task. l_value: threshold distance of the repertoire. max_size: maximal size of the container @@ -425,7 +431,6 @@ def init( an initialized unstructured repertoire. """ - # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) default_genotypes = jax.tree_map( diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 4868e7f2..37c4834e 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -17,7 +17,7 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -from qdax.types import Observation, Params, RNGKey +from qdax.types import Params, RNGKey Array = Any PRNGKey = Any @@ -26,18 +26,31 @@ def get_model( obs_size: int, teacher_force: bool = False, hidden_size: int = 10 ) -> Seq2seq: - # TODO: add docstring + """ + Returns a seq2seq model. + + Args: + obs_size: the size of the observation. + teacher_force: whether to use teacher forcing. + hidden_size: the size of the hidden layer (i.e. the encoding). + """ return Seq2seq( teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size ) def get_initial_params( - model: Seq2seq, rng: PRNGKey, encoder_input_shape: Tuple[int, ...] + model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...] ) -> Dict[str, Any]: - # TODO: add docstring - """Returns the initial parameters of a seq2seq model.""" - rng1, rng2, rng3 = jax.random.split(rng, 3) + """ + Returns the initial parameters of a seq2seq model. + + Args: + model: the seq2seq model. + random_key: the random number generator. + encoder_input_shape: the shape of the encoder input. + """ + random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4) variables = model.init( {"params": rng1, "lstm": rng2, "dropout": rng3}, jnp.ones(encoder_input_shape, jnp.float32), @@ -48,12 +61,21 @@ def get_initial_params( @jax.jit def train_step( - state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey + state: train_state.TrainState, + batch: Array, + lstm_random_key: PRNGKey, ) -> Tuple[train_state.TrainState, Dict[str, float]]: - # TODO: add docstring + """ + Trains for one step. + + Args: + state: the training state. + batch: the batch of data. + lstm_random_key: the random number key. + """ """Trains one step.""" - lstm_key = jax.random.fold_in(lstm_rng, state.step) + lstm_key = jax.random.fold_in(lstm_random_key, state.step) dropout_key, lstm_key = jax.random.split(lstm_key, 2) # Shift input by one to avoid leakage @@ -90,32 +112,21 @@ def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: def lstm_ae_train( - key: RNGKey, + random_key: RNGKey, repertoire: UnstructuredRepertoire, params: Params, epoch: int, - hidden_size: int = 10, + model, batch_size: int = 128, ) -> AuroraExtraInfoNormalization: if epoch > 100: num_epochs = 25 - - # Gradient step size - alpha = 0.0001 + alpha = 0.0001 # Gradient step size else: num_epochs = 100 - - # Gradient step size alpha = 0.0001 - rng, key, key_selection = jax.random.split(key, 3) - - # get the model used (seq2seq) - model = get_model( - repertoire.observations.shape[-1], teacher_force=True, hidden_size=hidden_size - ) - # compute mean/std of the obs for normalization mean_obs = jnp.nanmean(repertoire.observations, axis=(0, 1)) std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) @@ -126,34 +137,36 @@ def lstm_ae_train( tx = optax.adam(alpha) state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) + ########################################################################### + # Shuffling indexes of valid individuals in the repertoire + ########################################################################### + # size of the repertoire - repertoire_size = repertoire.centroids.shape[0] + repertoire_size = repertoire.max_size # number of individuals in the repertoire - num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) + num_indivs = repertoire.get_number_genotypes() - # select repertoire_size indexes going from 0 to num_indivs - # TODO: WHY?? - key_select_p1, rng = jax.random.split(key_selection, 2) + # select repertoire_size indexes going from 0 to the total number of + # valid individuals. Those indexes will be used to select the individuals + # in the training dataset. + random_key, key_select_p1 = jax.random.split(random_key, 2) idx_p1 = jax.random.randint( key_select_p1, shape=(repertoire_size,), minval=0, maxval=num_indivs ) - # TODO: what is the diff with repertoire_size?? - tot_indivs = repertoire.fitnesses.ravel().shape[0] - - # get indexes where fitness is not -inf?? + # get indexes where fitness is not -inf. Those are the valid individuals. indexes = jnp.argwhere( - jnp.logical_not(jnp.isinf(repertoire.fitnesses)), size=tot_indivs + jnp.logical_not(jnp.isinf(repertoire.fitnesses)), size=repertoire_size ) indexes = jnp.transpose(indexes, axes=(1, 0)) - # ??? + # get corresponding indices for the flattened repertoire fitnesses indiv_indices = jnp.array( jnp.ravel_multi_index(indexes, repertoire.fitnesses.shape, mode="clip") ).astype(int) - # ??? + # filter those indices to get only the indices of valid individuals valid_indexes = indiv_indices.at[idx_p1].get() # Normalising Dataset @@ -161,13 +174,11 @@ def lstm_ae_train( loss_val = 0.0 for epoch in range(num_epochs): - rng, shuffle_key = jax.random.split(rng, 2) + random_key, shuffle_key = jax.random.split(random_key, 2) valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) - # TODO: the std where they were NaNs is set to zero. But here we divide by the - # std, so NaNs appear here... - # std_obs += 1e-6 - + # the std where they were NaNs was set to zero. But here we divide by the + # std, so we replace the zeros by inf here. std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) # create dataset with the observation from the sample of valid indexes @@ -187,15 +198,12 @@ def lstm_ae_train( # print(batch.shape) continue - state, loss_val = train_step(state, batch, rng) + state, loss_val = train_step(state, batch, random_key) # To see the actual value we cannot jit this function (i.e. the _one_es_epoch # function nor the train function) print("Eval epoch: {}, loss: {:.4f}".format(epoch + 1, loss_val)) - # TODO: put this in metrics so we can jit the function and see the metrics - # TODO: not urgent because the training is not that long - params = state.params return AuroraExtraInfoNormalization.create(params, mean_obs, std_obs) diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 687d8f60..00239527 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -1,23 +1,48 @@ """Tests AURORA implementation""" import functools -from typing import Any, Dict, Tuple +from typing import Tuple +import brax.envs import jax import jax.numpy as jnp import pytest from qdax import environments from qdax.core.aurora import AURORA -from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire -from qdax.core.emitters.mutation_operators import isoline_variation -from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.core.neuroevolution.networks.networks import MLP from qdax.environments.bd_extractors import get_aurora_encoding, AuroraExtraInfoNormalization from qdax.tasks.brax_envs import get_aurora_scoring_fn, create_default_brax_task_components -from qdax.types import EnvState, Params, RNGKey, Observation +from qdax.types import Observation from qdax.utils import train_seq2seq +from qdax.utils.metrics import default_qd_metrics +from tests.core_test.map_elites_test import get_mixing_emitter + + +def get_observation_dims(observation_option: str, + env: brax.envs.Env, + max_observation_size: int, + episode_length: int, + traj_sampling_freq: int, + prior_descriptor_dim: int, + ) -> Tuple[int, int]: + obs_dim = jnp.minimum(env.observation_size, max_observation_size) + if observation_option == "full": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim + prior_descriptor_dim, + ) + elif observation_option == "no_sd": + observations_dims = ( + episode_length // traj_sampling_freq, + obs_dim, + ) + elif observation_option == "only_sd": + observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) + else: + raise ValueError(f"Unknown observation option: {observation_option}") + + return observations_dims @pytest.mark.parametrize( @@ -66,12 +91,8 @@ def observation_extractor_fn( # add the x/y position - (batch_size, traj_length, 2) state_desc = data.state_desc[:, ::traj_sampling_freq] - print("State Observations: ", state_obs) - print("XY positions: ", state_desc) - if observation_option == "full": observations = jnp.concatenate([state_desc, state_obs], axis=-1) - print("New observations: ", observations) elif observation_option == "no_sd": observations = state_obs elif observation_option == "only_sd": @@ -88,65 +109,43 @@ def observation_extractor_fn( ) # Define emitter - variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) - mixing_emitter = MixingEmitter( - mutation_fn=lambda x, y: (x, y), - variation_fn=variation_fn, - variation_percentage=1.0, - batch_size=batch_size, - ) + mixing_emitter = get_mixing_emitter(batch_size) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function - def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: - - # Get metrics - grid_empty = repertoire.fitnesses == -jnp.inf - qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) - # Add offset for positive qd_score - qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) - coverage = 100 * jnp.mean(1.0 - grid_empty) - max_fitness = jnp.max(repertoire.fitnesses) - - return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + metrics_fn = functools.partial(default_qd_metrics, qd_offset=reward_offset) # Init algorithm # AutoEncoder Params and INIT - obs_dim = jnp.minimum(env.observation_size, max_observation_size) - if observation_option == "full": - observations_dims = ( - episode_length // traj_sampling_freq, - obs_dim + prior_descriptor_dim, - ) - elif observation_option == "no_sd": - observations_dims = ( - episode_length // traj_sampling_freq, - obs_dim, - ) - elif observation_option == "only_sd": - observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) - else: - raise ValueError(f"Unknown observation option: {observation_option}") + observations_dims = get_observation_dims(observation_option=observation_option, + env=env, + max_observation_size=max_observation_size, + episode_length=episode_length, + traj_sampling_freq=traj_sampling_freq, + prior_descriptor_dim=prior_descriptor_dim, + ) # define the seq2seq model model = train_seq2seq.get_model( observations_dims[-1], True, hidden_size=hidden_size ) + # define the encoder function encoder_fn = functools.partial( get_aurora_encoding, model=model, ) + # define the training function train_fn = functools.partial( train_seq2seq.lstm_ae_train, - hidden_size=hidden_size, + model=model, batch_size=lstm_batch_size, ) - # Instantiate AURORA + # Instantiate AURORA algorithm aurora = AURORA( scoring_function=aurora_scoring_fn, emitter=mixing_emitter, @@ -161,33 +160,29 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: model, subkey, (1, *observations_dims) ) - print(jax.tree_map(lambda x: x.shape, model_params)) - # define arbitrary observation's mean/std mean_observations = jnp.zeros(observations_dims[-1]) std_observations = jnp.ones(observations_dims[-1]) - # init step of the aurora algorithm - repertoire, _, random_key = aurora.init( - init_variables, - random_key, + # init all the information needed by AURORA to compute encodings + aurora_extra_info = AuroraExtraInfoNormalization.create( model_params, mean_observations, std_observations, - jnp.array(l_value_init), + ) + + # init step of the aurora algorithm + repertoire, emitter_state, aurora_extra_info, random_key = aurora.init( + init_variables, + aurora_extra_info, + jnp.asarray(l_value_init), max_size, + random_key, ) # initializing means and stds and AURORA random_key, subkey = jax.random.split(random_key) - model_params, mean_observations, std_observations = train_seq2seq.lstm_ae_train( - subkey, - repertoire, - model_params, - 0, - hidden_size=hidden_size, - batch_size=lstm_batch_size, - ) + repertoire, aurora_extra_info = aurora.train(repertoire, model_params, iteration=0, random_key=subkey) # design aurora's schedule default_update_base = 10 @@ -196,23 +191,17 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: current_step_estimation = 0 + ############################ # Main loop + ############################ + target_repertoire_size = 1024 previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - target_repertoire_size iteration = 0 - - emitter_state = None - while iteration < max_iterations: - collected_metrics = [] - aurora_extra_info = AuroraExtraInfoNormalization.create( - model_params=model_params, - mean_observations=mean_observations, - std_observations=std_observations, - ) - # update + # standard MAP-Elites-like loop for _ in range(log_freq): repertoire, emitter_state, metrics, random_key = aurora.update( repertoire, @@ -220,18 +209,18 @@ def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: random_key, aurora_extra_info=aurora_extra_info, ) - collected_metrics.append(metrics) # update nb steps estimation current_step_estimation += batch_size * episode_length * log_freq - # autoencoder steps and CVC + # autoencoder steps and Container Size Control (CSC) if (iteration + 1) in schedules: - # train the autoencoder + # train the autoencoder (includes the CSC) random_key, subkey = jax.random.split(random_key) - repertoire = aurora.train(repertoire, model_params, iteration, subkey) + repertoire, aurora_extra_info = aurora.train(repertoire, model_params, iteration, subkey) elif iteration % 2 == 0: + # only CSC repertoire, previous_error = aurora.container_size_control(repertoire, target_size=target_repertoire_size, previous_error=previous_error) diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 66748079..91ee889d 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -19,6 +19,19 @@ from qdax.core.neuroevolution.networks.networks import MLP from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey +from qdax.utils.metrics import default_qd_metrics + + +def get_mixing_emitter(batch_size) -> MixingEmitter: + """Create a mixing emitter with a given batch size.""" + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + return mixing_emitter @pytest.mark.parametrize( @@ -102,29 +115,13 @@ def play_step_fn( ) # Define emitter - variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) - mixing_emitter = MixingEmitter( - mutation_fn=lambda x, y: (x, y), - variation_fn=variation_fn, - variation_percentage=1.0, - batch_size=batch_size, - ) + mixing_emitter = get_mixing_emitter(batch_size) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function - def metrics_fn(repertoire: MapElitesRepertoire) -> Dict: - - # Get metrics - grid_empty = repertoire.fitnesses == -jnp.inf - qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) - # Add offset for positive qd_score - qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) - coverage = 100 * jnp.mean(1.0 - grid_empty) - max_fitness = jnp.max(repertoire.fitnesses) - - return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + metrics_fn = functools.partial(default_qd_metrics, qd_offset=reward_offset) # Instantiate MAP-Elites map_elites = MAPElites( From 9873f481d164bac4dd837a6127cae68683e4cb5e Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 25 Apr 2023 18:32:22 +0100 Subject: [PATCH 18/26] fix style --- qdax/core/aurora.py | 64 +++++++++---------- .../containers/unstructured_repertoire.py | 8 +-- qdax/environments/bd_extractors.py | 25 ++++---- qdax/tasks/brax_envs.py | 24 ++++--- qdax/utils/train_seq2seq.py | 2 +- tests/core_test/aurora_test.py | 58 ++++++++++------- tests/core_test/map_elites_test.py | 9 +-- 7 files changed, 104 insertions(+), 86 deletions(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 551a1966..3d90144e 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -5,7 +5,6 @@ from functools import partial from typing import Callable, Optional, Tuple -import flax.struct import jax import jax.numpy as jnp from chex import ArrayTree @@ -13,12 +12,16 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.environments.bd_extractors import AuroraExtraInfo -from qdax.types import Descriptor, Fitness, Genotype, Metrics, Params, RNGKey, Observation - - - +from qdax.types import ( + Descriptor, + Fitness, + Genotype, + Metrics, + Observation, + Params, + RNGKey, +) class AURORA: @@ -43,9 +46,7 @@ def __init__( ], emitter: Emitter, metrics_function: Callable[[MapElitesRepertoire], Metrics], - encoder_function: Callable[ - [Observation, AuroraExtraInfo], Descriptor - ], + encoder_function: Callable[[Observation, AuroraExtraInfo], Descriptor], training_function: Callable[ [RNGKey, UnstructuredRepertoire, Params, int], AuroraExtraInfo ], @@ -57,11 +58,11 @@ def __init__( self._train_fn = training_function def train( - self, - repertoire: UnstructuredRepertoire, - model_params: Params, - iteration: int, - random_key: RNGKey, + self, + repertoire: UnstructuredRepertoire, + model_params: Params, + iteration: int, + random_key: RNGKey, ) -> Tuple[UnstructuredRepertoire, AuroraExtraInfo]: random_key, subkey = jax.random.split(random_key) aurora_extra_info = self._train_fn( @@ -83,17 +84,16 @@ def train( l_value=repertoire.l_value, max_size=repertoire.max_size, ), - aurora_extra_info + aurora_extra_info, ) - @partial(jax.jit, static_argnames=("self",)) def container_size_control( - self, - repertoire: UnstructuredRepertoire, - target_size: int, - previous_error: jnp.ndarray, - ): + self, + repertoire: UnstructuredRepertoire, + target_size: int, + previous_error: jnp.ndarray, + ) -> Tuple[UnstructuredRepertoire, jnp.ndarray]: # update the l value num_indivs = jnp.sum(repertoire.fitnesses != -jnp.inf) @@ -102,9 +102,7 @@ def container_size_control( change_rate = current_error - previous_error prop_gain = 1 * 10e-6 l_value = ( - repertoire.l_value - + (prop_gain * current_error) - + (prop_gain * change_rate) + repertoire.l_value + (prop_gain * current_error) + (prop_gain * change_rate) ) repertoire = repertoire.init( @@ -149,8 +147,7 @@ def init( observations = extra_scores["last_valid_observations"] - descriptors = self._encoder_fn(observations, - aurora_extra_info) + descriptors = self._encoder_fn(observations, aurora_extra_info) repertoire = UnstructuredRepertoire.init( genotypes=init_genotypes, @@ -176,10 +173,9 @@ def init( ) random_key, subkey = jax.random.split(random_key) - repertoire, updated_aurora_extra_info = self.train(repertoire, - aurora_extra_info.model_params, - iteration=0, - random_key=subkey) + repertoire, updated_aurora_extra_info = self.train( + repertoire, aurora_extra_info.model_params, iteration=0, random_key=subkey + ) return repertoire, emitter_state, updated_aurora_extra_info, random_key @@ -223,12 +219,14 @@ def update( observations = extra_scores["last_valid_observations"] - descriptors = self._encoder_fn(observations, - aurora_extra_info) + descriptors = self._encoder_fn(observations, aurora_extra_info) # add genotypes and observations in the repertoire repertoire = repertoire.add( - genotypes, descriptors, fitnesses, observations, + genotypes, + descriptors, + fitnesses, + observations, ) # update emitter state after scoring is made diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index 59b64ce4..f4cc0c98 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -346,7 +346,9 @@ def add( # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=self.max_size, + addition_condition, + x=batch_of_indices, + y=self.max_size, ) # create new grid @@ -434,9 +436,7 @@ def init( # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) default_genotypes = jax.tree_map( - lambda x: jnp.full( - shape=(max_size,) + x.shape[1:], fill_value=jnp.nan - ), + lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), genotypes, ) default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1])) diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index 7a77eea8..77cacf27 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import flax.struct import jax import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.types import Descriptor, Params -from qdax.utils import train_seq2seq def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: @@ -40,24 +41,26 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri return descriptors - class AuroraExtraInfo(flax.struct.PyTreeNode): model_params: Params + class AuroraExtraInfoNormalization(AuroraExtraInfo): mean_observations: jnp.ndarray std_observations: jnp.ndarray @classmethod - def create(cls, - model_params: Params, - mean_observations: jnp.ndarray, - std_observations: jnp.ndarray, - ): - return cls(model_params=model_params, - mean_observations=mean_observations, - std_observations=std_observations, - ) + def create( + cls, + model_params: Params, + mean_observations: jnp.ndarray, + std_observations: jnp.ndarray, + ) -> AuroraExtraInfoNormalization: + return cls( + model_params=model_params, + mean_observations=mean_observations, + std_observations=std_observations, + ) def get_aurora_encoding( diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index de10fc47..ae74c716 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -1,6 +1,6 @@ import functools from functools import partial -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple import brax.envs import flax.linen as nn @@ -84,7 +84,7 @@ def default_play_step_fn( def get_mask_from_transitions( - data: Transition, + data: Transition, ) -> jnp.ndarray: is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) mask = jnp.roll(is_done, 1, axis=1) @@ -275,8 +275,8 @@ def create_brax_scoring_fn( init_state = env.reset(subkey) # Define the function to deterministically reset the environment - def deterministic_reset(key: RNGKey, init_state: EnvState) -> EnvState: - return init_state + def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: + return _init_state play_reset_fn = partial(deterministic_reset, init_state=init_state) @@ -351,9 +351,13 @@ def create_default_brax_task_components( def get_aurora_scoring_fn( - scoring_fn: Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], + scoring_fn: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], observation_extractor_fn: Callable[[Transition], Observation], -) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]]: +) -> Callable[ + [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] +]: """Evaluates policies contained in flatten_variables in parallel This rollout is only deterministic when all the init states are the same. @@ -362,16 +366,18 @@ def get_aurora_scoring_fn( When the init states are different, this is not purely stochastic. This choice was made for performance reason, as the reset function of brax envs - is quite time consuming. If pure stochasticity of the environment is needed + is quite time-consuming. If pure stochasticity of the environment is needed for a use case, please open an issue. """ @functools.wraps(scoring_fn) - def _wrapper(params: Params, # Perform rollouts with each policy - random_key: RNGKey): + def _wrapper( + params: Params, random_key: RNGKey # Perform rollouts with each policy + ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) data = extra_scores["data"] observation = observation_extractor_fn(data) # type: ignore extra_scores["last_valid_observations"] = observation return fitnesses, None, extra_scores, random_key + return _wrapper diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 37c4834e..306b8a2e 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -116,7 +116,7 @@ def lstm_ae_train( repertoire: UnstructuredRepertoire, params: Params, epoch: int, - model, + model: Seq2seq, batch_size: int = 128, ) -> AuroraExtraInfoNormalization: diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 00239527..aef7c786 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -11,21 +11,28 @@ from qdax import environments from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.environments.bd_extractors import get_aurora_encoding, AuroraExtraInfoNormalization -from qdax.tasks.brax_envs import get_aurora_scoring_fn, create_default_brax_task_components +from qdax.environments.bd_extractors import ( + AuroraExtraInfoNormalization, + get_aurora_encoding, +) +from qdax.tasks.brax_envs import ( + create_default_brax_task_components, + get_aurora_scoring_fn, +) from qdax.types import Observation from qdax.utils import train_seq2seq from qdax.utils.metrics import default_qd_metrics from tests.core_test.map_elites_test import get_mixing_emitter -def get_observation_dims(observation_option: str, - env: brax.envs.Env, - max_observation_size: int, - episode_length: int, - traj_sampling_freq: int, - prior_descriptor_dim: int, - ) -> Tuple[int, int]: +def get_observation_dims( + observation_option: str, + env: brax.envs.Env, + max_observation_size: int, + episode_length: int, + traj_sampling_freq: int, + prior_descriptor_dim: int, +) -> Tuple[int, int]: obs_dim = jnp.minimum(env.observation_size, max_observation_size) if observation_option == "full": observations_dims = ( @@ -83,7 +90,7 @@ def test_aurora(env_name: str, batch_size: int) -> None: init_variables = jax.vmap(policy_network.init)(keys, fake_batch) def observation_extractor_fn( - data: QDTransition, + data: QDTransition, ) -> Observation: """Extract observation from the state.""" state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] @@ -119,13 +126,14 @@ def observation_extractor_fn( # Init algorithm # AutoEncoder Params and INIT - observations_dims = get_observation_dims(observation_option=observation_option, - env=env, - max_observation_size=max_observation_size, - episode_length=episode_length, - traj_sampling_freq=traj_sampling_freq, - prior_descriptor_dim=prior_descriptor_dim, - ) + observations_dims = get_observation_dims( + observation_option=observation_option, + env=env, + max_observation_size=max_observation_size, + episode_length=episode_length, + traj_sampling_freq=traj_sampling_freq, + prior_descriptor_dim=prior_descriptor_dim, + ) # define the seq2seq model model = train_seq2seq.get_model( @@ -182,7 +190,9 @@ def observation_extractor_fn( # initializing means and stds and AURORA random_key, subkey = jax.random.split(random_key) - repertoire, aurora_extra_info = aurora.train(repertoire, model_params, iteration=0, random_key=subkey) + repertoire, aurora_extra_info = aurora.train( + repertoire, model_params, iteration=0, random_key=subkey + ) # design aurora's schedule default_update_base = 10 @@ -217,13 +227,17 @@ def observation_extractor_fn( if (iteration + 1) in schedules: # train the autoencoder (includes the CSC) random_key, subkey = jax.random.split(random_key) - repertoire, aurora_extra_info = aurora.train(repertoire, model_params, iteration, subkey) + repertoire, aurora_extra_info = aurora.train( + repertoire, model_params, iteration, subkey + ) elif iteration % 2 == 0: # only CSC - repertoire, previous_error = aurora.container_size_control(repertoire, - target_size=target_repertoire_size, - previous_error=previous_error) + repertoire, previous_error = aurora.container_size_control( + repertoire, + target_size=target_repertoire_size, + previous_error=previous_error, + ) iteration += 1 diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 91ee889d..b532aa65 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -1,17 +1,14 @@ """Tests MAP Elites implementation""" import functools -from typing import Dict, Tuple +from typing import Tuple import jax import jax.numpy as jnp import pytest from qdax import environments -from qdax.core.containers.mapelites_repertoire import ( - MapElitesRepertoire, - compute_cvt_centroids, -) +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.map_elites import MAPElites @@ -22,7 +19,7 @@ from qdax.utils.metrics import default_qd_metrics -def get_mixing_emitter(batch_size) -> MixingEmitter: +def get_mixing_emitter(batch_size: int) -> MixingEmitter: """Create a mixing emitter with a given batch size.""" variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) mixing_emitter = MixingEmitter( From 3626a18600f2af15f272d9a52d58464e56806a31 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 25 Apr 2023 17:55:36 +0000 Subject: [PATCH 19/26] make it run --- qdax/environments/bd_extractors.py | 2 +- qdax/tasks/brax_envs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index 77cacf27..d19fb953 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -65,8 +65,8 @@ def create( def get_aurora_encoding( observations: jnp.ndarray, - model: flax.linen.Module, aurora_extra_info: AuroraExtraInfoNormalization, + model: flax.linen.Module, ) -> Descriptor: """ Compute final aurora embedding. diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index ae74c716..931ee9d3 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -278,7 +278,7 @@ def create_brax_scoring_fn( def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: return _init_state - play_reset_fn = partial(deterministic_reset, init_state=init_state) + play_reset_fn = partial(deterministic_reset, _init_state=init_state) # Stochastic case elif play_reset_fn is None: @@ -375,7 +375,7 @@ def _wrapper( params: Params, random_key: RNGKey # Perform rollouts with each policy ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) - data = extra_scores["data"] + data = extra_scores["transitions"] observation = observation_extractor_fn(data) # type: ignore extra_scores["last_valid_observations"] = observation return fitnesses, None, extra_scores, random_key From 2606c864c69805490481edbe289da82a80a26cad Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 25 Apr 2023 18:20:12 +0000 Subject: [PATCH 20/26] make it work --- qdax/environments/bd_extractors.py | 4 ++-- qdax/utils/train_seq2seq.py | 7 +++---- tests/core_test/aurora_test.py | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index d19fb953..8a1830f3 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools + import flax.struct import jax import jax.numpy as jnp @@ -84,6 +86,4 @@ def get_aurora_encoding( {"params": model_params}, normalized_observations, method=model.encode ) - print("Observations out of get aurora bd: ", observations) - return descriptors.squeeze() diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index 306b8a2e..acb14a9b 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -130,6 +130,9 @@ def lstm_ae_train( # compute mean/std of the obs for normalization mean_obs = jnp.nanmean(repertoire.observations, axis=(0, 1)) std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) + # the std where they were NaNs was set to zero. But here we divide by the + # std, so we replace the zeros by inf here. + std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) # TODO: maybe we could just compute this data on the valid dataset @@ -177,10 +180,6 @@ def lstm_ae_train( random_key, shuffle_key = jax.random.split(random_key, 2) valid_indexes = jax.random.permutation(shuffle_key, valid_indexes, axis=0) - # the std where they were NaNs was set to zero. But here we divide by the - # std, so we replace the zeros by inf here. - std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) - # create dataset with the observation from the sample of valid indexes training_dataset = ( repertoire.observations.at[valid_indexes, ...].get() - mean_obs diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index aef7c786..916caa89 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -137,14 +137,14 @@ def observation_extractor_fn( # define the seq2seq model model = train_seq2seq.get_model( - observations_dims[-1], True, hidden_size=hidden_size + int(observations_dims[-1]), True, hidden_size=hidden_size ) # define the encoder function - encoder_fn = functools.partial( + encoder_fn = jax.jit(functools.partial( get_aurora_encoding, model=model, - ) + )) # define the training function train_fn = functools.partial( @@ -213,7 +213,7 @@ def observation_extractor_fn( while iteration < max_iterations: # standard MAP-Elites-like loop for _ in range(log_freq): - repertoire, emitter_state, metrics, random_key = aurora.update( + repertoire, emitter_state, _, random_key = aurora.update( repertoire, emitter_state, random_key, From 1b06020990b3c3e9b6a66f283c6160315385de75 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 25 Apr 2023 19:44:21 +0100 Subject: [PATCH 21/26] passes style and tests --- environment.yaml | 1 + qdax/environments/bd_extractors.py | 2 -- tests/core_test/aurora_test.py | 10 ++++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/environment.yaml b/environment.yaml index 78058b9d..3da7ff9e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -9,5 +9,6 @@ dependencies: - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - jaxlib==0.3.15 + - optax==0.1.4 - -r requirements.txt - -r requirements-dev.txt diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index 8a1830f3..a6a495a0 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools - import flax.struct import jax import jax.numpy as jnp diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 916caa89..2b238237 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -141,10 +141,12 @@ def observation_extractor_fn( ) # define the encoder function - encoder_fn = jax.jit(functools.partial( - get_aurora_encoding, - model=model, - )) + encoder_fn = jax.jit( + functools.partial( + get_aurora_encoding, + model=model, + ) + ) # define the training function train_fn = functools.partial( From fc2e3c0ddf9858f22b28f1a9e69ae3ee807b883e Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 26 Apr 2023 14:45:41 +0100 Subject: [PATCH 22/26] using right version of brax --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 13bef5dc..b97297fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ absl-py==1.0.0 -brax==0.0.12 +brax==0.0.15 chex==0.1.5 dm-haiku==0.0.5 flax==0.6.0 diff --git a/setup.py b/setup.py index c180ca3d..2e50e0ea 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "jinja2<3.1.0", "jumanji>=0.1.3", "flax>=0.6, <0.6.2", - "brax>=0.0.12", + "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", "scikit-learn>=1.0.2", From 999b2733c24a5cf8b4891402d920540c4535580d Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 26 Apr 2023 16:23:48 +0100 Subject: [PATCH 23/26] update dependencies and change order of docs dependencies installations --- .readthedocs.yaml | 5 +++-- environment.yaml | 1 - requirements.txt | 1 + setup.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2ef47062..82349aa1 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -18,7 +18,8 @@ mkdocs: # Optionally declare the Python requirements required to build your docs python: install: - - requirements: requirements.txt - - requirements: docs/requirements.txt - method: pip path: . + - requirements: requirements.txt + - requirements: docs/requirements.txt + diff --git a/environment.yaml b/environment.yaml index 3da7ff9e..78058b9d 100644 --- a/environment.yaml +++ b/environment.yaml @@ -9,6 +9,5 @@ dependencies: - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - jaxlib==0.3.15 - - optax==0.1.4 - -r requirements.txt - -r requirements-dev.txt diff --git a/requirements.txt b/requirements.txt index b97297fa..c1eb2d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ jax==0.3.17 jumanji==0.1.3 jupyter numpy==1.22.3 +optax==0.1.4 protobuf==3.19.4 scikit-learn==1.0.2 scipy==1.8.0 diff --git a/setup.py b/setup.py index 2e50e0ea..a71f3174 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", + "optax>=0.1, <0.1.5", "scikit-learn>=1.0.2", "scipy>=1.8.0", ], From a1bcd30cac23167f1311e91ff11beb0dd6bd7760 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 27 Apr 2023 11:26:47 +0100 Subject: [PATCH 24/26] adding jaxlib to requirements --- .readthedocs.yaml | 1 - environment.yaml | 1 - requirements.txt | 1 + 3 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 82349aa1..7eec359d 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -22,4 +22,3 @@ python: path: . - requirements: requirements.txt - requirements: docs/requirements.txt - diff --git a/environment.yaml b/environment.yaml index 78058b9d..e46c034e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,5 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/requirements.txt b/requirements.txt index c1eb2d3e..16c91bc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ flax==0.6.0 gym==0.23.1 ipython jax==0.3.17 +jaxlib==0.3.15 jumanji==0.1.3 jupyter numpy==1.22.3 From a0307a3b7f29c5215231431ca888a2292899605f Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Mon, 11 Dec 2023 19:01:27 +0900 Subject: [PATCH 25/26] adapt LSTMCell usage to new RNNCellBase Upgrade --- .../core/neuroevolution/networks/seq2seq_networks.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index 83070c2d..ea7618ba 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -36,7 +36,8 @@ def __call__( ) -> Tuple[Tuple[Array, Array], Array]: """Applies the module.""" lstm_state, is_eos = carry - new_lstm_state, y = nn.LSTMCell()(lstm_state, x) + features = lstm_state[0].shape[-1] + new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) def select_carried_state(new_state: Array, old_state: Array) -> Array: return jnp.where(is_eos[:, np.newaxis], old_state, new_state) @@ -51,8 +52,8 @@ def select_carried_state(new_state: Array, old_state: Array) -> Array: @staticmethod def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]: # Use a dummy key since the default state init fn is just zeros. - return nn.LSTMCell.initialize_carry( # type: ignore - jax.random.PRNGKey(0), (batch_size,), hidden_size + return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore + jax.random.PRNGKey(0), (batch_size, hidden_size) ) @@ -101,7 +102,10 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: lstm_state, last_prediction = carry if not self.teacher_force: x = last_prediction - lstm_state, y = nn.LSTMCell()(lstm_state, x) + + features = lstm_state[0].shape[-1] + new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x) + logits = nn.Dense(features=self.obs_size)(y) return (lstm_state, logits), (logits, logits) From 6c5801e6ddb774d5e5c62b33587f7e478be2c8e2 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Mon, 11 Dec 2023 21:02:05 +0900 Subject: [PATCH 26/26] add missing docstrings --- qdax/core/aurora.py | 2 +- qdax/environments/bd_extractors.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index 3d90144e..fed716e3 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -199,7 +199,7 @@ def update( repertoire: unstructured repertoire emitter_state: state of the emitter random_key: a jax PRNG random key - aurora_extra_info: extra info for the encoding # TODO + aurora_extra_info: extra info for computing encodings Results: the updated MAP-Elites repertoire diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index a6a495a0..af1d51ba 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -42,10 +42,27 @@ def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descri class AuroraExtraInfo(flax.struct.PyTreeNode): + """ + Information specific to the AURORA algorithm. + + Args: + model_params: the parameters of the dimensionality reduction model + """ + model_params: Params class AuroraExtraInfoNormalization(AuroraExtraInfo): + """ + Information specific to the AURORA algorithm. In particular, it contains + the normalization parameters for the observations. + + Args: + model_params: the parameters of the dimensionality reduction model + mean_observations: the mean of observations + std_observations: the std of observations + """ + mean_observations: jnp.ndarray std_observations: jnp.ndarray