From c6d33a6a9fc7fc409742ef899a284f66abe9b6b2 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 21 Dec 2023 09:32:47 +0000 Subject: [PATCH 01/11] DCG-MAP-Elites --- qdax/core/containers/mapelites_repertoire.py | 28 + qdax/core/emitters/cma_emitter.py | 12 +- qdax/core/emitters/cma_mega_emitter.py | 14 +- qdax/core/emitters/cma_pool_emitter.py | 14 +- qdax/core/emitters/dcg_me_emitter.py | 88 +++ qdax/core/emitters/dpg_emitter.py | 16 +- qdax/core/emitters/emitter.py | 8 +- qdax/core/emitters/mees_emitter.py | 24 +- qdax/core/emitters/multi_emitter.py | 18 +- qdax/core/emitters/omg_mega_emitter.py | 14 +- qdax/core/emitters/pbt_me_emitter.py | 18 +- qdax/core/emitters/qdcg_emitter.py | 658 ++++++++++++++++++ qdax/core/emitters/qpg_emitter.py | 22 +- qdax/core/emitters/standard_emitters.py | 2 +- qdax/core/map_elites.py | 26 +- qdax/core/neuroevolution/buffers/buffer.py | 148 +++- qdax/core/neuroevolution/losses/td3_loss.py | 93 ++- qdax/core/neuroevolution/mdp_utils.py | 50 +- qdax/core/neuroevolution/networks/networks.py | 100 ++- qdax/environments/wrappers.py | 64 +- qdax/tasks/brax_envs.py | 195 ++++-- 21 files changed, 1471 insertions(+), 141 deletions(-) create mode 100644 qdax/core/emitters/dcg_me_emitter.py create mode 100644 qdax/core/emitters/qdcg_emitter.py diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..5968b03f 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -228,6 +228,34 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key + @partial(jax.jit, static_argnames=("num_samples",)) + def sample_with_descs(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + repertoire_empty = self.fitnesses == -jnp.inf + p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) + + random_key, subkey = jax.random.split(random_key) + samples = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.genotypes, + ) + descs = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.descriptors, + ) + + return samples, descs, random_key + @jax.jit def add( self, diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..c090e448 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -99,14 +99,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -154,7 +160,7 @@ def emit( cmaes_state=emitter_state.cmaes_state, random_key=random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..f79579dd 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -100,14 +100,20 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAMEGAState, RNGKey]: """ Initializes the CMA-MEGA emitter. Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -117,7 +123,7 @@ def init( # define init theta as 0 theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), - init_genotypes, + genotypes, ) # score it @@ -181,7 +187,7 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, random_key + return new_thetas, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5424a01..67034d71 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -49,14 +49,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAPoolEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -67,7 +73,7 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + emitter_state, random_key = self._emitter.init(genotypes, random_key) return random_key, emitter_state # init all the emitter states @@ -115,7 +121,7 @@ def emit( repertoire, used_emitter_state, random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py new file mode 100644 index 00000000..b9ae628b --- /dev/null +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Callable, Tuple + +import flax.linen as nn + +from qdax.core.emitters.multi_emitter import MultiEmitter +from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Params, RNGKey + + +@dataclass +class DCGMEConfig: + """Configuration for DCGME Algorithm""" + + ga_batch_size: int = 128 + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + # PG emitter + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCGMEEmitter(MultiEmitter): + def __init__( + self, + config: DCGMEConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]], + ) -> None: + self._config = config + self._env = env + self._variation_fn = variation_fn + + qdcg_config = QualityDCGConfig( + qpg_batch_size=config.qpg_batch_size, + ai_batch_size=config.ai_batch_size, + lengthscale=config.lengthscale, + critic_hidden_layer_size=config.critic_hidden_layer_size, + num_critic_training_steps=config.num_critic_training_steps, + num_pg_training_steps=config.num_pg_training_steps, + batch_size=config.batch_size, + replay_buffer_size=config.replay_buffer_size, + discount=config.discount, + reward_scaling=config.reward_scaling, + critic_learning_rate=config.critic_learning_rate, + actor_learning_rate=config.actor_learning_rate, + policy_learning_rate=config.policy_learning_rate, + noise_clip=config.noise_clip, + policy_noise=config.policy_noise, + soft_tau_update=config.soft_tau_update, + policy_delay=config.policy_delay, + ) + + # define the quality emitter + q_emitter = QualityDCGEmitter( + config=qdcg_config, + policy_network=policy_network, + actor_network=actor_network, + env=env + ) + + # define the GA emitter + ga_emitter = MixingEmitter( + mutation_fn=lambda x, r: (x, r), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=config.ga_batch_size, + ) + + super().__init__(emitters=(q_emitter, ga_emitter)) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..c266be10 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -10,7 +10,7 @@ import optax from qdax.core.containers.archive import Archive -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.repertoire import MapElitesRepertoire from qdax.core.emitters.qpg_emitter import ( QualityPGConfig, QualityPGEmitter, @@ -77,12 +77,18 @@ def __init__( self._score_novelty = score_novelty def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[DiversityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -90,7 +96,7 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(init_genotypes, random_key) + diversity_emitter_state, random_key = super().init(genotypes, random_key) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -116,7 +122,7 @@ def init( def state_update( self, emitter_state: DiversityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[MapElitesRepertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..14c6277a 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -30,7 +30,13 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..9641a613 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -236,26 +236,32 @@ def batch_size(self) -> int: static_argnames=("self",), ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[MEESEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1: - init_genotypes = jax.tree_util.tree_map( + if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree_util.tree_map( lambda x: x[0], - init_genotypes, + genotypes, ) # Initialise optimizer - initial_optimizer_state = self._optimizer.init(init_genotypes) + initial_optimizer_state = self._optimizer.init(genotypes) # Create empty Novelty archive if self._config.use_explore: @@ -270,7 +276,7 @@ def init( # Create empty updated genotypes and fitness last_updated_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), - init_genotypes, + genotypes, ) last_updated_fitnesses = -jnp.inf * jnp.ones( shape=self._config.last_updated_size @@ -280,7 +286,7 @@ def init( MEESEmitterState( initial_optimizer_state=initial_optimizer_state, optimizer_state=initial_optimizer_state, - offspring=init_genotypes, + offspring=genotypes, generation_count=0, novelty_archive=novelty_archive, last_updated_genotypes=last_updated_genotypes, @@ -313,7 +319,7 @@ def emit( a new jax PRNG key """ - return emitter_state.offspring, random_key + return emitter_state.offspring, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 2da46639..a0789f33 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -56,13 +56,19 @@ def get_indexes_separation_batches( return tuple(indexes_separation_batches) def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """ Initialize the state of the emitter. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -76,7 +82,7 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(init_genotypes, subkey_emitter) + emitter_state, _ = emitter.init(subkey_emitter, repertoire, genotypes, fitnesses, descriptors, extra_scores) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -108,21 +114,23 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] + all_extra_info = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) + all_extra_info = all_extra_info | extra_info # concatenate offsprings together offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, random_key + return offsprings, all_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..2380a85c 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -84,20 +84,26 @@ def __init__( self._num_descriptors = num_descriptors def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[OMGMEGAEmitterState, RNGKey]: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 gradient_genotype = jax.tree_util.tree_map( @@ -190,7 +196,7 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, random_key + return new_genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 3fdb4418..64c05f16 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -91,12 +91,18 @@ def __init__( ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[PBTEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -145,13 +151,13 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - init_genotypes = jax.tree_util.tree_map( - lambda x: x[: self._config.pg_population_size_per_device], init_genotypes + genotypes = jax.tree_util.tree_map( + lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, - training_states=init_genotypes, + training_states=genotypes, random_key=subkey2, ) @@ -199,7 +205,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py new file mode 100644 index 00000000..745773bd --- /dev/null +++ b/qdax/core/emitters/qdcg_emitter.py @@ -0,0 +1,658 @@ +"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm in JAX for Brax environments. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional, Tuple, Callable + +import jax +from jax import numpy as jnp +import flax.linen as nn +from flax.core.frozen_dict import freeze +import optax + +from qdax.core.containers.repertoire import Repertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn +from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey + + +@dataclass +class QualityDCGConfig: + """Configuration for QualityDCG Emitter""" + + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class QualityDCGEmitterState(EmitterState): + """Contains training state for the learner.""" + + critic_params: Params + critic_opt_state: optax.OptState + actor_params: Params + actor_opt_state: optax.OptState + target_critic_params: Params + target_actor_params: Params + replay_buffer: ReplayBuffer + random_key: RNGKey + steps: jnp.ndarray + + +class QualityDCGEmitter(Emitter): + """ + A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites + (PGA-Map-Elites) algorithm. + """ + + def __init__( + self, + config: QualityDCGConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + ) -> None: + self._config = config + self._env = env + self._policy_network = policy_network + self._actor_network = actor_network + + # Init Critics + critic_network = QModuleDC( + n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size + ) + self._critic_network = critic_network + + # Set up the losses and optimizers - return the opt states + self._policy_loss_fn, self._actor_loss_fn, self._critic_loss_fn = make_td3_loss_dc_fn( + policy_fn=policy_network.apply, + actor_fn=actor_network.apply, + critic_fn=critic_network.apply, + reward_scaling=self._config.reward_scaling, + discount=self._config.discount, + noise_clip=self._config.noise_clip, + policy_noise=self._config.policy_noise, + ) + + # Init optimizers + self._actor_optimizer = optax.adam( + learning_rate=self._config.actor_learning_rate + ) + self._critic_optimizer = optax.adam( + learning_rate=self._config.critic_learning_rate + ) + self._policies_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate + ) + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._config.qpg_batch_size + self._config.ai_batch_size + + @property + def use_all_data(self) -> bool: + """Whether to use all data or not when used along other emitters. + + QualityPGEmitter uses the transitions from the genotypes that were generated + by other emitters. + """ + return True + + def init( + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> Tuple[QualityDCGEmitterState, RNGKey]: + """Initializes the emitter state. + + Args: + genotypes: The initial population. + random_key: A random key. + + Returns: + The initial state of the PGAMEEmitter, a new random key. + """ + + observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] + descriptor_size = self._env.behavior_descriptor_length + action_size = self._env.action_size + + # Initialise critic, greedy actor and population + random_key, subkey = jax.random.split(random_key) + fake_obs = jnp.zeros(shape=(observation_size,)) + fake_desc = jnp.zeros(shape=(descriptor_size,)) + fake_action = jnp.zeros(shape=(action_size,)) + + critic_params = self._critic_network.init( + subkey, obs=fake_obs, actions=fake_action, desc=fake_desc + ) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + + random_key, subkey = jax.random.split(random_key) + actor_params = self._actor_network.init( + subkey, obs=fake_obs, desc=fake_desc) + target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + + # Prepare init optimizer states + critic_opt_state = self._critic_optimizer.init(critic_params) + actor_opt_state = self._actor_optimizer.init(actor_params) + + # Initialize replay buffer + dummy_transition = DCGTransition.init_dummy( + observation_dim=self._env.observation_size, + action_dim=action_size, + descriptor_dim=descriptor_size, + ) + + replay_buffer = ReplayBuffer.init( + buffer_size=self._config.replay_buffer_size, transition=dummy_transition + ) + + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + + transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_normalized) + replay_buffer = replay_buffer.insert(transitions) + + # Initial training state + random_key, subkey = jax.random.split(random_key) + emitter_state = QualityDCGEmitterState( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + replay_buffer=replay_buffer, + random_key=subkey, + steps=jnp.array(0), + ) + + return emitter_state, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _similarity(self, descs_1, descs_2): + """Compute the similarity between two batches of descriptors. + Args: + descs_1: batch of descriptors, representing the observed descriptors of the trajectories. + descs_2: batch of descriptors, representing the sampled descriptors. + Returns: + batch of similarity measures. + """ + return jnp.exp(-jnp.linalg.norm(descs_1 - descs_2, axis=-1)/self._config.lengthscale) + + @partial(jax.jit, static_argnames=("self",)) + def _normalize_desc(self, desc): + return 2*(desc - self._env.behavior_descriptor_limits[0])/(self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) - 1 + + @partial(jax.jit, static_argnames=("self",)) + def _unnormalize_desc(self, desc_normalized): + return 0.5 * (self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) * desc_normalized + \ + 0.5 * (self._env.behavior_descriptor_limits[1] + self._env.behavior_descriptor_limits[0]) + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_kernel_bias_with_desc(self, actor_dc_params, desc): + """ + Compute the equivalent bias of the first layer of the actor network + given a descriptor. + """ + # Extract kernel and bias of the first layer + kernel = actor_dc_params["params"]["Dense_0"]["kernel"] + bias = actor_dc_params["params"]["Dense_0"]["bias"] + + # Compute the equivalent bias + equivalent_kernel = kernel[:-desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0]:]) + + return equivalent_kernel, equivalent_bias + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_params_with_desc(self, actor_dc_params, desc): + desc_normalized = self._normalize_desc(desc) + equivalent_kernel, equivalent_bias = self._compute_equivalent_kernel_bias_with_desc(actor_dc_params, desc_normalized) + actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel + actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias + return actor_dc_params + + @partial(jax.jit, static_argnames=("self",),) + def emit( + self, + repertoire: Repertoire, + emitter_state: QualityDCGEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, RNGKey]: + """Do a step of PG emission. + + Args: + repertoire: the current repertoire of genotypes + emitter_state: the state of the emitter used + random_key: a random key + + Returns: + A batch of offspring, the new emitter state and a new key. + """ + # PG emitter + parents_pg, descs_pg, random_key = repertoire.sample_with_descs(random_key, self._config.qpg_batch_size) + genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) + + # Actor injection emitter + _, descs_ai, random_key = repertoire.sample_with_descs(random_key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape(descs_ai.shape[0], self._env.behavior_descriptor_length) + genotypes_ai = self.emit_ai(emitter_state, descs_ai) + + # Concatenate PG and AI genotypes + genotypes = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai) + + return genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, random_key + + @partial(jax.jit, static_argnames=("self",),) + def emit_pg( + self, emitter_state: QualityDCGEmitterState, parents: Genotype, descs: Descriptor) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + mutation_fn = partial( + self._mutation_function_pg, + emitter_state=emitter_state, + ) + offsprings = jax.vmap(mutation_fn)(parents, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",),) + def emit_ai( + self, emitter_state: QualityDCGEmitterState, descs: Descriptor + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + offsprings = jax.vmap(self._compute_equivalent_params_with_desc, in_axes=(None, 0))(emitter_state.actor_params, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",)) + def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + """Emit the greedy actor. + + Simply needs to be retrieved from the emitter state. + + Args: + emitter_state: the current emitter state, it stores the + greedy actor. + + Returns: + The parameters of the actor. + """ + return emitter_state.actor_params + + @partial(jax.jit, static_argnames=("self",),) + def state_update( + self, + emitter_state: QualityDCGEmitterState, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> QualityDCGEmitterState: + """This function gives an opportunity to update the emitter state + after the genotypes have been scored. + + Here it is used to fill the Replay Buffer with the transitions + from the scoring of the genotypes, and then the training of the + critic/actor happens. Hence the params of critic/actor are updated, + as well as their optimizer states. + + Args: + emitter_state: current emitter state. + repertoire: the current genotypes repertoire + genotypes: unused here - but compulsory in the signature. + fitnesses: unused here - but compulsory in the signature. + descriptors: unused here - but compulsory in the signature. + extra_scores: extra information coming from the scoring function, + this contains the transitions added to the replay buffer. + + Returns: + New emitter state where the replay buffer has been filled with + the new experienced transitions. + """ + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc_prime = jnp.concatenate([extra_scores["desc_prime"], descriptors[self._config.qpg_batch_size+self._config.ai_batch_size:]], axis=0) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_prime_normalized) + + # Add transitions to replay buffer + replay_buffer = emitter_state.replay_buffer.insert(transitions) + emitter_state = emitter_state.replace(replay_buffer=replay_buffer) + + # sample transitions from the replay buffer + random_key, subkey = jax.random.split(emitter_state.random_key) + transitions, random_key = replay_buffer.sample(subkey, self._config.num_critic_training_steps*self._config.batch_size) + transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_critic_training_steps, self._config.batch_size, *x.shape[1:])), transitions) + transitions = transitions.replace(rewards=self._similarity(transitions.desc, transitions.desc_prime)*transitions.rewards) + emitter_state = emitter_state.replace(random_key=random_key) + + def scan_train_critics( + carry: QualityDCGEmitterState, transitions + ) -> Tuple[QualityDCGEmitterState, Any]: + emitter_state = carry + new_emitter_state = self._train_critics(emitter_state, transitions) + return new_emitter_state, () + + # Train critics and greedy actor + emitter_state, _ = jax.lax.scan( + scan_train_critics, + emitter_state, + transitions, + length=self._config.num_critic_training_steps, + ) + + return emitter_state + + @partial(jax.jit, static_argnames=("self",)) + def _train_critics( + self, emitter_state: QualityDCGEmitterState, transitions + ) -> QualityDCGEmitterState: + """Apply one gradient step to critics and to the greedy actor + (contained in carry in training_state), then soft update target critics + and target actor. + + Those updates are very similar to those made in TD3. + + Args: + emitter_state: actual emitter state + + Returns: + New emitter state where the critic and the greedy actor have been + updated. Optimizer states have also been updated in the process. + """ + # Update Critic + ( + critic_opt_state, + critic_params, + target_critic_params, + random_key, + ) = self._update_critic( + critic_params=emitter_state.critic_params, + target_critic_params=emitter_state.target_critic_params, + target_actor_params=emitter_state.target_actor_params, + critic_opt_state=emitter_state.critic_opt_state, + transitions=transitions, + random_key=emitter_state.random_key, + ) + + # Update greedy actor + (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + emitter_state.steps % self._config.policy_delay == 0, + lambda x: self._update_actor(*x), + lambda _: ( + emitter_state.actor_opt_state, + emitter_state.actor_params, + emitter_state.target_actor_params, + ), + operand=( + emitter_state.actor_params, + emitter_state.actor_opt_state, + emitter_state.target_actor_params, + emitter_state.critic_params, + transitions, + ), + ) + + # Create new training state + new_emitter_state = emitter_state.replace( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + random_key=random_key, + steps=emitter_state.steps + 1, + ) + + return new_emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _update_critic( + self, + critic_params: Params, + target_critic_params: Params, + target_actor_params: Params, + critic_opt_state: Params, + transitions: DCGTransition, + random_key: RNGKey, + ) -> Tuple[Params, Params, Params, RNGKey]: + + # compute loss and gradients + random_key, subkey = jax.random.split(random_key) + critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_params, + target_actor_params, + target_critic_params, + transitions, + subkey, + ) + critic_updates, critic_opt_state = self._critic_optimizer.update( + critic_gradient, critic_opt_state + ) + + # update critic + critic_params = optax.apply_updates(critic_params, critic_updates) + + # Soft update of target critic network + target_critic_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_critic_params, + critic_params, + ) + + return critic_opt_state, critic_params, target_critic_params, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _update_actor( + self, + actor_params: Params, + actor_opt_state: optax.OptState, + target_actor_params: Params, + critic_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params, Params]: + + # Update greedy actor + policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + actor_params, + critic_params, + transitions, + ) + ( + policy_updates, + actor_opt_state, + ) = self._actor_optimizer.update(policy_gradient, actor_opt_state) + actor_params = optax.apply_updates(actor_params, policy_updates) + + # Soft update of target greedy actor + target_actor_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_actor_params, + actor_params, + ) + + return ( + actor_opt_state, + actor_params, + target_actor_params, + ) + + @partial(jax.jit, static_argnames=("self",),) + def _mutation_function_pg( + self, + policy_params: Genotype, + descs: Descriptor, + emitter_state: QualityDCGEmitterState, + ) -> Genotype: + """Apply pg mutation to a policy via multiple steps of gradient descent. + First, update the rewards to be diversity rewards, then apply the gradient + steps. + + Args: + policy_params: a policy, supposed to be a differentiable neural + network. + emitter_state: the current state of the emitter, containing among others, + the replay buffer, the critic. + + Returns: + The updated params of the neural network. + """ + # Get transitions + transitions, random_key = emitter_state.replay_buffer.sample(emitter_state.random_key, sample_size=self._config.num_pg_training_steps*self._config.batch_size) + descs_prime = jnp.tile(descs, (self._config.num_pg_training_steps*self._config.batch_size, 1)) + descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) + transitions = transitions.replace(rewards=self._similarity(transitions.desc, descs_prime_normalized)*transitions.rewards, desc_prime=descs_prime_normalized) + transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_pg_training_steps, self._config.batch_size, *x.shape[1:])), transitions) + + # Replace random_key + emitter_state = emitter_state.replace(random_key=random_key) + + # Define new policy optimizer state + policy_opt_state = self._policies_optimizer.init(policy_params) + + def scan_train_policy( + carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], + transitions, + ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + emitter_state, policy_params, policy_opt_state = carry + ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ) = self._train_policy( + emitter_state, + policy_params, + policy_opt_state, + transitions, + ) + return ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ), () + + (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + scan_train_policy, + (emitter_state, policy_params, policy_opt_state), + transitions, + length=self._config.num_pg_training_steps, + ) + + return policy_params + + @partial(jax.jit, static_argnames=("self",)) + def _train_policy( + self, + emitter_state: QualityDCGEmitterState, + policy_params: Params, + policy_opt_state: optax.OptState, + transitions, + ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + """Apply one gradient step to a policy (called policy_params). + + Args: + emitter_state: current state of the emitter. + policy_params: parameters corresponding to the weights and bias of + the neural network that defines the policy. + + Returns: + The new emitter state and new params of the NN. + """ + # update policy + policy_opt_state, policy_params = self._update_policy( + critic_params=emitter_state.critic_params, + policy_opt_state=policy_opt_state, + policy_params=policy_params, + transitions=transitions, + ) + + return emitter_state, policy_params, policy_opt_state + + @partial(jax.jit, static_argnames=("self",)) + def _update_policy( + self, + critic_params: Params, + policy_opt_state: optax.OptState, + policy_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params]: + + # compute loss + _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_params, + critic_params, + transitions, + ) + # Compute gradient and update policies + ( + policy_updates, + policy_opt_state, + ) = self._policies_optimizer.update(policy_gradient, policy_opt_state) + policy_params = optax.apply_updates(policy_params, policy_updates) + + return policy_opt_state, policy_params diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..d43827df 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -12,7 +12,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn @@ -119,12 +119,18 @@ def use_all_data(self) -> bool: return True def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[QualityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -144,8 +150,8 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -184,7 +190,7 @@ def init( ) def emit( self, - repertoire: Repertoire, + repertoire: MapElitesRepertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, RNGKey]: @@ -223,7 +229,7 @@ def emit( offspring_actor, ) - return genotypes, random_key + return genotypes, {}, random_key @partial( jax.jit, @@ -273,7 +279,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: def state_update( self, emitter_state: QualityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[MapElitesRepertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 8b877792..740aafa5 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -75,7 +75,7 @@ def emit( x_mutation, ) - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..77dd437e 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -23,7 +23,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independent of the GeneticAlgorithm class at the moment to keep + MAPElites class independant of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -52,7 +52,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -62,9 +62,9 @@ def init( such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tessellation centroids of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: @@ -73,12 +73,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MapElitesRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -87,14 +87,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -129,9 +124,10 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key @@ -147,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 42ed7552..3fe777d4 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor +from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor, Descriptor class Transition(flax.struct.PyTreeNode): @@ -262,6 +262,152 @@ def init_dummy( # type: ignore return dummy_transition +class DCGTransition(QDTransition): + """Stores data corresponding to a transition collected by a QD algorithm.""" + + desc: Descriptor + desc_prime: Descriptor + + @property + def descriptor_dim(self) -> int: + """ + Returns: + the dimension of the descriptors. + """ + return self.state_desc.shape[-1] # type: ignore + + @property + def flatten_dim(self) -> int: + """ + Returns: + the dimension of the transition once flattened. + """ + flatten_dim = ( + 2 * self.observation_dim + + self.action_dim + + 3 + + 2 * self.state_descriptor_dim + + 2 * self.descriptor_dim + ) + return flatten_dim + + def flatten(self) -> jnp.ndarray: + """ + Returns: + a jnp.ndarray that corresponds to the flattened transition. + """ + flatten_transition = jnp.concatenate( + [ + self.obs, + self.next_obs, + jnp.expand_dims(self.rewards, axis=-1), + jnp.expand_dims(self.dones, axis=-1), + jnp.expand_dims(self.truncations, axis=-1), + self.actions, + self.state_desc, + self.next_state_desc, + self.desc, + self.desc_prime, + ], + axis=-1, + ) + return flatten_transition + + @classmethod + def from_flatten( + cls, + flattened_transition: jnp.ndarray, + transition: QDTransition, + ) -> QDTransition: + """ + Creates a transition from a flattened transition in a jnp.ndarray. + Args: + flattened_transition: flattened transition in a jnp.ndarray of shape + (batch_size, flatten_dim) + transition: a transition object (might be a dummy one) to + get the dimensions right + Returns: + a Transition object + """ + obs_dim = transition.observation_dim + action_dim = transition.action_dim + state_desc_dim = transition.state_descriptor_dim + desc_dim = transition.descriptor_dim + + obs = flattened_transition[:, :obs_dim] + next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)] + rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)]) + dones = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)] + ) + truncations = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)] + ) + actions = flattened_transition[ + :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim) + ] + state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + state_desc_dim), + ] + next_state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + ), + ] + desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim + ), + ] + desc_prime = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + 2 * desc_dim + ), + ] + return cls( + obs=obs, + next_obs=next_obs, + rewards=rewards, + dones=dones, + truncations=truncations, + actions=actions, + state_desc=state_desc, + next_state_desc=next_state_desc, + desc=desc, + desc_prime=desc_prime, + ) + + @classmethod + def init_dummy( # type: ignore + cls, observation_dim: int, action_dim: int, descriptor_dim: int) -> QDTransition: + """ + Initialize a dummy transition that then can be passed to constructors to get + all shapes right. + Args: + observation_dim: observation dimension + action_dim: action dimension + Returns: + a dummy transition + """ + dummy_transition = DCGTransition( + obs=jnp.zeros(shape=(1, observation_dim)), + next_obs=jnp.zeros(shape=(1, observation_dim)), + rewards=jnp.zeros(shape=(1,)), + dones=jnp.zeros(shape=(1,)), + truncations=jnp.zeros(shape=(1,)), + actions=jnp.zeros(shape=(1, action_dim)), + state_desc=jnp.zeros(shape=(1, descriptor_dim)), + next_state_desc=jnp.zeros(shape=(1, descriptor_dim)), + desc=jnp.zeros(shape=(1, descriptor_dim)), + desc_prime=jnp.zeros(shape=(1, descriptor_dim)), + ) + return dummy_transition + + class ReplayBuffer(flax.struct.PyTreeNode): """ A replay buffer where transitions are flattened before being stored. diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 7f34a036..b360267c 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.types import Action, Observation, Descriptor, Params, RNGKey def make_td3_loss_fn( @@ -94,6 +94,97 @@ def _critic_loss_fn( return _policy_loss_fn, _critic_loss_fn +def make_td3_loss_dc_fn( + policy_fn: Callable[[Params, Observation], jnp.ndarray], + actor_fn: Callable[[Params, Observation, Descriptor], jnp.ndarray], + critic_fn: Callable[[Params, Observation, Action, Descriptor], jnp.ndarray], + reward_scaling: float, + discount: float, + noise_clip: float, + policy_noise: float, +) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], +]: + """Creates the loss functions for TD3. + Args: + policy_fn: forward pass through the neural network defining the policy. + actor_fn: forward pass through the neural network defining the + descriptor-conditioned policy. + critic_fn: forward pass through the neural network defining the + descriptor-conditioned critic. + reward_scaling: value to multiply the reward given by the environment. + discount: discount factor. + noise_clip: value that clips the noise to avoid extreme values. + policy_noise: noise applied to smooth the bootstrapping. + Returns: + Return the loss functions used to train the policy and the critic in TD3. + """ + + @jax.jit + def _policy_loss_fn( + policy_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Policy loss function for TD3 agent""" + action = policy_fn(policy_params, obs=transitions.obs) + q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _actor_loss_fn( + actor_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Descriptor-conditioned policy loss function for TD3 agent""" + action = actor_fn(actor_params, obs=transitions.obs, desc=transitions.desc_prime) + q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _critic_loss_fn( + critic_params: Params, + target_actor_params: Params, + target_critic_params: Params, + transitions: Transition, + random_key: RNGKey, + ) -> jnp.ndarray: + """Descriptor-conditioned critic loss function for TD3 agent""" + noise = ( + jax.random.normal(random_key, shape=transitions.actions.shape) + * policy_noise + ).clip(-noise_clip, noise_clip) + + next_action = ( + actor_fn(target_actor_params, obs=transitions.next_obs, desc=transitions.desc_prime) + noise + ).clip(-1.0, 1.0) + next_q = critic_fn(target_critic_params, obs=transitions.next_obs, actions=next_action, desc=transitions.desc_prime) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.rewards * reward_scaling + + (1.0 - transitions.dones) * discount * next_v + ) + q_old_action = critic_fn(critic_params, obs=transitions.obs, actions=transitions.actions, desc=transitions.desc_prime) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + + # Better bootstrapping for truncated episodes. + q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1) + + # compute the loss + q_losses = jnp.mean(jnp.square(q_error), axis=-2) + q_loss = jnp.sum(q_losses, axis=-1) + + return q_loss + + return _policy_loss_fn, _actor_loss_fn, _critic_loss_fn + + def td3_policy_loss_fn( policy_params: Params, critic_params: Params, diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 984d1aeb..fe936a57 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey +from qdax.types import Genotype, Params, RNGKey, Descriptor class TrainingState(PyTreeNode): @@ -67,6 +67,54 @@ def _scan_play_step_fn( return state, transitions +@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) +def generate_unroll_actor_dc( + init_state: EnvState, + actor_dc_params: Params, + desc: Descriptor, + random_key: RNGKey, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[ + EnvState, + Descriptor, + Params, + RNGKey, + Transition, + ], + ], +) -> Tuple[EnvState, Transition]: + """Generates an episode according to the agent's policy and descriptor, returns the final state of + the episode and the transitions of the episode. + + Args: + init_state: first state of the rollout. + policy_dc_params: descriptor-conditioned policy params. + desc: descriptor the policy attempts to achieve. + random_key: random key for stochasiticity handling. + episode_length: length of the rollout. + play_step_fn: function describing how a step need to be taken. + + Returns: + A new state, the experienced transition. + """ + + def _scan_play_step_fn( + carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any + ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: + env_state, actor_dc_params, desc, random_key, transitions = play_step_actor_dc_fn(*carry) + return (env_state, actor_dc_params, desc, random_key), transitions + + (state, _, _, _), transitions = jax.lax.scan( + _scan_play_step_fn, + (init_state, actor_dc_params, desc, random_key), + (), + length=episode_length, + ) + return state, transitions + + @jax.jit def get_first_episode(transition: Transition) -> Transition: """Extracts the first episode from a batch of transitions, returns the batch of diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index b2b176ef..d4b2ab3a 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -8,29 +8,47 @@ from brax.training import networks -class QModule(nn.Module): - """Q Module.""" - - hidden_layer_sizes: Tuple[int, ...] - n_critics: int = 2 +class MLP(nn.Module): + """MLP module.""" + layer_sizes: Tuple[int, ...] + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() + final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + bias: bool = True + kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: - hidden = jnp.concatenate([obs, actions], axis=-1) - res = [] - for _ in range(self.n_critics): - q = networks.MLP( - layer_sizes=self.hidden_layer_sizes + (1,), - activation=nn.relu, - kernel_init=jax.nn.initializers.lecun_uniform(), - )(hidden) - res.append(q) - return jnp.concatenate(res, axis=-1) + def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: + hidden = obs + for i, hidden_size in enumerate(self.layer_sizes): + if i != len(self.layer_sizes) - 1: + hidden = nn.Dense( + hidden_size, + kernel_init=self.kernel_init, + use_bias=self.bias, + )(hidden) + hidden = self.activation(hidden) # type: ignore -class MLP(nn.Module): - """MLP module.""" + else: + if self.kernel_init_final is not None: + kernel_init = self.kernel_init_final + else: + kernel_init = self.kernel_init + hidden = nn.Dense( + hidden_size, + kernel_init=kernel_init, + use_bias=self.bias, + )(hidden) + + if self.final_activation is not None: + hidden = self.final_activation(hidden) + + return hidden + +class MLPDC(nn.Module): + """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() @@ -39,15 +57,13 @@ class MLP(nn.Module): kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, data: jnp.ndarray) -> jnp.ndarray: - hidden = data + def __call__(self, obs: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, desc], axis=-1) for i, hidden_size in enumerate(self.layer_sizes): if i != len(self.layer_sizes) - 1: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", with this version of flax, changing the name - # changes the initialization kernel_init=self.kernel_init, use_bias=self.bias, )(hidden) @@ -61,7 +77,6 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", kernel_init=kernel_init, use_bias=self.bias, )(hidden) @@ -70,3 +85,42 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = self.final_activation(hidden) return hidden + + +class QModule(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + +class QModuleDC(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLPDC( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden, desc) + res.append(q) + return jnp.concatenate(res, axis=-1) diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 720f662a..ed1b4387 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -3,7 +3,7 @@ import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper +from brax.v1.envs import State, Wrapper, Env class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -69,3 +69,65 @@ def step(self, state: State, action: jp.ndarray) -> State: ) nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + +class ClipRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + +class AffineRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + +class OffsetRewardWrapper(Wrapper): + """Wraps gym environments to offset the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, offset=0.) -> None: + super().__init__(env) + self._offset = offset + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=state.reward + self._offset) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=state.reward + self._offset) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..845c329b 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -9,8 +9,8 @@ import qdax.environments from qdax import environments -from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( Descriptor, @@ -18,7 +18,6 @@ ExtraScores, Fitness, Genotype, - Observation, Params, RNGKey, ) @@ -83,15 +82,6 @@ def default_play_step_fn( return default_play_step_fn -def get_mask_from_transitions( - data: Transition, -) -> jnp.ndarray: - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) - return mask - - @partial( jax.jit, static_argnames=( @@ -144,7 +134,9 @@ def scoring_function_brax_envs( _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) # create a mask to extract data properly - mask = get_mask_from_transitions(data) + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) # scores fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) @@ -160,6 +152,78 @@ def scoring_function_brax_envs( ) +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. + descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll_actor_dc, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + @partial( jax.jit, static_argnames=( @@ -225,6 +289,72 @@ def reset_based_scoring_function_brax_envs( return fitnesses, descriptors, extra_scores, random_key +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. + descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0]) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + fitnesses, descriptors, extra_scores, random_key = scoring_actor_dc_function_brax_envs( + actors_dc_params=actors_dc_params, + descs=descs, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, @@ -275,10 +405,10 @@ def create_brax_scoring_fn( init_state = env.reset(subkey) # Define the function to deterministically reset the environment - def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: - return _init_state + def deterministic_reset(key: RNGKey, init_state: EnvState) -> EnvState: + return init_state - play_reset_fn = partial(deterministic_reset, _init_state=init_state) + play_reset_fn = partial(deterministic_reset, init_state=init_state) # Stochastic case elif play_reset_fn is None: @@ -348,36 +478,3 @@ def create_default_brax_task_components( ) return env, policy_network, scoring_fn, random_key - - -def get_aurora_scoring_fn( - scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] - ], - observation_extractor_fn: Callable[[Transition], Observation], -) -> Callable[ - [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] -]: - """Evaluates policies contained in flatten_variables in parallel - - This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarly - evaluated with the same environment everytime, this won't be determinist. - - When the init states are different, this is not purely stochastic. This - choice was made for performance reason, as the reset function of brax envs - is quite time-consuming. If pure stochasticity of the environment is needed - for a use case, please open an issue. - """ - - @functools.wraps(scoring_fn) - def _wrapper( - params: Params, random_key: RNGKey # Perform rollouts with each policy - ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: - fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) - data = extra_scores["transitions"] - observation = observation_extractor_fn(data) # type: ignore - extra_scores["last_valid_observations"] = observation - return fitnesses, None, extra_scores, random_key - - return _wrapper From 5fb2b65eefaea287a75d7dd6b0dc2f49c84fe6c5 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 21 Dec 2023 09:53:01 +0000 Subject: [PATCH 02/11] Fix brax_envs.py --- qdax/tasks/brax_envs.py | 55 +++++++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 845c329b..721567a1 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -9,7 +9,7 @@ import qdax.environments from qdax import environments -from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( @@ -18,6 +18,7 @@ ExtraScores, Fitness, Genotype, + Observation, Params, RNGKey, ) @@ -82,6 +83,15 @@ def default_play_step_fn( return default_play_step_fn +def get_mask_from_transitions( + data: Transition, +) -> jnp.ndarray: + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + return mask + + @partial( jax.jit, static_argnames=( @@ -134,9 +144,7 @@ def scoring_function_brax_envs( _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) + mask = get_mask_from_transitions(data) # scores fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) @@ -405,10 +413,10 @@ def create_brax_scoring_fn( init_state = env.reset(subkey) # Define the function to deterministically reset the environment - def deterministic_reset(key: RNGKey, init_state: EnvState) -> EnvState: - return init_state + def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: + return _init_state - play_reset_fn = partial(deterministic_reset, init_state=init_state) + play_reset_fn = partial(deterministic_reset, _init_state=init_state) # Stochastic case elif play_reset_fn is None: @@ -478,3 +486,36 @@ def create_default_brax_task_components( ) return env, policy_network, scoring_fn, random_key + + +def get_aurora_scoring_fn( + scoring_fn: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], + observation_extractor_fn: Callable[[Transition], Observation], +) -> Callable[ + [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] +]: + """Evaluates policies contained in flatten_variables in parallel + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarly + evaluated with the same environment everytime, this won't be determinist. + + When the init states are different, this is not purely stochastic. This + choice was made for performance reason, as the reset function of brax envs + is quite time-consuming. If pure stochasticity of the environment is needed + for a use case, please open an issue. + """ + + @functools.wraps(scoring_fn) + def _wrapper( + params: Params, random_key: RNGKey # Perform rollouts with each policy + ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: + fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) + data = extra_scores["transitions"] + observation = observation_extractor_fn(data) # type: ignore + extra_scores["last_valid_observations"] = observation + return fitnesses, None, extra_scores, random_key + + return _wrapper From dd729cfb081fe26ceae2ba813bc5ba43ed7fc6c8 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Fri, 29 Dec 2023 17:50:56 +0000 Subject: [PATCH 03/11] Fix PGA-MAP-Elites --- qdax/core/emitters/qpg_emitter.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index d43827df..a0b5c62d 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -12,7 +12,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import MapElitesRepertoire +from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn @@ -121,7 +121,7 @@ def use_all_data(self) -> bool: def init( self, random_key: RNGKey, - repertoire: MapElitesRepertoire, + repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -168,6 +168,13 @@ def init( buffer_size=self._config.replay_buffer_size, transition=dummy_transition ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + # add transitions in the replay buffer + replay_buffer = replay_buffer.insert(transitions) + # Initial training state random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( @@ -177,9 +184,9 @@ def init( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, + replay_buffer=replay_buffer, random_key=subkey, steps=jnp.array(0), - replay_buffer=replay_buffer, ) return emitter_state, random_key @@ -190,7 +197,7 @@ def init( ) def emit( self, - repertoire: MapElitesRepertoire, + repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, RNGKey]: @@ -279,7 +286,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: def state_update( self, emitter_state: QualityPGEmitterState, - repertoire: Optional[MapElitesRepertoire], + repertoire: Optional[Repertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], From f7f277f2fa120bbae0e8daa8f4064ca8d368aad9 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Fri, 29 Dec 2023 18:43:57 +0000 Subject: [PATCH 04/11] Fix QD-PG --- qdax/core/emitters/dpg_emitter.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index c266be10..8bc223df 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -10,7 +10,7 @@ import optax from qdax.core.containers.archive import Archive -from qdax.core.containers.repertoire import MapElitesRepertoire +from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.qpg_emitter import ( QualityPGConfig, QualityPGEmitter, @@ -79,7 +79,7 @@ def __init__( def init( self, random_key: RNGKey, - repertoire: MapElitesRepertoire, + repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, @@ -96,7 +96,13 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(genotypes, random_key) + diversity_emitter_state, random_key = super().init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores,) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -108,6 +114,12 @@ def init( max_size=self._config.archive_max_size, ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + archive = archive.insert(transitions.state_desc) + # init emitter state emitter_state = DiversityPGEmitterState( # retrieve all attributes from the QualityPGEmitterState @@ -122,7 +134,7 @@ def init( def state_update( self, emitter_state: DiversityPGEmitterState, - repertoire: Optional[MapElitesRepertoire], + repertoire: Optional[Repertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], From c3aa1112922a2c0f6fef06a975bc19a159edac64 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Fri, 29 Dec 2023 21:05:20 +0000 Subject: [PATCH 05/11] Rename init_genotypes into genotypes --- examples/distributed_mapelites.ipynb | 2 +- examples/me_sac_pbt.ipynb | 2 +- examples/me_td3_pbt.ipynb | 2 +- examples/mome.ipynb | 4 ++-- examples/nsga2_spea2.ipynb | 6 +++--- qdax/baselines/genetic_algorithm.py | 12 ++++++------ qdax/baselines/nsga2.py | 10 +++++----- qdax/baselines/spea2.py | 10 +++++----- qdax/core/aurora.py | 12 ++++++------ qdax/core/emitters/cma_mega_emitter.py | 2 +- qdax/core/emitters/cma_rnd_emitter.py | 4 ++-- qdax/core/mels.py | 12 ++++++------ qdax/core/mome.py | 12 ++++++------ requirements.txt | 5 +++-- tests/baselines_test/ga_test.py | 6 +++--- tests/baselines_test/me_pbt_sac_test.py | 2 +- tests/baselines_test/me_pbt_td3_test.py | 2 +- tests/core_test/mome_test.py | 4 ++-- 18 files changed, 55 insertions(+), 54 deletions(-) diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 434725a3..18d4f0f3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -348,7 +348,7 @@ "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(init_genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, random_key=random_key)" ] }, { diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 6d4dfdfe..6b4ae0b5 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -311,7 +311,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 1cf17c5e..ca127e72 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -314,7 +314,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/mome.ipynb b/examples/mome.ipynb index a4ca36a6..05387158 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -212,7 +212,7 @@ "# initial population\n", "random_key = jax.random.PRNGKey(42)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -303,7 +303,7 @@ "outputs": [], "source": [ "repertoire, emitter_state, random_key = mome.init(\n", - " init_genotypes,\n", + " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", " random_key\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 51c5f5bd..e10c0d91 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -189,7 +189,7 @@ "# Initial population\n", "random_key = jax.random.PRNGKey(0)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -238,7 +238,7 @@ "\n", "# init nsga2\n", "repertoire, emitter_state, random_key = nsga2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " random_key\n", ")" @@ -303,7 +303,7 @@ "\n", "# init spea2\n", "repertoire, emitter_state, random_key = spea2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " num_neighbours,\n", " random_key\n", diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 0714fb6c..10b2e6d7 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -39,12 +39,12 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: """Initialize a GARepertoire with an initial population of genotypes. Args: - init_genotypes: the initial population of genotypes + genotypes: the initial population of genotypes population_size: the maximal size of the repertoire random_key: a random key to handle stochastic operations @@ -54,26 +54,26 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = GARepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=None, extra_scores=extra_scores, diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index 663d6f0e..cf72257d 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -28,31 +28,31 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = NSGA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 72ec2791..f0c996db 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -40,7 +40,7 @@ class SPEA2(GeneticAlgorithm): ) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, population_size: int, num_neighbours: int, random_key: RNGKey, @@ -48,12 +48,12 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = SPEA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, num_neighbours=num_neighbours, @@ -61,14 +61,14 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index fed716e3..edd80240 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -118,7 +118,7 @@ def container_size_control( def init( self, - init_genotypes: Genotype, + genotypes: Genotype, aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, @@ -128,7 +128,7 @@ def init( genotypes. Also performs the first training of the AURORA encoder. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) aurora_extra_info: information to perform AURORA encodings, such as the encoder parameters @@ -141,7 +141,7 @@ def init( the emitter, and the updated information to perform AURORA encodings """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, + genotypes, random_key, ) @@ -150,7 +150,7 @@ def init( descriptors = self._encoder_fn(observations, aurora_extra_info) repertoire = UnstructuredRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, observations=observations, @@ -160,13 +160,13 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f79579dd..1fd0e1e6 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -153,7 +153,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 4afb2f5d..2b9928a0 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -35,14 +35,14 @@ class CMARndEmitterState(CMAEmitterState): class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, genotypes: Genotype, random_key: RNGKey ) -> Tuple[CMARndEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6c06b785..3969e6d0 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -55,7 +55,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: @@ -64,7 +64,7 @@ def init( be computed with any method such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -75,12 +75,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MELSRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -89,14 +89,14 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index 2a004f59..98693e5a 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -23,7 +23,7 @@ class MOME(MAPElites): @partial(jax.jit, static_argnames=("self", "pareto_front_max_length")) def init( self, - init_genotypes: jnp.ndarray, + genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, random_key: RNGKey, @@ -33,7 +33,7 @@ def init( CVT or Euclidean mapping. Args: - init_genotypes: genotypes of the initial population. + genotypes: genotypes of the initial population. centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. @@ -45,12 +45,12 @@ def init( # first score fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MOMERepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -60,14 +60,14 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + genotypes=genotypes, random_key=random_key ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/requirements.txt b/requirements.txt index 978a1c87..50d7899c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + absl-py==1.0.0 brax==0.9.2 chex==0.1.83 @@ -5,8 +7,7 @@ dm-haiku==0.0.10 flax==0.7.4 gym==0.26.2 ipython -jax==0.4.16 -jaxlib==0.4.16 +jax[cuda12_pip] jumanji==0.3.1 jupyter numpy==1.24.1 diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..5f9ec5f7 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -73,7 +73,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), minval=minval, @@ -111,11 +111,11 @@ def scoring_fn( if isinstance(algo_instance, SPEA2): repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, num_neighbours, random_key + genotypes, population_size, num_neighbours, random_key ) else: repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, random_key + genotypes, population_size, random_key ) # Run the algorithm diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 079fde45..c4ab259e 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -178,7 +178,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 5c6fbb0a..510743c1 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -176,7 +176,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..103f9489 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -81,7 +81,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, num_variables), minval=minval, @@ -127,7 +127,7 @@ def scoring_fn( ) repertoire, emitter_state, random_key = mome.init( - init_genotypes, centroids, pareto_front_max_length, random_key + genotypes, centroids, pareto_front_max_length, random_key ) # Run the algorithm From 5278ada81fece40f49ea579afef7393cadbfa682 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Fri, 29 Dec 2023 21:33:19 +0000 Subject: [PATCH 06/11] Fix pre-commit errors --- qdax/baselines/genetic_algorithm.py | 11 +- qdax/baselines/nsga2.py | 7 +- qdax/baselines/spea2.py | 7 +- qdax/core/aurora.py | 14 +- qdax/core/containers/mapelites_repertoire.py | 6 +- qdax/core/distributed_map_elites.py | 21 ++- qdax/core/emitters/cma_emitter.py | 2 +- qdax/core/emitters/cma_pool_emitter.py | 2 +- qdax/core/emitters/emitter.py | 4 +- qdax/core/emitters/mees_emitter.py | 2 +- qdax/core/emitters/multi_emitter.py | 17 +- qdax/core/emitters/omg_mega_emitter.py | 2 +- qdax/core/emitters/pbt_me_emitter.py | 2 +- qdax/core/emitters/qdcg_emitter.py | 151 ++++++++++++------ qdax/core/emitters/qpg_emitter.py | 2 +- qdax/core/emitters/standard_emitters.py | 4 +- qdax/core/map_elites.py | 2 +- qdax/core/mels.py | 7 +- qdax/core/mome.py | 7 +- qdax/core/neuroevolution/buffers/buffer.py | 19 ++- qdax/core/neuroevolution/losses/td3_loss.py | 31 ++-- qdax/core/neuroevolution/mdp_utils.py | 14 +- qdax/core/neuroevolution/networks/networks.py | 10 +- qdax/environments/base_wrappers.py | 11 +- qdax/environments/wrappers.py | 32 +++- qdax/tasks/brax_envs.py | 31 ++-- 26 files changed, 283 insertions(+), 135 deletions(-) diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 10b2e6d7..c8891f11 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -66,12 +66,7 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, @@ -108,7 +103,7 @@ def update( """ # generate offsprings - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) @@ -127,7 +122,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=None, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index cf72257d..a889eadc 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -45,7 +45,12 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index f0c996db..c52063b6 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -61,7 +61,12 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index edd80240..a0968ccc 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -160,12 +160,8 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, + repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, @@ -208,9 +204,10 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, @@ -232,10 +229,11 @@ def update( # update emitter state after scoring is made emitter_state = self._emitter.state_update( emitter_state=emitter_state, + repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index 5968b03f..b1145c34 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -229,7 +229,11 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key @partial(jax.jit, static_argnames=("num_samples",)) - def sample_with_descs(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + def sample_with_descs( + self, + random_key: RNGKey, + num_samples: int, + ) -> Tuple[Genotype, Descriptor, RNGKey]: """Sample elements in the repertoire. Args: diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index c8a1ea44..b6d116a4 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -17,7 +17,7 @@ class DistributedMAPElites(MAPElites): @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -30,7 +30,7 @@ def init( devices. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -41,7 +41,7 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # gather across all devices @@ -51,7 +51,7 @@ def init( gathered_descriptors, ) = jax.tree_util.tree_map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), - (init_genotypes, fitnesses, descriptors), + (genotypes, fitnesses, descriptors), ) # init the repertoire @@ -64,14 +64,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -108,7 +113,7 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) # scores the offsprings @@ -138,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index c090e448..66e5677a 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -141,7 +141,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 67034d71..d5af2181 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -97,7 +97,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index 14c6277a..056798ba 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -43,7 +43,7 @@ def init( outputted. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -57,7 +57,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index 9641a613..0a03a6ba 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -306,7 +306,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Return the offspring generated through gradient update. Params: diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index a0789f33..e59e2037 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -82,7 +82,13 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(subkey_emitter, repertoire, genotypes, fitnesses, descriptors, extra_scores) + emitter_state, _ = emitter.init( + subkey_emitter, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -93,7 +99,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. @@ -114,13 +120,16 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] - all_extra_info = {} + all_extra_info: ExtraScores = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, extra_info, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit( + repertoire, + sub_emitter_state, + subkey_emitter) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 2380a85c..7a480e06 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -132,7 +132,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 64c05f16..a2266bfa 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -172,7 +172,7 @@ def emit( repertoire: Repertoire, emitter_state: PBTEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index 745773bd..b2b782a0 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -1,15 +1,15 @@ -"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm in JAX for Brax environments. +"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm +in JAX for Brax environments. """ from dataclasses import dataclass from functools import partial -from typing import Any, Optional, Tuple, Callable +from typing import Any, Tuple -import jax -from jax import numpy as jnp import flax.linen as nn -from flax.core.frozen_dict import freeze +import jax import optax +from jax import numpy as jnp from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState @@ -83,7 +83,11 @@ def __init__( self._critic_network = critic_network # Set up the losses and optimizers - return the opt states - self._policy_loss_fn, self._actor_loss_fn, self._critic_loss_fn = make_td3_loss_dc_fn( + ( + self._policy_loss_fn, + self._actor_loss_fn, + self._critic_loss_fn, + ) = make_td3_loss_dc_fn( policy_fn=policy_network.apply, actor_fn=actor_network.apply, critic_fn=critic_network.apply, @@ -153,7 +157,8 @@ def init( critic_params = self._critic_network.init( subkey, obs=fake_obs, actions=fake_action, desc=fake_desc ) - target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + target_critic_params = jax.tree_util.tree_map( + lambda x: x, critic_params) random_key, subkey = jax.random.split(random_key) actor_params = self._actor_network.init( @@ -179,10 +184,12 @@ def init( transitions = extra_scores["transitions"] episode_length = transitions.obs.shape[1] - desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat( + descriptors[:, jnp.newaxis, :], episode_length, axis=1) desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) - transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_normalized) + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_normalized) replay_buffer = replay_buffer.insert(transitions) # Initial training state @@ -202,27 +209,36 @@ def init( return emitter_state, random_key @partial(jax.jit, static_argnames=("self",)) - def _similarity(self, descs_1, descs_2): + def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: """Compute the similarity between two batches of descriptors. Args: - descs_1: batch of descriptors, representing the observed descriptors of the trajectories. - descs_2: batch of descriptors, representing the sampled descriptors. + descs_1: batch of descriptors. + descs_2: batch of descriptors. Returns: batch of similarity measures. """ - return jnp.exp(-jnp.linalg.norm(descs_1 - descs_2, axis=-1)/self._config.lengthscale) + return jnp.exp(-jnp.linalg.norm( + descs_1 - descs_2, axis=-1)/self._config.lengthscale) @partial(jax.jit, static_argnames=("self",)) - def _normalize_desc(self, desc): - return 2*(desc - self._env.behavior_descriptor_limits[0])/(self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) - 1 + def _normalize_desc(self, desc: Descriptor) -> Descriptor: + return 2*(desc - self._env.behavior_descriptor_limits[0])/( + self._env.behavior_descriptor_limits[1] - + self._env.behavior_descriptor_limits[0]) - 1 @partial(jax.jit, static_argnames=("self",)) - def _unnormalize_desc(self, desc_normalized): - return 0.5 * (self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) * desc_normalized + \ - 0.5 * (self._env.behavior_descriptor_limits[1] + self._env.behavior_descriptor_limits[0]) + def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: + return 0.5 * (self._env.behavior_descriptor_limits[1] - + self._env.behavior_descriptor_limits[0]) * desc_normalized + \ + 0.5 * (self._env.behavior_descriptor_limits[1] + + self._env.behavior_descriptor_limits[0]) @partial(jax.jit, static_argnames=("self",)) - def _compute_equivalent_kernel_bias_with_desc(self, actor_dc_params, desc): + def _compute_equivalent_kernel_bias_with_desc( + self, + actor_dc_params: Params, + desc: Descriptor + ) -> Tuple[Params, Params]: """ Compute the equivalent bias of the first layer of the actor network given a descriptor. @@ -238,9 +254,17 @@ def _compute_equivalent_kernel_bias_with_desc(self, actor_dc_params, desc): return equivalent_kernel, equivalent_bias @partial(jax.jit, static_argnames=("self",)) - def _compute_equivalent_params_with_desc(self, actor_dc_params, desc): + def _compute_equivalent_params_with_desc( + self, + actor_dc_params: Params, + desc: Descriptor + ) -> Params: desc_normalized = self._normalize_desc(desc) - equivalent_kernel, equivalent_bias = self._compute_equivalent_kernel_bias_with_desc(actor_dc_params, desc_normalized) + ( + equivalent_kernel, + equivalent_bias, + ) = self._compute_equivalent_kernel_bias_with_desc( + actor_dc_params, desc_normalized) actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias return actor_dc_params @@ -251,7 +275,7 @@ def emit( repertoire: Repertoire, emitter_state: QualityDCGEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: @@ -263,22 +287,31 @@ def emit( A batch of offspring, the new emitter state and a new key. """ # PG emitter - parents_pg, descs_pg, random_key = repertoire.sample_with_descs(random_key, self._config.qpg_batch_size) + parents_pg, descs_pg, random_key = repertoire.sample_with_descs( + random_key, self._config.qpg_batch_size) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, random_key = repertoire.sample_with_descs(random_key, self._config.ai_batch_size) - descs_ai = descs_ai.reshape(descs_ai.shape[0], self._env.behavior_descriptor_length) + _, descs_ai, random_key = repertoire.sample_with_descs( + random_key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape( + descs_ai.shape[0], self._env.behavior_descriptor_length) genotypes_ai = self.emit_ai(emitter_state, descs_ai) # Concatenate PG and AI genotypes - genotypes = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai) + genotypes = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate( + (x1, x2), axis=0), genotypes_pg, genotypes_ai) - return genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, random_key + return genotypes, { + "desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, random_key @partial(jax.jit, static_argnames=("self",),) def emit_pg( - self, emitter_state: QualityDCGEmitterState, parents: Genotype, descs: Descriptor) -> Genotype: + self, + emitter_state: QualityDCGEmitterState, + parents: Genotype, + descs: Descriptor + ) -> Genotype: """Emit the offsprings generated through pg mutation. Args: @@ -315,7 +348,8 @@ def emit_ai( Returns: A new set of offsprings. """ - offsprings = jax.vmap(self._compute_equivalent_params_with_desc, in_axes=(None, 0))(emitter_state.actor_params, descs) + offsprings = jax.vmap(self._compute_equivalent_params_with_desc, in_axes=( + None, 0))(emitter_state.actor_params, descs) return offsprings @@ -369,13 +403,20 @@ def state_update( transitions = extra_scores["transitions"] episode_length = transitions.obs.shape[1] - desc_prime = jnp.concatenate([extra_scores["desc_prime"], descriptors[self._config.qpg_batch_size+self._config.ai_batch_size:]], axis=0) - desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) - desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) - - desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_prime = jnp.concatenate( + [extra_scores["desc_prime"], + descriptors[self._config.qpg_batch_size+self._config.ai_batch_size:]], + axis=0) + desc_prime = jnp.repeat( + desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat( + descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap( + jax.vmap(self._normalize_desc))(desc_prime) desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) - transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_prime_normalized) + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_prime_normalized) # Add transitions to replay buffer replay_buffer = emitter_state.replay_buffer.insert(transitions) @@ -383,13 +424,19 @@ def state_update( # sample transitions from the replay buffer random_key, subkey = jax.random.split(emitter_state.random_key) - transitions, random_key = replay_buffer.sample(subkey, self._config.num_critic_training_steps*self._config.batch_size) - transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_critic_training_steps, self._config.batch_size, *x.shape[1:])), transitions) - transitions = transitions.replace(rewards=self._similarity(transitions.desc, transitions.desc_prime)*transitions.rewards) + transitions, random_key = replay_buffer.sample( + subkey, self._config.num_critic_training_steps*self._config.batch_size) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, + (self._config.num_critic_training_steps, + self._config.batch_size, *x.shape[1:])), + transitions) + transitions = transitions.replace(rewards=self._similarity( + transitions.desc, transitions.desc_prime)*transitions.rewards) emitter_state = emitter_state.replace(random_key=random_key) def scan_train_critics( - carry: QualityDCGEmitterState, transitions + carry: QualityDCGEmitterState, transitions: DCGTransition, ) -> Tuple[QualityDCGEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state, transitions) @@ -407,7 +454,7 @@ def scan_train_critics( @partial(jax.jit, static_argnames=("self",)) def _train_critics( - self, emitter_state: QualityDCGEmitterState, transitions + self, emitter_state: QualityDCGEmitterState, transitions: DCGTransition ) -> QualityDCGEmitterState: """Apply one gradient step to critics and to the greedy actor (contained in carry in training_state), then soft update target critics @@ -563,11 +610,23 @@ def _mutation_function_pg( The updated params of the neural network. """ # Get transitions - transitions, random_key = emitter_state.replay_buffer.sample(emitter_state.random_key, sample_size=self._config.num_pg_training_steps*self._config.batch_size) - descs_prime = jnp.tile(descs, (self._config.num_pg_training_steps*self._config.batch_size, 1)) + transitions, random_key = emitter_state.replay_buffer.sample( + emitter_state.random_key, + sample_size=self._config.num_pg_training_steps*self._config.batch_size) + descs_prime = jnp.tile( + descs, (self._config.num_pg_training_steps*self._config.batch_size, 1)) descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) - transitions = transitions.replace(rewards=self._similarity(transitions.desc, descs_prime_normalized)*transitions.rewards, desc_prime=descs_prime_normalized) - transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_pg_training_steps, self._config.batch_size, *x.shape[1:])), transitions) + transitions = transitions.replace( + rewards=self._similarity( + transitions.desc, + descs_prime_normalized)*transitions.rewards, + desc_prime=descs_prime_normalized) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + (self._config.num_pg_training_steps, + self._config.batch_size, *x.shape[1:])), + transitions) # Replace random_key emitter_state = emitter_state.replace(random_key=random_key) @@ -577,7 +636,7 @@ def _mutation_function_pg( def scan_train_policy( carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], - transitions, + transitions: DCGTransition, ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: emitter_state, policy_params, policy_opt_state = carry ( @@ -611,7 +670,7 @@ def _train_policy( emitter_state: QualityDCGEmitterState, policy_params: Params, policy_opt_state: optax.OptState, - transitions, + transitions: DCGTransition, ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: """Apply one gradient step to a policy (called policy_params). diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index a0b5c62d..4a173b51 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -200,7 +200,7 @@ def emit( repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 740aafa5..860962d4 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Genotype, RNGKey +from qdax.types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): @@ -31,7 +31,7 @@ def emit( repertoire: Repertoire, emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index 77dd437e..a12fe2a0 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -87,7 +87,7 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - random_key, + random_key=random_key, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 3969e6d0..6dc8f551 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -89,7 +89,12 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state diff --git a/qdax/core/mome.py b/qdax/core/mome.py index 98693e5a..db450b9a 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -60,12 +60,7 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - genotypes=genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 3fe777d4..a80fe739 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,15 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor, Descriptor +from qdax.types import ( + Action, + Descriptor, + Done, + Observation, + Reward, + RNGKey, + StateDescriptor, +) class Transition(flax.struct.PyTreeNode): @@ -348,7 +356,8 @@ def from_flatten( ] state_desc = flattened_transition[ :, - (2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + state_desc_dim), + (2 * obs_dim + 3 + action_dim) : ( + 2 * obs_dim + 3 + action_dim + state_desc_dim), ] next_state_desc = flattened_transition[ :, @@ -383,7 +392,11 @@ def from_flatten( @classmethod def init_dummy( # type: ignore - cls, observation_dim: int, action_dim: int, descriptor_dim: int) -> QDTransition: + cls, + observation_dim: int, + action_dim: int, + descriptor_dim: int + ) -> QDTransition: """ Initialize a dummy transition that then can be passed to constructors to get all shapes right. diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index b360267c..9866a9cd 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Descriptor, Params, RNGKey +from qdax.types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( @@ -103,6 +103,7 @@ def make_td3_loss_dc_fn( noise_clip: float, policy_noise: float, ) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], Callable[[Params, Params, Transition], jnp.ndarray], Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], ]: @@ -128,8 +129,12 @@ def _policy_loss_fn( transitions: Transition, ) -> jnp.ndarray: """Policy loss function for TD3 agent""" - action = policy_fn(policy_params, obs=transitions.obs) - q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + action = policy_fn(policy_params, transitions.obs) + q_value = critic_fn( + critic_params, + transitions.obs, + action, + transitions.desc_prime) q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss @@ -141,8 +146,13 @@ def _actor_loss_fn( transitions: Transition, ) -> jnp.ndarray: """Descriptor-conditioned policy loss function for TD3 agent""" - action = actor_fn(actor_params, obs=transitions.obs, desc=transitions.desc_prime) - q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + action = actor_fn(actor_params, transitions.obs, transitions.desc_prime) + q_value = critic_fn( + critic_params, + transitions.obs, + action, + transitions.desc_prime + ) q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss @@ -162,15 +172,18 @@ def _critic_loss_fn( ).clip(-noise_clip, noise_clip) next_action = ( - actor_fn(target_actor_params, obs=transitions.next_obs, desc=transitions.desc_prime) + noise + actor_fn(target_actor_params, transitions.next_obs, + transitions.desc_prime) + noise ).clip(-1.0, 1.0) - next_q = critic_fn(target_critic_params, obs=transitions.next_obs, actions=next_action, desc=transitions.desc_prime) + next_q = critic_fn(target_critic_params, transitions.next_obs, + next_action, transitions.desc_prime) next_v = jnp.min(next_q, axis=-1) target_q = jax.lax.stop_gradient( transitions.rewards * reward_scaling + (1.0 - transitions.dones) * discount * next_v ) - q_old_action = critic_fn(critic_params, obs=transitions.obs, actions=transitions.actions, desc=transitions.desc_prime) + q_old_action = critic_fn(critic_params, transitions.obs, + transitions.actions, transitions.desc_prime) q_error = q_old_action - jnp.expand_dims(target_q, -1) # Better bootstrapping for truncated episodes. @@ -207,7 +220,7 @@ def td3_policy_loss_fn( action = policy_fn(policy_params, transitions.obs) q_value = critic_fn( - critic_params, obs=transitions.obs, actions=action # type: ignore + critic_params, transitions.obs, action # type: ignore ) q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index fe936a57..3b077069 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey, Descriptor +from qdax.types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -85,8 +85,8 @@ def generate_unroll_actor_dc( ], ], ) -> Tuple[EnvState, Transition]: - """Generates an episode according to the agent's policy and descriptor, returns the final state of - the episode and the transitions of the episode. + """Generates an episode according to the agent's policy and descriptor, + returns the final state of the episode and the transitions of the episode. Args: init_state: first state of the rollout. @@ -103,7 +103,13 @@ def generate_unroll_actor_dc( def _scan_play_step_fn( carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: - env_state, actor_dc_params, desc, random_key, transitions = play_step_actor_dc_fn(*carry) + ( + env_state, + actor_dc_params, + desc, + random_key, + transitions, + ) = play_step_actor_dc_fn(*carry) return (env_state, actor_dc_params, desc, random_key), transitions (state, _, _, _), transitions = jax.lax.scan( diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index d4b2ab3a..fea7c1ac 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -5,7 +5,6 @@ import flax.linen as nn import jax import jax.numpy as jnp -from brax.training import networks class MLP(nn.Module): @@ -47,6 +46,7 @@ def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: return hidden + class MLPDC(nn.Module): """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] @@ -106,6 +106,7 @@ def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: res.append(q) return jnp.concatenate(res, axis=-1) + class QModuleDC(nn.Module): """Q Module.""" @@ -113,7 +114,12 @@ class QModuleDC(nn.Module): n_critics: int = 2 @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + def __call__( + self, + obs: jnp.ndarray, + actions: jnp.ndarray, + desc: jnp.ndarray + ) -> jnp.ndarray: hidden = jnp.concatenate([obs, actions], axis=-1) res = [] for _ in range(self.n_critics): diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 6f317e7f..3f709fa7 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -1,6 +1,7 @@ from abc import abstractmethod -from typing import Any, List, Tuple +from typing import Any, Tuple +import jax from brax.v1 import jumpy as jp from brax.v1.envs import Env, State @@ -22,7 +23,7 @@ def state_descriptor_name(self) -> str: @property @abstractmethod - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -32,7 +33,7 @@ def behavior_descriptor_length(self) -> int: @property @abstractmethod - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -71,7 +72,7 @@ def state_descriptor_name(self) -> str: return self.env.state_descriptor_name @property - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.state_descriptor_limits @property @@ -79,7 +80,7 @@ def behavior_descriptor_length(self) -> int: return self.env.behavior_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.behavior_descriptor_limits @property diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index ed1b4387..1aa7bcab 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -3,7 +3,7 @@ import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper, Env +from brax.v1.envs import Env, State, Wrapper class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -70,6 +70,7 @@ def step(self, state: State, action: jp.ndarray) -> State: nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + class ClipRewardWrapper(Wrapper): """Wraps gym environments to clip the reward to be greater than 0. @@ -78,18 +79,26 @@ class ClipRewardWrapper(Wrapper): work like before and will simply clip the reward to be greater than 0. """ - def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + def __init__( + self, + env: Env, + clip_min: float = None, + clip_max: float = None + ) -> None: super().__init__(env) self._clip_min = clip_min self._clip_max = clip_max def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) - return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) - return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + class AffineRewardWrapper(Wrapper): """Wraps gym environments to clip the reward. @@ -99,18 +108,25 @@ class AffineRewardWrapper(Wrapper): work like before and will simply clip the reward to be greater than 0. """ - def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + def __init__( + self, env: Env, + clip_min: float = None, + clip_max: float = None + ) -> None: super().__init__(env) self._clip_min = clip_min self._clip_max = clip_max def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) - return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) - return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + class OffsetRewardWrapper(Wrapper): """Wraps gym environments to offset the reward to be greater than 0. @@ -120,7 +136,7 @@ class OffsetRewardWrapper(Wrapper): work like before and will simply clip the reward to be greater than 0. """ - def __init__(self, env: Env, offset=0.) -> None: + def __init__(self, env: Env, offset: float = 0.) -> None: super().__init__(env) self._offset = offset diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 721567a1..5d7eda5a 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -175,7 +175,8 @@ def scoring_actor_dc_function_brax_envs( init_states: EnvState, episode_length: int, play_step_actor_dc_fn: Callable[ - [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] ], behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: @@ -188,8 +189,10 @@ def scoring_actor_dc_function_brax_envs( When the init states are different, this is not purely stochastic. Args: - policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. - descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. random_key: A jax random key episode_length: The maximal rollout length. play_step_fn: The function to play a step of the environment. @@ -313,7 +316,8 @@ def reset_based_scoring_actor_dc_function_brax_envs( episode_length: int, play_reset_fn: Callable[[RNGKey], EnvState], play_step_actor_dc_fn: Callable[ - [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] ], behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: @@ -330,11 +334,14 @@ def reset_based_scoring_actor_dc_function_brax_envs( "play_reset_fn = lambda random_key: init_state". Args: - policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. - descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. random_key: A jax random key episode_length: The maximal rollout length. - play_reset_fn: The function to reset the environment and obtain initial states. + play_reset_fn: The function to reset the environment + and obtain initial states. play_step_fn: The function to play a step of the environment. behavior_descriptor_extractor: The function to extract the behavior descriptor. @@ -346,11 +353,17 @@ def reset_based_scoring_actor_dc_function_brax_envs( """ random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0]) + keys = jax.random.split(subkey, + jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0]) reset_fn = jax.vmap(play_reset_fn) init_states = reset_fn(keys) - fitnesses, descriptors, extra_scores, random_key = scoring_actor_dc_function_brax_envs( + ( + fitnesses, + descriptors, + extra_scores, + random_key, + ) = scoring_actor_dc_function_brax_envs( actors_dc_params=actors_dc_params, descs=descs, random_key=random_key, From dd72df731230f570035df5efc23489a312069bad Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Mon, 8 Jan 2024 16:56:11 +0000 Subject: [PATCH 07/11] Solve pre-commit errors --- qdax/baselines/genetic_algorithm.py | 2 +- qdax/core/distributed_map_elites.py | 2 +- qdax/core/emitters/multi_emitter.py | 2 +- qdax/core/emitters/qdcg_emitter.py | 2 +- qdax/core/map_elites.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index c8891f11..a01a13b1 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -122,7 +122,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=None, - extra_scores=extra_scores | extra_info, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index b6d116a4..7b5609f2 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -143,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores | extra_info, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index e59e2037..d0142d79 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -133,7 +133,7 @@ def emit( batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) - all_extra_info = all_extra_info | extra_info + all_extra_info = {**all_extra_info, **extra_info} # concatenate offsprings together offsprings = jax.tree_util.tree_map( diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index b2b782a0..9677cb2b 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -450,7 +450,7 @@ def scan_train_critics( length=self._config.num_critic_training_steps, ) - return emitter_state + return emitter_state # type: ignore @partial(jax.jit, static_argnames=("self",)) def _train_critics( diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index a12fe2a0..8b649d0c 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -143,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores | extra_info, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics From ec907338b24ba4f21c27aec22ae8cf7ec83cc536 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Mon, 8 Jan 2024 17:33:12 +0000 Subject: [PATCH 08/11] Fix tests --- qdax/core/emitters/cma_pool_emitter.py | 13 ++++++++++--- tests/baselines_test/me_pbt_sac_test.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5af2181..24556f8b 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -73,7 +73,14 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(genotypes, random_key) + emitter_state, random_key = self._emitter.init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) return random_key, emitter_state # init all the emitter states @@ -117,11 +124,11 @@ def emit( ) # use it to emit offsprings - offsprings, random_key = self._emitter.emit( + offsprings, extra_info, random_key = self._emitter.emit( repertoire, used_emitter_state, random_key ) - return offsprings, {}, random_key + return offsprings, extra_info, random_key @partial( jax.jit, diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index c4ab259e..98a5b960 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -126,7 +126,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] From a25bedbabcfef310999c2494a161c82a352613ab Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Mon, 8 Jan 2024 18:23:11 +0000 Subject: [PATCH 09/11] Fix scoring_fn that were outputing extra_scores of the wrong type --- qdax/core/emitters/cma_rnd_emitter.py | 8 +++++++- tests/baselines_test/me_pbt_td3_test.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 2b9928a0..e05cc453 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -35,7 +35,13 @@ class CMARndEmitterState(CMAEmitterState): class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( - self, genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMARndEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 510743c1..fc2e89b0 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -124,7 +124,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] From 84a8f9758a586444ef20dc6a97d6ea0cda2f27a7 Mon Sep 17 00:00:00 2001 From: maxence Date: Mon, 8 Jan 2024 21:11:53 +0000 Subject: [PATCH 10/11] 9 files reformatted by black hook --- qdax/core/emitters/dcg_me_emitter.py | 2 +- qdax/core/emitters/dpg_emitter.py | 3 +- qdax/core/emitters/multi_emitter.py | 8 +- qdax/core/emitters/qdcg_emitter.py | 184 +++++++++++------- qdax/core/neuroevolution/buffers/buffer.py | 8 +- qdax/core/neuroevolution/losses/td3_loss.py | 32 ++- qdax/core/neuroevolution/networks/networks.py | 7 +- qdax/environments/wrappers.py | 23 ++- qdax/tasks/brax_envs.py | 9 +- 9 files changed, 159 insertions(+), 117 deletions(-) diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py index b9ae628b..94e0bb9d 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -74,7 +74,7 @@ def __init__( config=qdcg_config, policy_network=policy_network, actor_network=actor_network, - env=env + env=env, ) # define the GA emitter diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8bc223df..2c55cbd2 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -102,7 +102,8 @@ def init( genotypes, fitnesses, descriptors, - extra_scores,) + extra_scores, + ) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index d0142d79..b3ad23c6 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -88,7 +88,8 @@ def init( genotypes, fitnesses, descriptors, - extra_scores) + extra_scores, + ) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -127,9 +128,8 @@ def emit( subkeys, ): genotype, extra_info, _ = emitter.emit( - repertoire, - sub_emitter_state, - subkey_emitter) + repertoire, sub_emitter_state, subkey_emitter + ) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index 9677cb2b..0d560cbb 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -157,12 +157,10 @@ def init( critic_params = self._critic_network.init( subkey, obs=fake_obs, actions=fake_action, desc=fake_desc ) - target_critic_params = jax.tree_util.tree_map( - lambda x: x, critic_params) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) random_key, subkey = jax.random.split(random_key) - actor_params = self._actor_network.init( - subkey, obs=fake_obs, desc=fake_desc) + actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) # Prepare init optimizer states @@ -184,12 +182,12 @@ def init( transitions = extra_scores["transitions"] episode_length = transitions.obs.shape[1] - desc = jnp.repeat( - descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) transitions = transitions.replace( - desc=desc_normalized, desc_prime=desc_normalized) + desc=desc_normalized, desc_prime=desc_normalized + ) replay_buffer = replay_buffer.insert(transitions) # Initial training state @@ -217,27 +215,35 @@ def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: Returns: batch of similarity measures. """ - return jnp.exp(-jnp.linalg.norm( - descs_1 - descs_2, axis=-1)/self._config.lengthscale) + return jnp.exp( + -jnp.linalg.norm(descs_1 - descs_2, axis=-1) / self._config.lengthscale + ) @partial(jax.jit, static_argnames=("self",)) def _normalize_desc(self, desc: Descriptor) -> Descriptor: - return 2*(desc - self._env.behavior_descriptor_limits[0])/( - self._env.behavior_descriptor_limits[1] - - self._env.behavior_descriptor_limits[0]) - 1 + return ( + 2 + * (desc - self._env.behavior_descriptor_limits[0]) + / ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) + - 1 + ) @partial(jax.jit, static_argnames=("self",)) def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: - return 0.5 * (self._env.behavior_descriptor_limits[1] - - self._env.behavior_descriptor_limits[0]) * desc_normalized + \ - 0.5 * (self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0]) + return 0.5 * ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) * desc_normalized + 0.5 * ( + self._env.behavior_descriptor_limits[1] + + self._env.behavior_descriptor_limits[0] + ) @partial(jax.jit, static_argnames=("self",)) def _compute_equivalent_kernel_bias_with_desc( - self, - actor_dc_params: Params, - desc: Descriptor + self, actor_dc_params: Params, desc: Descriptor ) -> Tuple[Params, Params]: """ Compute the equivalent bias of the first layer of the actor network @@ -248,28 +254,30 @@ def _compute_equivalent_kernel_bias_with_desc( bias = actor_dc_params["params"]["Dense_0"]["bias"] # Compute the equivalent bias - equivalent_kernel = kernel[:-desc.shape[0], :] - equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0]:]) + equivalent_kernel = kernel[: -desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0] :]) return equivalent_kernel, equivalent_bias @partial(jax.jit, static_argnames=("self",)) def _compute_equivalent_params_with_desc( - self, - actor_dc_params: Params, - desc: Descriptor + self, actor_dc_params: Params, desc: Descriptor ) -> Params: desc_normalized = self._normalize_desc(desc) ( equivalent_kernel, equivalent_bias, ) = self._compute_equivalent_kernel_bias_with_desc( - actor_dc_params, desc_normalized) + actor_dc_params, desc_normalized + ) actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias return actor_dc_params - @partial(jax.jit, static_argnames=("self",),) + @partial( + jax.jit, + static_argnames=("self",), + ) def emit( self, repertoire: Repertoire, @@ -288,29 +296,39 @@ def emit( """ # PG emitter parents_pg, descs_pg, random_key = repertoire.sample_with_descs( - random_key, self._config.qpg_batch_size) + random_key, self._config.qpg_batch_size + ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter _, descs_ai, random_key = repertoire.sample_with_descs( - random_key, self._config.ai_batch_size) + random_key, self._config.ai_batch_size + ) descs_ai = descs_ai.reshape( - descs_ai.shape[0], self._env.behavior_descriptor_length) + descs_ai.shape[0], self._env.behavior_descriptor_length + ) genotypes_ai = self.emit_ai(emitter_state, descs_ai) # Concatenate PG and AI genotypes - genotypes = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate( - (x1, x2), axis=0), genotypes_pg, genotypes_ai) + genotypes = jax.tree_util.tree_map( + lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai + ) - return genotypes, { - "desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, random_key + return ( + genotypes, + {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, + random_key, + ) - @partial(jax.jit, static_argnames=("self",),) + @partial( + jax.jit, + static_argnames=("self",), + ) def emit_pg( self, emitter_state: QualityDCGEmitterState, parents: Genotype, - descs: Descriptor + descs: Descriptor, ) -> Genotype: """Emit the offsprings generated through pg mutation. @@ -332,7 +350,10 @@ def emit_pg( return offsprings - @partial(jax.jit, static_argnames=("self",),) + @partial( + jax.jit, + static_argnames=("self",), + ) def emit_ai( self, emitter_state: QualityDCGEmitterState, descs: Descriptor ) -> Genotype: @@ -348,8 +369,9 @@ def emit_ai( Returns: A new set of offsprings. """ - offsprings = jax.vmap(self._compute_equivalent_params_with_desc, in_axes=( - None, 0))(emitter_state.actor_params, descs) + offsprings = jax.vmap( + self._compute_equivalent_params_with_desc, in_axes=(None, 0) + )(emitter_state.actor_params, descs) return offsprings @@ -368,7 +390,10 @@ def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: """ return emitter_state.actor_params - @partial(jax.jit, static_argnames=("self",),) + @partial( + jax.jit, + static_argnames=("self",), + ) def state_update( self, emitter_state: QualityDCGEmitterState, @@ -404,19 +429,20 @@ def state_update( episode_length = transitions.obs.shape[1] desc_prime = jnp.concatenate( - [extra_scores["desc_prime"], - descriptors[self._config.qpg_batch_size+self._config.ai_batch_size:]], - axis=0) - desc_prime = jnp.repeat( - desc_prime[:, jnp.newaxis, :], episode_length, axis=1) - desc = jnp.repeat( - descriptors[:, jnp.newaxis, :], episode_length, axis=1) - - desc_prime_normalized = jax.vmap( - jax.vmap(self._normalize_desc))(desc_prime) + [ + extra_scores["desc_prime"], + descriptors[self._config.qpg_batch_size + self._config.ai_batch_size :], + ], + axis=0, + ) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) transitions = transitions.replace( - desc=desc_normalized, desc_prime=desc_prime_normalized) + desc=desc_normalized, desc_prime=desc_prime_normalized + ) # Add transitions to replay buffer replay_buffer = emitter_state.replay_buffer.insert(transitions) @@ -425,18 +451,28 @@ def state_update( # sample transitions from the replay buffer random_key, subkey = jax.random.split(emitter_state.random_key) transitions, random_key = replay_buffer.sample( - subkey, self._config.num_critic_training_steps*self._config.batch_size) + subkey, self._config.num_critic_training_steps * self._config.batch_size + ) transitions = jax.tree_util.tree_map( - lambda x: jnp.reshape(x, - (self._config.num_critic_training_steps, - self._config.batch_size, *x.shape[1:])), - transitions) - transitions = transitions.replace(rewards=self._similarity( - transitions.desc, transitions.desc_prime)*transitions.rewards) + lambda x: jnp.reshape( + x, + ( + self._config.num_critic_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, transitions.desc_prime) + * transitions.rewards + ) emitter_state = emitter_state.replace(random_key=random_key) def scan_train_critics( - carry: QualityDCGEmitterState, transitions: DCGTransition, + carry: QualityDCGEmitterState, + transitions: DCGTransition, ) -> Tuple[QualityDCGEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state, transitions) @@ -450,7 +486,7 @@ def scan_train_critics( length=self._config.num_critic_training_steps, ) - return emitter_state # type: ignore + return emitter_state # type: ignore @partial(jax.jit, static_argnames=("self",)) def _train_critics( @@ -589,7 +625,10 @@ def _update_actor( target_actor_params, ) - @partial(jax.jit, static_argnames=("self",),) + @partial( + jax.jit, + static_argnames=("self",), + ) def _mutation_function_pg( self, policy_params: Genotype, @@ -612,21 +651,28 @@ def _mutation_function_pg( # Get transitions transitions, random_key = emitter_state.replay_buffer.sample( emitter_state.random_key, - sample_size=self._config.num_pg_training_steps*self._config.batch_size) + sample_size=self._config.num_pg_training_steps * self._config.batch_size, + ) descs_prime = jnp.tile( - descs, (self._config.num_pg_training_steps*self._config.batch_size, 1)) + descs, (self._config.num_pg_training_steps * self._config.batch_size, 1) + ) descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) transitions = transitions.replace( - rewards=self._similarity( - transitions.desc, - descs_prime_normalized)*transitions.rewards, - desc_prime=descs_prime_normalized) + rewards=self._similarity(transitions.desc, descs_prime_normalized) + * transitions.rewards, + desc_prime=descs_prime_normalized, + ) transitions = jax.tree_util.tree_map( lambda x: jnp.reshape( x, - (self._config.num_pg_training_steps, - self._config.batch_size, *x.shape[1:])), - transitions) + ( + self._config.num_pg_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) # Replace random_key emitter_state = emitter_state.replace(random_key=random_key) diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index a80fe739..d25c8c6c 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -357,7 +357,8 @@ def from_flatten( state_desc = flattened_transition[ :, (2 * obs_dim + 3 + action_dim) : ( - 2 * obs_dim + 3 + action_dim + state_desc_dim), + 2 * obs_dim + 3 + action_dim + state_desc_dim + ), ] next_state_desc = flattened_transition[ :, @@ -392,10 +393,7 @@ def from_flatten( @classmethod def init_dummy( # type: ignore - cls, - observation_dim: int, - action_dim: int, - descriptor_dim: int + cls, observation_dim: int, action_dim: int, descriptor_dim: int ) -> QDTransition: """ Initialize a dummy transition that then can be passed to constructors to get diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 9866a9cd..e12797b9 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -131,10 +131,8 @@ def _policy_loss_fn( """Policy loss function for TD3 agent""" action = policy_fn(policy_params, transitions.obs) q_value = critic_fn( - critic_params, - transitions.obs, - action, - transitions.desc_prime) + critic_params, transitions.obs, action, transitions.desc_prime + ) q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss @@ -148,10 +146,7 @@ def _actor_loss_fn( """Descriptor-conditioned policy loss function for TD3 agent""" action = actor_fn(actor_params, transitions.obs, transitions.desc_prime) q_value = critic_fn( - critic_params, - transitions.obs, - action, - transitions.desc_prime + critic_params, transitions.obs, action, transitions.desc_prime ) q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) @@ -172,18 +167,23 @@ def _critic_loss_fn( ).clip(-noise_clip, noise_clip) next_action = ( - actor_fn(target_actor_params, transitions.next_obs, - transitions.desc_prime) + noise + actor_fn(target_actor_params, transitions.next_obs, transitions.desc_prime) + + noise ).clip(-1.0, 1.0) - next_q = critic_fn(target_critic_params, transitions.next_obs, - next_action, transitions.desc_prime) + next_q = critic_fn( + target_critic_params, + transitions.next_obs, + next_action, + transitions.desc_prime, + ) next_v = jnp.min(next_q, axis=-1) target_q = jax.lax.stop_gradient( transitions.rewards * reward_scaling + (1.0 - transitions.dones) * discount * next_v ) - q_old_action = critic_fn(critic_params, transitions.obs, - transitions.actions, transitions.desc_prime) + q_old_action = critic_fn( + critic_params, transitions.obs, transitions.actions, transitions.desc_prime + ) q_error = q_old_action - jnp.expand_dims(target_q, -1) # Better bootstrapping for truncated episodes. @@ -219,9 +219,7 @@ def td3_policy_loss_fn( """ action = policy_fn(policy_params, transitions.obs) - q_value = critic_fn( - critic_params, transitions.obs, action # type: ignore - ) + q_value = critic_fn(critic_params, transitions.obs, action) # type: ignore q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index fea7c1ac..365c8d56 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -9,6 +9,7 @@ class MLP(nn.Module): """MLP module.""" + layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() @@ -49,6 +50,7 @@ def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: class MLPDC(nn.Module): """Descriptor-conditioned MLP module.""" + layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() @@ -115,10 +117,7 @@ class QModuleDC(nn.Module): @nn.compact def __call__( - self, - obs: jnp.ndarray, - actions: jnp.ndarray, - desc: jnp.ndarray + self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray ) -> jnp.ndarray: hidden = jnp.concatenate([obs, actions], axis=-1) res = [] diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 1aa7bcab..cf0c3336 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -80,10 +80,7 @@ class ClipRewardWrapper(Wrapper): """ def __init__( - self, - env: Env, - clip_min: float = None, - clip_max: float = None + self, env: Env, clip_min: float = None, clip_max: float = None ) -> None: super().__init__(env) self._clip_min = clip_min @@ -92,12 +89,14 @@ def __init__( def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) class AffineRewardWrapper(Wrapper): @@ -109,9 +108,7 @@ class AffineRewardWrapper(Wrapper): """ def __init__( - self, env: Env, - clip_min: float = None, - clip_max: float = None + self, env: Env, clip_min: float = None, clip_max: float = None ) -> None: super().__init__(env) self._clip_min = clip_min @@ -120,12 +117,14 @@ def __init__( def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) class OffsetRewardWrapper(Wrapper): @@ -136,7 +135,7 @@ class OffsetRewardWrapper(Wrapper): work like before and will simply clip the reward to be greater than 0. """ - def __init__(self, env: Env, offset: float = 0.) -> None: + def __init__(self, env: Env, offset: float = 0.0) -> None: super().__init__(env) self._offset = offset diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 5d7eda5a..1a588e52 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -176,7 +176,7 @@ def scoring_actor_dc_function_brax_envs( episode_length: int, play_step_actor_dc_fn: Callable[ [EnvState, Descriptor, Params, RNGKey], - Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], ], behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: @@ -317,7 +317,7 @@ def reset_based_scoring_actor_dc_function_brax_envs( play_reset_fn: Callable[[RNGKey], EnvState], play_step_actor_dc_fn: Callable[ [EnvState, Descriptor, Params, RNGKey], - Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], ], behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: @@ -353,8 +353,9 @@ def reset_based_scoring_actor_dc_function_brax_envs( """ random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, - jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0]) + keys = jax.random.split( + subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] + ) reset_fn = jax.vmap(play_reset_fn) init_states = reset_fn(keys) From 9e75ded46b82362109d31ee211685826b2d3452a Mon Sep 17 00:00:00 2001 From: maxence Date: Mon, 8 Jan 2024 21:34:00 +0000 Subject: [PATCH 11/11] Fix omg_mega emitter --- qdax/core/emitters/omg_mega_emitter.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7a480e06..54766152 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -118,6 +118,18 @@ def init( genotype=gradient_genotype, centroids=self._centroids ) + # get gradients out of the extra scores + assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key" + gradients = extra_scores["gradients"] + + # update the gradients repertoire + gradients_repertoire = gradients_repertoire.add( + gradients, + descriptors, + fitnesses, + extra_scores, + ) + return ( OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), random_key,