diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 434725a3..18d4f0f3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -348,7 +348,7 @@ "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(init_genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, random_key=random_key)" ] }, { diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 6d4dfdfe..6b4ae0b5 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -311,7 +311,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 1cf17c5e..ca127e72 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -314,7 +314,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/mome.ipynb b/examples/mome.ipynb index a4ca36a6..05387158 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -212,7 +212,7 @@ "# initial population\n", "random_key = jax.random.PRNGKey(42)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -303,7 +303,7 @@ "outputs": [], "source": [ "repertoire, emitter_state, random_key = mome.init(\n", - " init_genotypes,\n", + " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", " random_key\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 51c5f5bd..e10c0d91 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -189,7 +189,7 @@ "# Initial population\n", "random_key = jax.random.PRNGKey(0)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -238,7 +238,7 @@ "\n", "# init nsga2\n", "repertoire, emitter_state, random_key = nsga2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " random_key\n", ")" @@ -303,7 +303,7 @@ "\n", "# init spea2\n", "repertoire, emitter_state, random_key = spea2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " num_neighbours,\n", " random_key\n", diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 0714fb6c..a01a13b1 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -39,12 +39,12 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: """Initialize a GARepertoire with an initial population of genotypes. Args: - init_genotypes: the initial population of genotypes + genotypes: the initial population of genotypes population_size: the maximal size of the repertoire random_key: a random key to handle stochastic operations @@ -54,26 +54,21 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = GARepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=None, extra_scores=extra_scores, @@ -108,7 +103,7 @@ def update( """ # generate offsprings - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) @@ -127,7 +122,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=None, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index 663d6f0e..a889eadc 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -28,31 +28,36 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = NSGA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 72ec2791..c52063b6 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -40,7 +40,7 @@ class SPEA2(GeneticAlgorithm): ) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, population_size: int, num_neighbours: int, random_key: RNGKey, @@ -48,12 +48,12 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = SPEA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, num_neighbours=num_neighbours, @@ -61,14 +61,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index fed716e3..a0968ccc 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -118,7 +118,7 @@ def container_size_control( def init( self, - init_genotypes: Genotype, + genotypes: Genotype, aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, @@ -128,7 +128,7 @@ def init( genotypes. Also performs the first training of the AURORA encoder. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) aurora_extra_info: information to perform AURORA encodings, such as the encoder parameters @@ -141,7 +141,7 @@ def init( the emitter, and the updated information to perform AURORA encodings """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, + genotypes, random_key, ) @@ -150,7 +150,7 @@ def init( descriptors = self._encoder_fn(observations, aurora_extra_info) repertoire = UnstructuredRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, observations=observations, @@ -160,13 +160,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, - genotypes=init_genotypes, + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -208,9 +204,10 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, @@ -232,10 +229,11 @@ def update( # update emitter state after scoring is made emitter_state = self._emitter.state_update( emitter_state=emitter_state, + repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..b1145c34 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -228,6 +228,38 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key + @partial(jax.jit, static_argnames=("num_samples",)) + def sample_with_descs( + self, + random_key: RNGKey, + num_samples: int, + ) -> Tuple[Genotype, Descriptor, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + repertoire_empty = self.fitnesses == -jnp.inf + p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) + + random_key, subkey = jax.random.split(random_key) + samples = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.genotypes, + ) + descs = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.descriptors, + ) + + return samples, descs, random_key + @jax.jit def add( self, diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index c8a1ea44..7b5609f2 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -17,7 +17,7 @@ class DistributedMAPElites(MAPElites): @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -30,7 +30,7 @@ def init( devices. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -41,7 +41,7 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # gather across all devices @@ -51,7 +51,7 @@ def init( gathered_descriptors, ) = jax.tree_util.tree_map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), - (init_genotypes, fitnesses, descriptors), + (genotypes, fitnesses, descriptors), ) # init the repertoire @@ -64,14 +64,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -108,7 +113,7 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) # scores the offsprings @@ -138,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..66e5677a 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -99,14 +99,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -135,7 +141,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -154,7 +160,7 @@ def emit( cmaes_state=emitter_state.cmaes_state, random_key=random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..1fd0e1e6 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -100,14 +100,20 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAMEGAState, RNGKey]: """ Initializes the CMA-MEGA emitter. Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -117,7 +123,7 @@ def init( # define init theta as 0 theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), - init_genotypes, + genotypes, ) # score it @@ -147,7 +153,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -181,7 +187,7 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, random_key + return new_thetas, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5424a01..24556f8b 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -49,14 +49,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAPoolEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -67,7 +73,14 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + emitter_state, random_key = self._emitter.init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) return random_key, emitter_state # init all the emitter states @@ -91,7 +104,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. @@ -111,11 +124,11 @@ def emit( ) # use it to emit offsprings - offsprings, random_key = self._emitter.emit( + offsprings, extra_info, random_key = self._emitter.emit( repertoire, used_emitter_state, random_key ) - return offsprings, random_key + return offsprings, extra_info, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 4afb2f5d..e05cc453 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -35,14 +35,20 @@ class CMARndEmitterState(CMAEmitterState): class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMARndEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py new file mode 100644 index 00000000..94e0bb9d --- /dev/null +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Callable, Tuple + +import flax.linen as nn + +from qdax.core.emitters.multi_emitter import MultiEmitter +from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Params, RNGKey + + +@dataclass +class DCGMEConfig: + """Configuration for DCGME Algorithm""" + + ga_batch_size: int = 128 + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + # PG emitter + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCGMEEmitter(MultiEmitter): + def __init__( + self, + config: DCGMEConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]], + ) -> None: + self._config = config + self._env = env + self._variation_fn = variation_fn + + qdcg_config = QualityDCGConfig( + qpg_batch_size=config.qpg_batch_size, + ai_batch_size=config.ai_batch_size, + lengthscale=config.lengthscale, + critic_hidden_layer_size=config.critic_hidden_layer_size, + num_critic_training_steps=config.num_critic_training_steps, + num_pg_training_steps=config.num_pg_training_steps, + batch_size=config.batch_size, + replay_buffer_size=config.replay_buffer_size, + discount=config.discount, + reward_scaling=config.reward_scaling, + critic_learning_rate=config.critic_learning_rate, + actor_learning_rate=config.actor_learning_rate, + policy_learning_rate=config.policy_learning_rate, + noise_clip=config.noise_clip, + policy_noise=config.policy_noise, + soft_tau_update=config.soft_tau_update, + policy_delay=config.policy_delay, + ) + + # define the quality emitter + q_emitter = QualityDCGEmitter( + config=qdcg_config, + policy_network=policy_network, + actor_network=actor_network, + env=env, + ) + + # define the GA emitter + ga_emitter = MixingEmitter( + mutation_fn=lambda x, r: (x, r), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=config.ga_batch_size, + ) + + super().__init__(emitters=(q_emitter, ga_emitter)) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..2c55cbd2 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -77,12 +77,18 @@ def __init__( self._score_novelty = score_novelty def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[DiversityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -90,7 +96,14 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(init_genotypes, random_key) + diversity_emitter_state, random_key = super().init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -102,6 +115,12 @@ def init( max_size=self._config.archive_max_size, ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + archive = archive.insert(transitions.state_desc) + # init emitter state emitter_state = DiversityPGEmitterState( # retrieve all attributes from the QualityPGEmitterState diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..056798ba 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -30,14 +30,20 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -51,7 +57,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..0a03a6ba 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -236,26 +236,32 @@ def batch_size(self) -> int: static_argnames=("self",), ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[MEESEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1: - init_genotypes = jax.tree_util.tree_map( + if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree_util.tree_map( lambda x: x[0], - init_genotypes, + genotypes, ) # Initialise optimizer - initial_optimizer_state = self._optimizer.init(init_genotypes) + initial_optimizer_state = self._optimizer.init(genotypes) # Create empty Novelty archive if self._config.use_explore: @@ -270,7 +276,7 @@ def init( # Create empty updated genotypes and fitness last_updated_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), - init_genotypes, + genotypes, ) last_updated_fitnesses = -jnp.inf * jnp.ones( shape=self._config.last_updated_size @@ -280,7 +286,7 @@ def init( MEESEmitterState( initial_optimizer_state=initial_optimizer_state, optimizer_state=initial_optimizer_state, - offspring=init_genotypes, + offspring=genotypes, generation_count=0, novelty_archive=novelty_archive, last_updated_genotypes=last_updated_genotypes, @@ -300,7 +306,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Return the offspring generated through gradient update. Params: @@ -313,7 +319,7 @@ def emit( a new jax PRNG key """ - return emitter_state.offspring, random_key + return emitter_state.offspring, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 2da46639..b3ad23c6 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -56,13 +56,19 @@ def get_indexes_separation_batches( return tuple(indexes_separation_batches) def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """ Initialize the state of the emitter. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -76,7 +82,14 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(init_genotypes, subkey_emitter) + emitter_state, _ = emitter.init( + subkey_emitter, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -87,7 +100,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. @@ -108,21 +121,25 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] + all_extra_info: ExtraScores = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit( + repertoire, sub_emitter_state, subkey_emitter + ) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) + all_extra_info = {**all_extra_info, **extra_info} # concatenate offsprings together offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, random_key + return offsprings, all_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..54766152 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -84,20 +84,26 @@ def __init__( self._num_descriptors = num_descriptors def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[OMGMEGAEmitterState, RNGKey]: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 gradient_genotype = jax.tree_util.tree_map( @@ -112,6 +118,18 @@ def init( genotype=gradient_genotype, centroids=self._centroids ) + # get gradients out of the extra scores + assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key" + gradients = extra_scores["gradients"] + + # update the gradients repertoire + gradients_repertoire = gradients_repertoire.add( + gradients, + descriptors, + fitnesses, + extra_scores, + ) + return ( OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), random_key, @@ -126,7 +144,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. @@ -190,7 +208,7 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, random_key + return new_genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 3fdb4418..a2266bfa 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -91,12 +91,18 @@ def __init__( ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[PBTEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -145,13 +151,13 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - init_genotypes = jax.tree_util.tree_map( - lambda x: x[: self._config.pg_population_size_per_device], init_genotypes + genotypes = jax.tree_util.tree_map( + lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, - training_states=init_genotypes, + training_states=genotypes, random_key=subkey2, ) @@ -166,7 +172,7 @@ def emit( repertoire: Repertoire, emitter_state: PBTEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. @@ -199,7 +205,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py new file mode 100644 index 00000000..0d560cbb --- /dev/null +++ b/qdax/core/emitters/qdcg_emitter.py @@ -0,0 +1,763 @@ +"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm +in JAX for Brax environments. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Tuple + +import flax.linen as nn +import jax +import optax +from jax import numpy as jnp + +from qdax.core.containers.repertoire import Repertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn +from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey + + +@dataclass +class QualityDCGConfig: + """Configuration for QualityDCG Emitter""" + + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class QualityDCGEmitterState(EmitterState): + """Contains training state for the learner.""" + + critic_params: Params + critic_opt_state: optax.OptState + actor_params: Params + actor_opt_state: optax.OptState + target_critic_params: Params + target_actor_params: Params + replay_buffer: ReplayBuffer + random_key: RNGKey + steps: jnp.ndarray + + +class QualityDCGEmitter(Emitter): + """ + A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites + (PGA-Map-Elites) algorithm. + """ + + def __init__( + self, + config: QualityDCGConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + ) -> None: + self._config = config + self._env = env + self._policy_network = policy_network + self._actor_network = actor_network + + # Init Critics + critic_network = QModuleDC( + n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size + ) + self._critic_network = critic_network + + # Set up the losses and optimizers - return the opt states + ( + self._policy_loss_fn, + self._actor_loss_fn, + self._critic_loss_fn, + ) = make_td3_loss_dc_fn( + policy_fn=policy_network.apply, + actor_fn=actor_network.apply, + critic_fn=critic_network.apply, + reward_scaling=self._config.reward_scaling, + discount=self._config.discount, + noise_clip=self._config.noise_clip, + policy_noise=self._config.policy_noise, + ) + + # Init optimizers + self._actor_optimizer = optax.adam( + learning_rate=self._config.actor_learning_rate + ) + self._critic_optimizer = optax.adam( + learning_rate=self._config.critic_learning_rate + ) + self._policies_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate + ) + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._config.qpg_batch_size + self._config.ai_batch_size + + @property + def use_all_data(self) -> bool: + """Whether to use all data or not when used along other emitters. + + QualityPGEmitter uses the transitions from the genotypes that were generated + by other emitters. + """ + return True + + def init( + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> Tuple[QualityDCGEmitterState, RNGKey]: + """Initializes the emitter state. + + Args: + genotypes: The initial population. + random_key: A random key. + + Returns: + The initial state of the PGAMEEmitter, a new random key. + """ + + observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] + descriptor_size = self._env.behavior_descriptor_length + action_size = self._env.action_size + + # Initialise critic, greedy actor and population + random_key, subkey = jax.random.split(random_key) + fake_obs = jnp.zeros(shape=(observation_size,)) + fake_desc = jnp.zeros(shape=(descriptor_size,)) + fake_action = jnp.zeros(shape=(action_size,)) + + critic_params = self._critic_network.init( + subkey, obs=fake_obs, actions=fake_action, desc=fake_desc + ) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + + random_key, subkey = jax.random.split(random_key) + actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) + target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + + # Prepare init optimizer states + critic_opt_state = self._critic_optimizer.init(critic_params) + actor_opt_state = self._actor_optimizer.init(actor_params) + + # Initialize replay buffer + dummy_transition = DCGTransition.init_dummy( + observation_dim=self._env.observation_size, + action_dim=action_size, + descriptor_dim=descriptor_size, + ) + + replay_buffer = ReplayBuffer.init( + buffer_size=self._config.replay_buffer_size, transition=dummy_transition + ) + + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_normalized + ) + replay_buffer = replay_buffer.insert(transitions) + + # Initial training state + random_key, subkey = jax.random.split(random_key) + emitter_state = QualityDCGEmitterState( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + replay_buffer=replay_buffer, + random_key=subkey, + steps=jnp.array(0), + ) + + return emitter_state, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: + """Compute the similarity between two batches of descriptors. + Args: + descs_1: batch of descriptors. + descs_2: batch of descriptors. + Returns: + batch of similarity measures. + """ + return jnp.exp( + -jnp.linalg.norm(descs_1 - descs_2, axis=-1) / self._config.lengthscale + ) + + @partial(jax.jit, static_argnames=("self",)) + def _normalize_desc(self, desc: Descriptor) -> Descriptor: + return ( + 2 + * (desc - self._env.behavior_descriptor_limits[0]) + / ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) + - 1 + ) + + @partial(jax.jit, static_argnames=("self",)) + def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: + return 0.5 * ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) * desc_normalized + 0.5 * ( + self._env.behavior_descriptor_limits[1] + + self._env.behavior_descriptor_limits[0] + ) + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_kernel_bias_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Tuple[Params, Params]: + """ + Compute the equivalent bias of the first layer of the actor network + given a descriptor. + """ + # Extract kernel and bias of the first layer + kernel = actor_dc_params["params"]["Dense_0"]["kernel"] + bias = actor_dc_params["params"]["Dense_0"]["bias"] + + # Compute the equivalent bias + equivalent_kernel = kernel[: -desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0] :]) + + return equivalent_kernel, equivalent_bias + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_params_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Params: + desc_normalized = self._normalize_desc(desc) + ( + equivalent_kernel, + equivalent_bias, + ) = self._compute_equivalent_kernel_bias_with_desc( + actor_dc_params, desc_normalized + ) + actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel + actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias + return actor_dc_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit( + self, + repertoire: Repertoire, + emitter_state: QualityDCGEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, ExtraScores, RNGKey]: + """Do a step of PG emission. + + Args: + repertoire: the current repertoire of genotypes + emitter_state: the state of the emitter used + random_key: a random key + + Returns: + A batch of offspring, the new emitter state and a new key. + """ + # PG emitter + parents_pg, descs_pg, random_key = repertoire.sample_with_descs( + random_key, self._config.qpg_batch_size + ) + genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) + + # Actor injection emitter + _, descs_ai, random_key = repertoire.sample_with_descs( + random_key, self._config.ai_batch_size + ) + descs_ai = descs_ai.reshape( + descs_ai.shape[0], self._env.behavior_descriptor_length + ) + genotypes_ai = self.emit_ai(emitter_state, descs_ai) + + # Concatenate PG and AI genotypes + genotypes = jax.tree_util.tree_map( + lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai + ) + + return ( + genotypes, + {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, + random_key, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_pg( + self, + emitter_state: QualityDCGEmitterState, + parents: Genotype, + descs: Descriptor, + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + mutation_fn = partial( + self._mutation_function_pg, + emitter_state=emitter_state, + ) + offsprings = jax.vmap(mutation_fn)(parents, descs) + + return offsprings + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_ai( + self, emitter_state: QualityDCGEmitterState, descs: Descriptor + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + offsprings = jax.vmap( + self._compute_equivalent_params_with_desc, in_axes=(None, 0) + )(emitter_state.actor_params, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",)) + def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + """Emit the greedy actor. + + Simply needs to be retrieved from the emitter state. + + Args: + emitter_state: the current emitter state, it stores the + greedy actor. + + Returns: + The parameters of the actor. + """ + return emitter_state.actor_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def state_update( + self, + emitter_state: QualityDCGEmitterState, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> QualityDCGEmitterState: + """This function gives an opportunity to update the emitter state + after the genotypes have been scored. + + Here it is used to fill the Replay Buffer with the transitions + from the scoring of the genotypes, and then the training of the + critic/actor happens. Hence the params of critic/actor are updated, + as well as their optimizer states. + + Args: + emitter_state: current emitter state. + repertoire: the current genotypes repertoire + genotypes: unused here - but compulsory in the signature. + fitnesses: unused here - but compulsory in the signature. + descriptors: unused here - but compulsory in the signature. + extra_scores: extra information coming from the scoring function, + this contains the transitions added to the replay buffer. + + Returns: + New emitter state where the replay buffer has been filled with + the new experienced transitions. + """ + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc_prime = jnp.concatenate( + [ + extra_scores["desc_prime"], + descriptors[self._config.qpg_batch_size + self._config.ai_batch_size :], + ], + axis=0, + ) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_prime_normalized + ) + + # Add transitions to replay buffer + replay_buffer = emitter_state.replay_buffer.insert(transitions) + emitter_state = emitter_state.replace(replay_buffer=replay_buffer) + + # sample transitions from the replay buffer + random_key, subkey = jax.random.split(emitter_state.random_key) + transitions, random_key = replay_buffer.sample( + subkey, self._config.num_critic_training_steps * self._config.batch_size + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_critic_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, transitions.desc_prime) + * transitions.rewards + ) + emitter_state = emitter_state.replace(random_key=random_key) + + def scan_train_critics( + carry: QualityDCGEmitterState, + transitions: DCGTransition, + ) -> Tuple[QualityDCGEmitterState, Any]: + emitter_state = carry + new_emitter_state = self._train_critics(emitter_state, transitions) + return new_emitter_state, () + + # Train critics and greedy actor + emitter_state, _ = jax.lax.scan( + scan_train_critics, + emitter_state, + transitions, + length=self._config.num_critic_training_steps, + ) + + return emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _train_critics( + self, emitter_state: QualityDCGEmitterState, transitions: DCGTransition + ) -> QualityDCGEmitterState: + """Apply one gradient step to critics and to the greedy actor + (contained in carry in training_state), then soft update target critics + and target actor. + + Those updates are very similar to those made in TD3. + + Args: + emitter_state: actual emitter state + + Returns: + New emitter state where the critic and the greedy actor have been + updated. Optimizer states have also been updated in the process. + """ + # Update Critic + ( + critic_opt_state, + critic_params, + target_critic_params, + random_key, + ) = self._update_critic( + critic_params=emitter_state.critic_params, + target_critic_params=emitter_state.target_critic_params, + target_actor_params=emitter_state.target_actor_params, + critic_opt_state=emitter_state.critic_opt_state, + transitions=transitions, + random_key=emitter_state.random_key, + ) + + # Update greedy actor + (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + emitter_state.steps % self._config.policy_delay == 0, + lambda x: self._update_actor(*x), + lambda _: ( + emitter_state.actor_opt_state, + emitter_state.actor_params, + emitter_state.target_actor_params, + ), + operand=( + emitter_state.actor_params, + emitter_state.actor_opt_state, + emitter_state.target_actor_params, + emitter_state.critic_params, + transitions, + ), + ) + + # Create new training state + new_emitter_state = emitter_state.replace( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + random_key=random_key, + steps=emitter_state.steps + 1, + ) + + return new_emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _update_critic( + self, + critic_params: Params, + target_critic_params: Params, + target_actor_params: Params, + critic_opt_state: Params, + transitions: DCGTransition, + random_key: RNGKey, + ) -> Tuple[Params, Params, Params, RNGKey]: + + # compute loss and gradients + random_key, subkey = jax.random.split(random_key) + critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_params, + target_actor_params, + target_critic_params, + transitions, + subkey, + ) + critic_updates, critic_opt_state = self._critic_optimizer.update( + critic_gradient, critic_opt_state + ) + + # update critic + critic_params = optax.apply_updates(critic_params, critic_updates) + + # Soft update of target critic network + target_critic_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_critic_params, + critic_params, + ) + + return critic_opt_state, critic_params, target_critic_params, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _update_actor( + self, + actor_params: Params, + actor_opt_state: optax.OptState, + target_actor_params: Params, + critic_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params, Params]: + + # Update greedy actor + policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + actor_params, + critic_params, + transitions, + ) + ( + policy_updates, + actor_opt_state, + ) = self._actor_optimizer.update(policy_gradient, actor_opt_state) + actor_params = optax.apply_updates(actor_params, policy_updates) + + # Soft update of target greedy actor + target_actor_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_actor_params, + actor_params, + ) + + return ( + actor_opt_state, + actor_params, + target_actor_params, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def _mutation_function_pg( + self, + policy_params: Genotype, + descs: Descriptor, + emitter_state: QualityDCGEmitterState, + ) -> Genotype: + """Apply pg mutation to a policy via multiple steps of gradient descent. + First, update the rewards to be diversity rewards, then apply the gradient + steps. + + Args: + policy_params: a policy, supposed to be a differentiable neural + network. + emitter_state: the current state of the emitter, containing among others, + the replay buffer, the critic. + + Returns: + The updated params of the neural network. + """ + # Get transitions + transitions, random_key = emitter_state.replay_buffer.sample( + emitter_state.random_key, + sample_size=self._config.num_pg_training_steps * self._config.batch_size, + ) + descs_prime = jnp.tile( + descs, (self._config.num_pg_training_steps * self._config.batch_size, 1) + ) + descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, descs_prime_normalized) + * transitions.rewards, + desc_prime=descs_prime_normalized, + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_pg_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + + # Replace random_key + emitter_state = emitter_state.replace(random_key=random_key) + + # Define new policy optimizer state + policy_opt_state = self._policies_optimizer.init(policy_params) + + def scan_train_policy( + carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], + transitions: DCGTransition, + ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + emitter_state, policy_params, policy_opt_state = carry + ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ) = self._train_policy( + emitter_state, + policy_params, + policy_opt_state, + transitions, + ) + return ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ), () + + (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + scan_train_policy, + (emitter_state, policy_params, policy_opt_state), + transitions, + length=self._config.num_pg_training_steps, + ) + + return policy_params + + @partial(jax.jit, static_argnames=("self",)) + def _train_policy( + self, + emitter_state: QualityDCGEmitterState, + policy_params: Params, + policy_opt_state: optax.OptState, + transitions: DCGTransition, + ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + """Apply one gradient step to a policy (called policy_params). + + Args: + emitter_state: current state of the emitter. + policy_params: parameters corresponding to the weights and bias of + the neural network that defines the policy. + + Returns: + The new emitter state and new params of the NN. + """ + # update policy + policy_opt_state, policy_params = self._update_policy( + critic_params=emitter_state.critic_params, + policy_opt_state=policy_opt_state, + policy_params=policy_params, + transitions=transitions, + ) + + return emitter_state, policy_params, policy_opt_state + + @partial(jax.jit, static_argnames=("self",)) + def _update_policy( + self, + critic_params: Params, + policy_opt_state: optax.OptState, + policy_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params]: + + # compute loss + _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_params, + critic_params, + transitions, + ) + # Compute gradient and update policies + ( + policy_updates, + policy_opt_state, + ) = self._policies_optimizer.update(policy_gradient, policy_opt_state) + policy_params = optax.apply_updates(policy_params, policy_updates) + + return policy_opt_state, policy_params diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..4a173b51 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -119,12 +119,18 @@ def use_all_data(self) -> bool: return True def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[QualityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -144,8 +150,8 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -162,6 +168,13 @@ def init( buffer_size=self._config.replay_buffer_size, transition=dummy_transition ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + # add transitions in the replay buffer + replay_buffer = replay_buffer.insert(transitions) + # Initial training state random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( @@ -171,9 +184,9 @@ def init( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, + replay_buffer=replay_buffer, random_key=subkey, steps=jnp.array(0), - replay_buffer=replay_buffer, ) return emitter_state, random_key @@ -187,7 +200,7 @@ def emit( repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: @@ -223,7 +236,7 @@ def emit( offspring_actor, ) - return genotypes, random_key + return genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 8b877792..860962d4 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Genotype, RNGKey +from qdax.types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): @@ -31,7 +31,7 @@ def emit( repertoire: Repertoire, emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, @@ -75,7 +75,7 @@ def emit( x_mutation, ) - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..8b649d0c 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -23,7 +23,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independent of the GeneticAlgorithm class at the moment to keep + MAPElites class independant of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -52,7 +52,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -62,9 +62,9 @@ def init( such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tessellation centroids of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: @@ -73,12 +73,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MapElitesRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -87,14 +87,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -129,9 +124,10 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key @@ -147,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6c06b785..6dc8f551 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -55,7 +55,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: @@ -64,7 +64,7 @@ def init( be computed with any method such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -75,12 +75,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MELSRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -89,14 +89,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index 2a004f59..db450b9a 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -23,7 +23,7 @@ class MOME(MAPElites): @partial(jax.jit, static_argnames=("self", "pareto_front_max_length")) def init( self, - init_genotypes: jnp.ndarray, + genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, random_key: RNGKey, @@ -33,7 +33,7 @@ def init( CVT or Euclidean mapping. Args: - init_genotypes: genotypes of the initial population. + genotypes: genotypes of the initial population. centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. @@ -45,12 +45,12 @@ def init( # first score fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MOMERepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -60,14 +60,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 42ed7552..d25c8c6c 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,15 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor +from qdax.types import ( + Action, + Descriptor, + Done, + Observation, + Reward, + RNGKey, + StateDescriptor, +) class Transition(flax.struct.PyTreeNode): @@ -262,6 +270,155 @@ def init_dummy( # type: ignore return dummy_transition +class DCGTransition(QDTransition): + """Stores data corresponding to a transition collected by a QD algorithm.""" + + desc: Descriptor + desc_prime: Descriptor + + @property + def descriptor_dim(self) -> int: + """ + Returns: + the dimension of the descriptors. + """ + return self.state_desc.shape[-1] # type: ignore + + @property + def flatten_dim(self) -> int: + """ + Returns: + the dimension of the transition once flattened. + """ + flatten_dim = ( + 2 * self.observation_dim + + self.action_dim + + 3 + + 2 * self.state_descriptor_dim + + 2 * self.descriptor_dim + ) + return flatten_dim + + def flatten(self) -> jnp.ndarray: + """ + Returns: + a jnp.ndarray that corresponds to the flattened transition. + """ + flatten_transition = jnp.concatenate( + [ + self.obs, + self.next_obs, + jnp.expand_dims(self.rewards, axis=-1), + jnp.expand_dims(self.dones, axis=-1), + jnp.expand_dims(self.truncations, axis=-1), + self.actions, + self.state_desc, + self.next_state_desc, + self.desc, + self.desc_prime, + ], + axis=-1, + ) + return flatten_transition + + @classmethod + def from_flatten( + cls, + flattened_transition: jnp.ndarray, + transition: QDTransition, + ) -> QDTransition: + """ + Creates a transition from a flattened transition in a jnp.ndarray. + Args: + flattened_transition: flattened transition in a jnp.ndarray of shape + (batch_size, flatten_dim) + transition: a transition object (might be a dummy one) to + get the dimensions right + Returns: + a Transition object + """ + obs_dim = transition.observation_dim + action_dim = transition.action_dim + state_desc_dim = transition.state_descriptor_dim + desc_dim = transition.descriptor_dim + + obs = flattened_transition[:, :obs_dim] + next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)] + rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)]) + dones = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)] + ) + truncations = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)] + ) + actions = flattened_transition[ + :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim) + ] + state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim) : ( + 2 * obs_dim + 3 + action_dim + state_desc_dim + ), + ] + next_state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + ), + ] + desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim + ), + ] + desc_prime = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + 2 * desc_dim + ), + ] + return cls( + obs=obs, + next_obs=next_obs, + rewards=rewards, + dones=dones, + truncations=truncations, + actions=actions, + state_desc=state_desc, + next_state_desc=next_state_desc, + desc=desc, + desc_prime=desc_prime, + ) + + @classmethod + def init_dummy( # type: ignore + cls, observation_dim: int, action_dim: int, descriptor_dim: int + ) -> QDTransition: + """ + Initialize a dummy transition that then can be passed to constructors to get + all shapes right. + Args: + observation_dim: observation dimension + action_dim: action dimension + Returns: + a dummy transition + """ + dummy_transition = DCGTransition( + obs=jnp.zeros(shape=(1, observation_dim)), + next_obs=jnp.zeros(shape=(1, observation_dim)), + rewards=jnp.zeros(shape=(1,)), + dones=jnp.zeros(shape=(1,)), + truncations=jnp.zeros(shape=(1,)), + actions=jnp.zeros(shape=(1, action_dim)), + state_desc=jnp.zeros(shape=(1, descriptor_dim)), + next_state_desc=jnp.zeros(shape=(1, descriptor_dim)), + desc=jnp.zeros(shape=(1, descriptor_dim)), + desc_prime=jnp.zeros(shape=(1, descriptor_dim)), + ) + return dummy_transition + + class ReplayBuffer(flax.struct.PyTreeNode): """ A replay buffer where transitions are flattened before being stored. diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 7f34a036..e12797b9 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( @@ -94,6 +94,110 @@ def _critic_loss_fn( return _policy_loss_fn, _critic_loss_fn +def make_td3_loss_dc_fn( + policy_fn: Callable[[Params, Observation], jnp.ndarray], + actor_fn: Callable[[Params, Observation, Descriptor], jnp.ndarray], + critic_fn: Callable[[Params, Observation, Action, Descriptor], jnp.ndarray], + reward_scaling: float, + discount: float, + noise_clip: float, + policy_noise: float, +) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], +]: + """Creates the loss functions for TD3. + Args: + policy_fn: forward pass through the neural network defining the policy. + actor_fn: forward pass through the neural network defining the + descriptor-conditioned policy. + critic_fn: forward pass through the neural network defining the + descriptor-conditioned critic. + reward_scaling: value to multiply the reward given by the environment. + discount: discount factor. + noise_clip: value that clips the noise to avoid extreme values. + policy_noise: noise applied to smooth the bootstrapping. + Returns: + Return the loss functions used to train the policy and the critic in TD3. + """ + + @jax.jit + def _policy_loss_fn( + policy_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Policy loss function for TD3 agent""" + action = policy_fn(policy_params, transitions.obs) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _actor_loss_fn( + actor_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Descriptor-conditioned policy loss function for TD3 agent""" + action = actor_fn(actor_params, transitions.obs, transitions.desc_prime) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _critic_loss_fn( + critic_params: Params, + target_actor_params: Params, + target_critic_params: Params, + transitions: Transition, + random_key: RNGKey, + ) -> jnp.ndarray: + """Descriptor-conditioned critic loss function for TD3 agent""" + noise = ( + jax.random.normal(random_key, shape=transitions.actions.shape) + * policy_noise + ).clip(-noise_clip, noise_clip) + + next_action = ( + actor_fn(target_actor_params, transitions.next_obs, transitions.desc_prime) + + noise + ).clip(-1.0, 1.0) + next_q = critic_fn( + target_critic_params, + transitions.next_obs, + next_action, + transitions.desc_prime, + ) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.rewards * reward_scaling + + (1.0 - transitions.dones) * discount * next_v + ) + q_old_action = critic_fn( + critic_params, transitions.obs, transitions.actions, transitions.desc_prime + ) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + + # Better bootstrapping for truncated episodes. + q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1) + + # compute the loss + q_losses = jnp.mean(jnp.square(q_error), axis=-2) + q_loss = jnp.sum(q_losses, axis=-1) + + return q_loss + + return _policy_loss_fn, _actor_loss_fn, _critic_loss_fn + + def td3_policy_loss_fn( policy_params: Params, critic_params: Params, @@ -115,9 +219,7 @@ def td3_policy_loss_fn( """ action = policy_fn(policy_params, transitions.obs) - q_value = critic_fn( - critic_params, obs=transitions.obs, actions=action # type: ignore - ) + q_value = critic_fn(critic_params, transitions.obs, action) # type: ignore q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 984d1aeb..3b077069 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey +from qdax.types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -67,6 +67,60 @@ def _scan_play_step_fn( return state, transitions +@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) +def generate_unroll_actor_dc( + init_state: EnvState, + actor_dc_params: Params, + desc: Descriptor, + random_key: RNGKey, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[ + EnvState, + Descriptor, + Params, + RNGKey, + Transition, + ], + ], +) -> Tuple[EnvState, Transition]: + """Generates an episode according to the agent's policy and descriptor, + returns the final state of the episode and the transitions of the episode. + + Args: + init_state: first state of the rollout. + policy_dc_params: descriptor-conditioned policy params. + desc: descriptor the policy attempts to achieve. + random_key: random key for stochasiticity handling. + episode_length: length of the rollout. + play_step_fn: function describing how a step need to be taken. + + Returns: + A new state, the experienced transition. + """ + + def _scan_play_step_fn( + carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any + ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: + ( + env_state, + actor_dc_params, + desc, + random_key, + transitions, + ) = play_step_actor_dc_fn(*carry) + return (env_state, actor_dc_params, desc, random_key), transitions + + (state, _, _, _), transitions = jax.lax.scan( + _scan_play_step_fn, + (init_state, actor_dc_params, desc, random_key), + (), + length=episode_length, + ) + return state, transitions + + @jax.jit def get_first_episode(transition: Transition) -> Transition: """Extracts the first episode from a batch of transitions, returns the batch of diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index b2b176ef..365c8d56 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -5,31 +5,51 @@ import flax.linen as nn import jax import jax.numpy as jnp -from brax.training import networks -class QModule(nn.Module): - """Q Module.""" +class MLP(nn.Module): + """MLP module.""" - hidden_layer_sizes: Tuple[int, ...] - n_critics: int = 2 + layer_sizes: Tuple[int, ...] + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() + final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + bias: bool = True + kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: - hidden = jnp.concatenate([obs, actions], axis=-1) - res = [] - for _ in range(self.n_critics): - q = networks.MLP( - layer_sizes=self.hidden_layer_sizes + (1,), - activation=nn.relu, - kernel_init=jax.nn.initializers.lecun_uniform(), - )(hidden) - res.append(q) - return jnp.concatenate(res, axis=-1) + def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: + hidden = obs + for i, hidden_size in enumerate(self.layer_sizes): + if i != len(self.layer_sizes) - 1: + hidden = nn.Dense( + hidden_size, + kernel_init=self.kernel_init, + use_bias=self.bias, + )(hidden) + hidden = self.activation(hidden) # type: ignore -class MLP(nn.Module): - """MLP module.""" + else: + if self.kernel_init_final is not None: + kernel_init = self.kernel_init_final + else: + kernel_init = self.kernel_init + + hidden = nn.Dense( + hidden_size, + kernel_init=kernel_init, + use_bias=self.bias, + )(hidden) + + if self.final_activation is not None: + hidden = self.final_activation(hidden) + + return hidden + + +class MLPDC(nn.Module): + """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -39,15 +59,13 @@ class MLP(nn.Module): kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, data: jnp.ndarray) -> jnp.ndarray: - hidden = data + def __call__(self, obs: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, desc], axis=-1) for i, hidden_size in enumerate(self.layer_sizes): if i != len(self.layer_sizes) - 1: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", with this version of flax, changing the name - # changes the initialization kernel_init=self.kernel_init, use_bias=self.bias, )(hidden) @@ -61,7 +79,6 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", kernel_init=kernel_init, use_bias=self.bias, )(hidden) @@ -70,3 +87,45 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = self.final_activation(hidden) return hidden + + +class QModule(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + + +class QModuleDC(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__( + self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray + ) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLPDC( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden, desc) + res.append(q) + return jnp.concatenate(res, axis=-1) diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 6f317e7f..3f709fa7 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -1,6 +1,7 @@ from abc import abstractmethod -from typing import Any, List, Tuple +from typing import Any, Tuple +import jax from brax.v1 import jumpy as jp from brax.v1.envs import Env, State @@ -22,7 +23,7 @@ def state_descriptor_name(self) -> str: @property @abstractmethod - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -32,7 +33,7 @@ def behavior_descriptor_length(self) -> int: @property @abstractmethod - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -71,7 +72,7 @@ def state_descriptor_name(self) -> str: return self.env.state_descriptor_name @property - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.state_descriptor_limits @property @@ -79,7 +80,7 @@ def behavior_descriptor_length(self) -> int: return self.env.behavior_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.behavior_descriptor_limits @property diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 720f662a..cf0c3336 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -3,7 +3,7 @@ import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper +from brax.v1.envs import Env, State, Wrapper class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -69,3 +69,80 @@ def step(self, state: State, action: jp.ndarray) -> State: ) nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + + +class ClipRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__( + self, env: Env, clip_min: float = None, clip_max: float = None + ) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + +class AffineRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__( + self, env: Env, clip_min: float = None, clip_max: float = None + ) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + +class OffsetRewardWrapper(Wrapper): + """Wraps gym environments to offset the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, offset: float = 0.0) -> None: + super().__init__(env) + self._offset = offset + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=state.reward + self._offset) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=state.reward + self._offset) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..1a588e52 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -10,7 +10,7 @@ import qdax.environments from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( Descriptor, @@ -160,6 +160,81 @@ def scoring_function_brax_envs( ) +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll_actor_dc, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + @partial( jax.jit, static_argnames=( @@ -225,6 +300,83 @@ def reset_based_scoring_function_brax_envs( return fitnesses, descriptors, extra_scores, random_key +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment + and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split( + subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] + ) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + ( + fitnesses, + descriptors, + extra_scores, + random_key, + ) = scoring_actor_dc_function_brax_envs( + actors_dc_params=actors_dc_params, + descs=descs, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, diff --git a/requirements.txt b/requirements.txt index 978a1c87..50d7899c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + absl-py==1.0.0 brax==0.9.2 chex==0.1.83 @@ -5,8 +7,7 @@ dm-haiku==0.0.10 flax==0.7.4 gym==0.26.2 ipython -jax==0.4.16 -jaxlib==0.4.16 +jax[cuda12_pip] jumanji==0.3.1 jupyter numpy==1.24.1 diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..5f9ec5f7 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -73,7 +73,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), minval=minval, @@ -111,11 +111,11 @@ def scoring_fn( if isinstance(algo_instance, SPEA2): repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, num_neighbours, random_key + genotypes, population_size, num_neighbours, random_key ) else: repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, random_key + genotypes, population_size, random_key ) # Run the algorithm diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 079fde45..98a5b960 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -126,7 +126,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -178,7 +178,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 5c6fbb0a..fc2e89b0 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -124,7 +124,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -176,7 +176,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..103f9489 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -81,7 +81,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, num_variables), minval=minval, @@ -127,7 +127,7 @@ def scoring_fn( ) repertoire, emitter_state, random_key = mome.init( - init_genotypes, centroids, pareto_front_max_length, random_key + genotypes, centroids, pareto_front_max_length, random_key ) # Run the algorithm