diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..5968b03f 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -228,6 +228,34 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key + @partial(jax.jit, static_argnames=("num_samples",)) + def sample_with_descs(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + repertoire_empty = self.fitnesses == -jnp.inf + p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) + + random_key, subkey = jax.random.split(random_key) + samples = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.genotypes, + ) + descs = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.descriptors, + ) + + return samples, descs, random_key + @jax.jit def add( self, diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..c090e448 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -99,14 +99,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -154,7 +160,7 @@ def emit( cmaes_state=emitter_state.cmaes_state, random_key=random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..f79579dd 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -100,14 +100,20 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAMEGAState, RNGKey]: """ Initializes the CMA-MEGA emitter. Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -117,7 +123,7 @@ def init( # define init theta as 0 theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), - init_genotypes, + genotypes, ) # score it @@ -181,7 +187,7 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, random_key + return new_thetas, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5424a01..67034d71 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -49,14 +49,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAPoolEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -67,7 +73,7 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + emitter_state, random_key = self._emitter.init(genotypes, random_key) return random_key, emitter_state # init all the emitter states @@ -115,7 +121,7 @@ def emit( repertoire, used_emitter_state, random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py new file mode 100644 index 00000000..b9ae628b --- /dev/null +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Callable, Tuple + +import flax.linen as nn + +from qdax.core.emitters.multi_emitter import MultiEmitter +from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Params, RNGKey + + +@dataclass +class DCGMEConfig: + """Configuration for DCGME Algorithm""" + + ga_batch_size: int = 128 + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + # PG emitter + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCGMEEmitter(MultiEmitter): + def __init__( + self, + config: DCGMEConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]], + ) -> None: + self._config = config + self._env = env + self._variation_fn = variation_fn + + qdcg_config = QualityDCGConfig( + qpg_batch_size=config.qpg_batch_size, + ai_batch_size=config.ai_batch_size, + lengthscale=config.lengthscale, + critic_hidden_layer_size=config.critic_hidden_layer_size, + num_critic_training_steps=config.num_critic_training_steps, + num_pg_training_steps=config.num_pg_training_steps, + batch_size=config.batch_size, + replay_buffer_size=config.replay_buffer_size, + discount=config.discount, + reward_scaling=config.reward_scaling, + critic_learning_rate=config.critic_learning_rate, + actor_learning_rate=config.actor_learning_rate, + policy_learning_rate=config.policy_learning_rate, + noise_clip=config.noise_clip, + policy_noise=config.policy_noise, + soft_tau_update=config.soft_tau_update, + policy_delay=config.policy_delay, + ) + + # define the quality emitter + q_emitter = QualityDCGEmitter( + config=qdcg_config, + policy_network=policy_network, + actor_network=actor_network, + env=env + ) + + # define the GA emitter + ga_emitter = MixingEmitter( + mutation_fn=lambda x, r: (x, r), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=config.ga_batch_size, + ) + + super().__init__(emitters=(q_emitter, ga_emitter)) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..c266be10 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -10,7 +10,7 @@ import optax from qdax.core.containers.archive import Archive -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.repertoire import MapElitesRepertoire from qdax.core.emitters.qpg_emitter import ( QualityPGConfig, QualityPGEmitter, @@ -77,12 +77,18 @@ def __init__( self._score_novelty = score_novelty def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[DiversityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -90,7 +96,7 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(init_genotypes, random_key) + diversity_emitter_state, random_key = super().init(genotypes, random_key) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -116,7 +122,7 @@ def init( def state_update( self, emitter_state: DiversityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[MapElitesRepertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..14c6277a 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -30,7 +30,13 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..9641a613 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -236,26 +236,32 @@ def batch_size(self) -> int: static_argnames=("self",), ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[MEESEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1: - init_genotypes = jax.tree_util.tree_map( + if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree_util.tree_map( lambda x: x[0], - init_genotypes, + genotypes, ) # Initialise optimizer - initial_optimizer_state = self._optimizer.init(init_genotypes) + initial_optimizer_state = self._optimizer.init(genotypes) # Create empty Novelty archive if self._config.use_explore: @@ -270,7 +276,7 @@ def init( # Create empty updated genotypes and fitness last_updated_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), - init_genotypes, + genotypes, ) last_updated_fitnesses = -jnp.inf * jnp.ones( shape=self._config.last_updated_size @@ -280,7 +286,7 @@ def init( MEESEmitterState( initial_optimizer_state=initial_optimizer_state, optimizer_state=initial_optimizer_state, - offspring=init_genotypes, + offspring=genotypes, generation_count=0, novelty_archive=novelty_archive, last_updated_genotypes=last_updated_genotypes, @@ -313,7 +319,7 @@ def emit( a new jax PRNG key """ - return emitter_state.offspring, random_key + return emitter_state.offspring, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 2da46639..a0789f33 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -56,13 +56,19 @@ def get_indexes_separation_batches( return tuple(indexes_separation_batches) def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """ Initialize the state of the emitter. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -76,7 +82,7 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(init_genotypes, subkey_emitter) + emitter_state, _ = emitter.init(subkey_emitter, repertoire, genotypes, fitnesses, descriptors, extra_scores) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -108,21 +114,23 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] + all_extra_info = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) + all_extra_info = all_extra_info | extra_info # concatenate offsprings together offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, random_key + return offsprings, all_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..2380a85c 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -84,20 +84,26 @@ def __init__( self._num_descriptors = num_descriptors def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[OMGMEGAEmitterState, RNGKey]: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 gradient_genotype = jax.tree_util.tree_map( @@ -190,7 +196,7 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, random_key + return new_genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 3fdb4418..64c05f16 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -91,12 +91,18 @@ def __init__( ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[PBTEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -145,13 +151,13 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - init_genotypes = jax.tree_util.tree_map( - lambda x: x[: self._config.pg_population_size_per_device], init_genotypes + genotypes = jax.tree_util.tree_map( + lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, - training_states=init_genotypes, + training_states=genotypes, random_key=subkey2, ) @@ -199,7 +205,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py new file mode 100644 index 00000000..745773bd --- /dev/null +++ b/qdax/core/emitters/qdcg_emitter.py @@ -0,0 +1,658 @@ +"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm in JAX for Brax environments. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional, Tuple, Callable + +import jax +from jax import numpy as jnp +import flax.linen as nn +from flax.core.frozen_dict import freeze +import optax + +from qdax.core.containers.repertoire import Repertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn +from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey + + +@dataclass +class QualityDCGConfig: + """Configuration for QualityDCG Emitter""" + + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class QualityDCGEmitterState(EmitterState): + """Contains training state for the learner.""" + + critic_params: Params + critic_opt_state: optax.OptState + actor_params: Params + actor_opt_state: optax.OptState + target_critic_params: Params + target_actor_params: Params + replay_buffer: ReplayBuffer + random_key: RNGKey + steps: jnp.ndarray + + +class QualityDCGEmitter(Emitter): + """ + A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites + (PGA-Map-Elites) algorithm. + """ + + def __init__( + self, + config: QualityDCGConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + ) -> None: + self._config = config + self._env = env + self._policy_network = policy_network + self._actor_network = actor_network + + # Init Critics + critic_network = QModuleDC( + n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size + ) + self._critic_network = critic_network + + # Set up the losses and optimizers - return the opt states + self._policy_loss_fn, self._actor_loss_fn, self._critic_loss_fn = make_td3_loss_dc_fn( + policy_fn=policy_network.apply, + actor_fn=actor_network.apply, + critic_fn=critic_network.apply, + reward_scaling=self._config.reward_scaling, + discount=self._config.discount, + noise_clip=self._config.noise_clip, + policy_noise=self._config.policy_noise, + ) + + # Init optimizers + self._actor_optimizer = optax.adam( + learning_rate=self._config.actor_learning_rate + ) + self._critic_optimizer = optax.adam( + learning_rate=self._config.critic_learning_rate + ) + self._policies_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate + ) + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._config.qpg_batch_size + self._config.ai_batch_size + + @property + def use_all_data(self) -> bool: + """Whether to use all data or not when used along other emitters. + + QualityPGEmitter uses the transitions from the genotypes that were generated + by other emitters. + """ + return True + + def init( + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> Tuple[QualityDCGEmitterState, RNGKey]: + """Initializes the emitter state. + + Args: + genotypes: The initial population. + random_key: A random key. + + Returns: + The initial state of the PGAMEEmitter, a new random key. + """ + + observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] + descriptor_size = self._env.behavior_descriptor_length + action_size = self._env.action_size + + # Initialise critic, greedy actor and population + random_key, subkey = jax.random.split(random_key) + fake_obs = jnp.zeros(shape=(observation_size,)) + fake_desc = jnp.zeros(shape=(descriptor_size,)) + fake_action = jnp.zeros(shape=(action_size,)) + + critic_params = self._critic_network.init( + subkey, obs=fake_obs, actions=fake_action, desc=fake_desc + ) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + + random_key, subkey = jax.random.split(random_key) + actor_params = self._actor_network.init( + subkey, obs=fake_obs, desc=fake_desc) + target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + + # Prepare init optimizer states + critic_opt_state = self._critic_optimizer.init(critic_params) + actor_opt_state = self._actor_optimizer.init(actor_params) + + # Initialize replay buffer + dummy_transition = DCGTransition.init_dummy( + observation_dim=self._env.observation_size, + action_dim=action_size, + descriptor_dim=descriptor_size, + ) + + replay_buffer = ReplayBuffer.init( + buffer_size=self._config.replay_buffer_size, transition=dummy_transition + ) + + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + + transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_normalized) + replay_buffer = replay_buffer.insert(transitions) + + # Initial training state + random_key, subkey = jax.random.split(random_key) + emitter_state = QualityDCGEmitterState( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + replay_buffer=replay_buffer, + random_key=subkey, + steps=jnp.array(0), + ) + + return emitter_state, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _similarity(self, descs_1, descs_2): + """Compute the similarity between two batches of descriptors. + Args: + descs_1: batch of descriptors, representing the observed descriptors of the trajectories. + descs_2: batch of descriptors, representing the sampled descriptors. + Returns: + batch of similarity measures. + """ + return jnp.exp(-jnp.linalg.norm(descs_1 - descs_2, axis=-1)/self._config.lengthscale) + + @partial(jax.jit, static_argnames=("self",)) + def _normalize_desc(self, desc): + return 2*(desc - self._env.behavior_descriptor_limits[0])/(self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) - 1 + + @partial(jax.jit, static_argnames=("self",)) + def _unnormalize_desc(self, desc_normalized): + return 0.5 * (self._env.behavior_descriptor_limits[1] - self._env.behavior_descriptor_limits[0]) * desc_normalized + \ + 0.5 * (self._env.behavior_descriptor_limits[1] + self._env.behavior_descriptor_limits[0]) + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_kernel_bias_with_desc(self, actor_dc_params, desc): + """ + Compute the equivalent bias of the first layer of the actor network + given a descriptor. + """ + # Extract kernel and bias of the first layer + kernel = actor_dc_params["params"]["Dense_0"]["kernel"] + bias = actor_dc_params["params"]["Dense_0"]["bias"] + + # Compute the equivalent bias + equivalent_kernel = kernel[:-desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0]:]) + + return equivalent_kernel, equivalent_bias + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_params_with_desc(self, actor_dc_params, desc): + desc_normalized = self._normalize_desc(desc) + equivalent_kernel, equivalent_bias = self._compute_equivalent_kernel_bias_with_desc(actor_dc_params, desc_normalized) + actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel + actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias + return actor_dc_params + + @partial(jax.jit, static_argnames=("self",),) + def emit( + self, + repertoire: Repertoire, + emitter_state: QualityDCGEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, RNGKey]: + """Do a step of PG emission. + + Args: + repertoire: the current repertoire of genotypes + emitter_state: the state of the emitter used + random_key: a random key + + Returns: + A batch of offspring, the new emitter state and a new key. + """ + # PG emitter + parents_pg, descs_pg, random_key = repertoire.sample_with_descs(random_key, self._config.qpg_batch_size) + genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) + + # Actor injection emitter + _, descs_ai, random_key = repertoire.sample_with_descs(random_key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape(descs_ai.shape[0], self._env.behavior_descriptor_length) + genotypes_ai = self.emit_ai(emitter_state, descs_ai) + + # Concatenate PG and AI genotypes + genotypes = jax.tree_util.tree_map(lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai) + + return genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, random_key + + @partial(jax.jit, static_argnames=("self",),) + def emit_pg( + self, emitter_state: QualityDCGEmitterState, parents: Genotype, descs: Descriptor) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + mutation_fn = partial( + self._mutation_function_pg, + emitter_state=emitter_state, + ) + offsprings = jax.vmap(mutation_fn)(parents, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",),) + def emit_ai( + self, emitter_state: QualityDCGEmitterState, descs: Descriptor + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + offsprings = jax.vmap(self._compute_equivalent_params_with_desc, in_axes=(None, 0))(emitter_state.actor_params, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",)) + def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + """Emit the greedy actor. + + Simply needs to be retrieved from the emitter state. + + Args: + emitter_state: the current emitter state, it stores the + greedy actor. + + Returns: + The parameters of the actor. + """ + return emitter_state.actor_params + + @partial(jax.jit, static_argnames=("self",),) + def state_update( + self, + emitter_state: QualityDCGEmitterState, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> QualityDCGEmitterState: + """This function gives an opportunity to update the emitter state + after the genotypes have been scored. + + Here it is used to fill the Replay Buffer with the transitions + from the scoring of the genotypes, and then the training of the + critic/actor happens. Hence the params of critic/actor are updated, + as well as their optimizer states. + + Args: + emitter_state: current emitter state. + repertoire: the current genotypes repertoire + genotypes: unused here - but compulsory in the signature. + fitnesses: unused here - but compulsory in the signature. + descriptors: unused here - but compulsory in the signature. + extra_scores: extra information coming from the scoring function, + this contains the transitions added to the replay buffer. + + Returns: + New emitter state where the replay buffer has been filled with + the new experienced transitions. + """ + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc_prime = jnp.concatenate([extra_scores["desc_prime"], descriptors[self._config.qpg_batch_size+self._config.ai_batch_size:]], axis=0) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + transitions = transitions.replace(desc=desc_normalized, desc_prime=desc_prime_normalized) + + # Add transitions to replay buffer + replay_buffer = emitter_state.replay_buffer.insert(transitions) + emitter_state = emitter_state.replace(replay_buffer=replay_buffer) + + # sample transitions from the replay buffer + random_key, subkey = jax.random.split(emitter_state.random_key) + transitions, random_key = replay_buffer.sample(subkey, self._config.num_critic_training_steps*self._config.batch_size) + transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_critic_training_steps, self._config.batch_size, *x.shape[1:])), transitions) + transitions = transitions.replace(rewards=self._similarity(transitions.desc, transitions.desc_prime)*transitions.rewards) + emitter_state = emitter_state.replace(random_key=random_key) + + def scan_train_critics( + carry: QualityDCGEmitterState, transitions + ) -> Tuple[QualityDCGEmitterState, Any]: + emitter_state = carry + new_emitter_state = self._train_critics(emitter_state, transitions) + return new_emitter_state, () + + # Train critics and greedy actor + emitter_state, _ = jax.lax.scan( + scan_train_critics, + emitter_state, + transitions, + length=self._config.num_critic_training_steps, + ) + + return emitter_state + + @partial(jax.jit, static_argnames=("self",)) + def _train_critics( + self, emitter_state: QualityDCGEmitterState, transitions + ) -> QualityDCGEmitterState: + """Apply one gradient step to critics and to the greedy actor + (contained in carry in training_state), then soft update target critics + and target actor. + + Those updates are very similar to those made in TD3. + + Args: + emitter_state: actual emitter state + + Returns: + New emitter state where the critic and the greedy actor have been + updated. Optimizer states have also been updated in the process. + """ + # Update Critic + ( + critic_opt_state, + critic_params, + target_critic_params, + random_key, + ) = self._update_critic( + critic_params=emitter_state.critic_params, + target_critic_params=emitter_state.target_critic_params, + target_actor_params=emitter_state.target_actor_params, + critic_opt_state=emitter_state.critic_opt_state, + transitions=transitions, + random_key=emitter_state.random_key, + ) + + # Update greedy actor + (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + emitter_state.steps % self._config.policy_delay == 0, + lambda x: self._update_actor(*x), + lambda _: ( + emitter_state.actor_opt_state, + emitter_state.actor_params, + emitter_state.target_actor_params, + ), + operand=( + emitter_state.actor_params, + emitter_state.actor_opt_state, + emitter_state.target_actor_params, + emitter_state.critic_params, + transitions, + ), + ) + + # Create new training state + new_emitter_state = emitter_state.replace( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + random_key=random_key, + steps=emitter_state.steps + 1, + ) + + return new_emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _update_critic( + self, + critic_params: Params, + target_critic_params: Params, + target_actor_params: Params, + critic_opt_state: Params, + transitions: DCGTransition, + random_key: RNGKey, + ) -> Tuple[Params, Params, Params, RNGKey]: + + # compute loss and gradients + random_key, subkey = jax.random.split(random_key) + critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_params, + target_actor_params, + target_critic_params, + transitions, + subkey, + ) + critic_updates, critic_opt_state = self._critic_optimizer.update( + critic_gradient, critic_opt_state + ) + + # update critic + critic_params = optax.apply_updates(critic_params, critic_updates) + + # Soft update of target critic network + target_critic_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_critic_params, + critic_params, + ) + + return critic_opt_state, critic_params, target_critic_params, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _update_actor( + self, + actor_params: Params, + actor_opt_state: optax.OptState, + target_actor_params: Params, + critic_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params, Params]: + + # Update greedy actor + policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + actor_params, + critic_params, + transitions, + ) + ( + policy_updates, + actor_opt_state, + ) = self._actor_optimizer.update(policy_gradient, actor_opt_state) + actor_params = optax.apply_updates(actor_params, policy_updates) + + # Soft update of target greedy actor + target_actor_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_actor_params, + actor_params, + ) + + return ( + actor_opt_state, + actor_params, + target_actor_params, + ) + + @partial(jax.jit, static_argnames=("self",),) + def _mutation_function_pg( + self, + policy_params: Genotype, + descs: Descriptor, + emitter_state: QualityDCGEmitterState, + ) -> Genotype: + """Apply pg mutation to a policy via multiple steps of gradient descent. + First, update the rewards to be diversity rewards, then apply the gradient + steps. + + Args: + policy_params: a policy, supposed to be a differentiable neural + network. + emitter_state: the current state of the emitter, containing among others, + the replay buffer, the critic. + + Returns: + The updated params of the neural network. + """ + # Get transitions + transitions, random_key = emitter_state.replay_buffer.sample(emitter_state.random_key, sample_size=self._config.num_pg_training_steps*self._config.batch_size) + descs_prime = jnp.tile(descs, (self._config.num_pg_training_steps*self._config.batch_size, 1)) + descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) + transitions = transitions.replace(rewards=self._similarity(transitions.desc, descs_prime_normalized)*transitions.rewards, desc_prime=descs_prime_normalized) + transitions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (self._config.num_pg_training_steps, self._config.batch_size, *x.shape[1:])), transitions) + + # Replace random_key + emitter_state = emitter_state.replace(random_key=random_key) + + # Define new policy optimizer state + policy_opt_state = self._policies_optimizer.init(policy_params) + + def scan_train_policy( + carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], + transitions, + ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + emitter_state, policy_params, policy_opt_state = carry + ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ) = self._train_policy( + emitter_state, + policy_params, + policy_opt_state, + transitions, + ) + return ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ), () + + (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + scan_train_policy, + (emitter_state, policy_params, policy_opt_state), + transitions, + length=self._config.num_pg_training_steps, + ) + + return policy_params + + @partial(jax.jit, static_argnames=("self",)) + def _train_policy( + self, + emitter_state: QualityDCGEmitterState, + policy_params: Params, + policy_opt_state: optax.OptState, + transitions, + ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + """Apply one gradient step to a policy (called policy_params). + + Args: + emitter_state: current state of the emitter. + policy_params: parameters corresponding to the weights and bias of + the neural network that defines the policy. + + Returns: + The new emitter state and new params of the NN. + """ + # update policy + policy_opt_state, policy_params = self._update_policy( + critic_params=emitter_state.critic_params, + policy_opt_state=policy_opt_state, + policy_params=policy_params, + transitions=transitions, + ) + + return emitter_state, policy_params, policy_opt_state + + @partial(jax.jit, static_argnames=("self",)) + def _update_policy( + self, + critic_params: Params, + policy_opt_state: optax.OptState, + policy_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params]: + + # compute loss + _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_params, + critic_params, + transitions, + ) + # Compute gradient and update policies + ( + policy_updates, + policy_opt_state, + ) = self._policies_optimizer.update(policy_gradient, policy_opt_state) + policy_params = optax.apply_updates(policy_params, policy_updates) + + return policy_opt_state, policy_params diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..d43827df 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -12,7 +12,7 @@ import optax from jax import numpy as jnp -from qdax.core.containers.repertoire import Repertoire +from qdax.core.containers.repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn @@ -119,12 +119,18 @@ def use_all_data(self) -> bool: return True def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[QualityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -144,8 +150,8 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -184,7 +190,7 @@ def init( ) def emit( self, - repertoire: Repertoire, + repertoire: MapElitesRepertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, ) -> Tuple[Genotype, RNGKey]: @@ -223,7 +229,7 @@ def emit( offspring_actor, ) - return genotypes, random_key + return genotypes, {}, random_key @partial( jax.jit, @@ -273,7 +279,7 @@ def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype: def state_update( self, emitter_state: QualityPGEmitterState, - repertoire: Optional[Repertoire], + repertoire: Optional[MapElitesRepertoire], genotypes: Optional[Genotype], fitnesses: Optional[Fitness], descriptors: Optional[Descriptor], diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 8b877792..740aafa5 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -75,7 +75,7 @@ def emit( x_mutation, ) - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..77dd437e 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -23,7 +23,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independent of the GeneticAlgorithm class at the moment to keep + MAPElites class independant of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -52,7 +52,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -62,9 +62,9 @@ def init( such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tessellation centroids of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: @@ -73,12 +73,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MapElitesRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -87,14 +87,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -129,9 +124,10 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key @@ -147,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 42ed7552..3fe777d4 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor +from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor, Descriptor class Transition(flax.struct.PyTreeNode): @@ -262,6 +262,152 @@ def init_dummy( # type: ignore return dummy_transition +class DCGTransition(QDTransition): + """Stores data corresponding to a transition collected by a QD algorithm.""" + + desc: Descriptor + desc_prime: Descriptor + + @property + def descriptor_dim(self) -> int: + """ + Returns: + the dimension of the descriptors. + """ + return self.state_desc.shape[-1] # type: ignore + + @property + def flatten_dim(self) -> int: + """ + Returns: + the dimension of the transition once flattened. + """ + flatten_dim = ( + 2 * self.observation_dim + + self.action_dim + + 3 + + 2 * self.state_descriptor_dim + + 2 * self.descriptor_dim + ) + return flatten_dim + + def flatten(self) -> jnp.ndarray: + """ + Returns: + a jnp.ndarray that corresponds to the flattened transition. + """ + flatten_transition = jnp.concatenate( + [ + self.obs, + self.next_obs, + jnp.expand_dims(self.rewards, axis=-1), + jnp.expand_dims(self.dones, axis=-1), + jnp.expand_dims(self.truncations, axis=-1), + self.actions, + self.state_desc, + self.next_state_desc, + self.desc, + self.desc_prime, + ], + axis=-1, + ) + return flatten_transition + + @classmethod + def from_flatten( + cls, + flattened_transition: jnp.ndarray, + transition: QDTransition, + ) -> QDTransition: + """ + Creates a transition from a flattened transition in a jnp.ndarray. + Args: + flattened_transition: flattened transition in a jnp.ndarray of shape + (batch_size, flatten_dim) + transition: a transition object (might be a dummy one) to + get the dimensions right + Returns: + a Transition object + """ + obs_dim = transition.observation_dim + action_dim = transition.action_dim + state_desc_dim = transition.state_descriptor_dim + desc_dim = transition.descriptor_dim + + obs = flattened_transition[:, :obs_dim] + next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)] + rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)]) + dones = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)] + ) + truncations = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)] + ) + actions = flattened_transition[ + :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim) + ] + state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + state_desc_dim), + ] + next_state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + ), + ] + desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim + ), + ] + desc_prime = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + 2 * desc_dim + ), + ] + return cls( + obs=obs, + next_obs=next_obs, + rewards=rewards, + dones=dones, + truncations=truncations, + actions=actions, + state_desc=state_desc, + next_state_desc=next_state_desc, + desc=desc, + desc_prime=desc_prime, + ) + + @classmethod + def init_dummy( # type: ignore + cls, observation_dim: int, action_dim: int, descriptor_dim: int) -> QDTransition: + """ + Initialize a dummy transition that then can be passed to constructors to get + all shapes right. + Args: + observation_dim: observation dimension + action_dim: action dimension + Returns: + a dummy transition + """ + dummy_transition = DCGTransition( + obs=jnp.zeros(shape=(1, observation_dim)), + next_obs=jnp.zeros(shape=(1, observation_dim)), + rewards=jnp.zeros(shape=(1,)), + dones=jnp.zeros(shape=(1,)), + truncations=jnp.zeros(shape=(1,)), + actions=jnp.zeros(shape=(1, action_dim)), + state_desc=jnp.zeros(shape=(1, descriptor_dim)), + next_state_desc=jnp.zeros(shape=(1, descriptor_dim)), + desc=jnp.zeros(shape=(1, descriptor_dim)), + desc_prime=jnp.zeros(shape=(1, descriptor_dim)), + ) + return dummy_transition + + class ReplayBuffer(flax.struct.PyTreeNode): """ A replay buffer where transitions are flattened before being stored. diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 7f34a036..b360267c 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.types import Action, Observation, Descriptor, Params, RNGKey def make_td3_loss_fn( @@ -94,6 +94,97 @@ def _critic_loss_fn( return _policy_loss_fn, _critic_loss_fn +def make_td3_loss_dc_fn( + policy_fn: Callable[[Params, Observation], jnp.ndarray], + actor_fn: Callable[[Params, Observation, Descriptor], jnp.ndarray], + critic_fn: Callable[[Params, Observation, Action, Descriptor], jnp.ndarray], + reward_scaling: float, + discount: float, + noise_clip: float, + policy_noise: float, +) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], +]: + """Creates the loss functions for TD3. + Args: + policy_fn: forward pass through the neural network defining the policy. + actor_fn: forward pass through the neural network defining the + descriptor-conditioned policy. + critic_fn: forward pass through the neural network defining the + descriptor-conditioned critic. + reward_scaling: value to multiply the reward given by the environment. + discount: discount factor. + noise_clip: value that clips the noise to avoid extreme values. + policy_noise: noise applied to smooth the bootstrapping. + Returns: + Return the loss functions used to train the policy and the critic in TD3. + """ + + @jax.jit + def _policy_loss_fn( + policy_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Policy loss function for TD3 agent""" + action = policy_fn(policy_params, obs=transitions.obs) + q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _actor_loss_fn( + actor_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Descriptor-conditioned policy loss function for TD3 agent""" + action = actor_fn(actor_params, obs=transitions.obs, desc=transitions.desc_prime) + q_value = critic_fn(critic_params, obs=transitions.obs, actions=action, desc=transitions.desc_prime) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _critic_loss_fn( + critic_params: Params, + target_actor_params: Params, + target_critic_params: Params, + transitions: Transition, + random_key: RNGKey, + ) -> jnp.ndarray: + """Descriptor-conditioned critic loss function for TD3 agent""" + noise = ( + jax.random.normal(random_key, shape=transitions.actions.shape) + * policy_noise + ).clip(-noise_clip, noise_clip) + + next_action = ( + actor_fn(target_actor_params, obs=transitions.next_obs, desc=transitions.desc_prime) + noise + ).clip(-1.0, 1.0) + next_q = critic_fn(target_critic_params, obs=transitions.next_obs, actions=next_action, desc=transitions.desc_prime) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.rewards * reward_scaling + + (1.0 - transitions.dones) * discount * next_v + ) + q_old_action = critic_fn(critic_params, obs=transitions.obs, actions=transitions.actions, desc=transitions.desc_prime) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + + # Better bootstrapping for truncated episodes. + q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1) + + # compute the loss + q_losses = jnp.mean(jnp.square(q_error), axis=-2) + q_loss = jnp.sum(q_losses, axis=-1) + + return q_loss + + return _policy_loss_fn, _actor_loss_fn, _critic_loss_fn + + def td3_policy_loss_fn( policy_params: Params, critic_params: Params, diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 984d1aeb..fe936a57 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey +from qdax.types import Genotype, Params, RNGKey, Descriptor class TrainingState(PyTreeNode): @@ -67,6 +67,54 @@ def _scan_play_step_fn( return state, transitions +@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) +def generate_unroll_actor_dc( + init_state: EnvState, + actor_dc_params: Params, + desc: Descriptor, + random_key: RNGKey, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[ + EnvState, + Descriptor, + Params, + RNGKey, + Transition, + ], + ], +) -> Tuple[EnvState, Transition]: + """Generates an episode according to the agent's policy and descriptor, returns the final state of + the episode and the transitions of the episode. + + Args: + init_state: first state of the rollout. + policy_dc_params: descriptor-conditioned policy params. + desc: descriptor the policy attempts to achieve. + random_key: random key for stochasiticity handling. + episode_length: length of the rollout. + play_step_fn: function describing how a step need to be taken. + + Returns: + A new state, the experienced transition. + """ + + def _scan_play_step_fn( + carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any + ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: + env_state, actor_dc_params, desc, random_key, transitions = play_step_actor_dc_fn(*carry) + return (env_state, actor_dc_params, desc, random_key), transitions + + (state, _, _, _), transitions = jax.lax.scan( + _scan_play_step_fn, + (init_state, actor_dc_params, desc, random_key), + (), + length=episode_length, + ) + return state, transitions + + @jax.jit def get_first_episode(transition: Transition) -> Transition: """Extracts the first episode from a batch of transitions, returns the batch of diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index b2b176ef..d4b2ab3a 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -8,29 +8,47 @@ from brax.training import networks -class QModule(nn.Module): - """Q Module.""" - - hidden_layer_sizes: Tuple[int, ...] - n_critics: int = 2 +class MLP(nn.Module): + """MLP module.""" + layer_sizes: Tuple[int, ...] + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() + final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + bias: bool = True + kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: - hidden = jnp.concatenate([obs, actions], axis=-1) - res = [] - for _ in range(self.n_critics): - q = networks.MLP( - layer_sizes=self.hidden_layer_sizes + (1,), - activation=nn.relu, - kernel_init=jax.nn.initializers.lecun_uniform(), - )(hidden) - res.append(q) - return jnp.concatenate(res, axis=-1) + def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: + hidden = obs + for i, hidden_size in enumerate(self.layer_sizes): + if i != len(self.layer_sizes) - 1: + hidden = nn.Dense( + hidden_size, + kernel_init=self.kernel_init, + use_bias=self.bias, + )(hidden) + hidden = self.activation(hidden) # type: ignore -class MLP(nn.Module): - """MLP module.""" + else: + if self.kernel_init_final is not None: + kernel_init = self.kernel_init_final + else: + kernel_init = self.kernel_init + hidden = nn.Dense( + hidden_size, + kernel_init=kernel_init, + use_bias=self.bias, + )(hidden) + + if self.final_activation is not None: + hidden = self.final_activation(hidden) + + return hidden + +class MLPDC(nn.Module): + """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() @@ -39,15 +57,13 @@ class MLP(nn.Module): kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, data: jnp.ndarray) -> jnp.ndarray: - hidden = data + def __call__(self, obs: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, desc], axis=-1) for i, hidden_size in enumerate(self.layer_sizes): if i != len(self.layer_sizes) - 1: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", with this version of flax, changing the name - # changes the initialization kernel_init=self.kernel_init, use_bias=self.bias, )(hidden) @@ -61,7 +77,6 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", kernel_init=kernel_init, use_bias=self.bias, )(hidden) @@ -70,3 +85,42 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = self.final_activation(hidden) return hidden + + +class QModule(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + +class QModuleDC(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLPDC( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden, desc) + res.append(q) + return jnp.concatenate(res, axis=-1) diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 720f662a..ed1b4387 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -3,7 +3,7 @@ import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper +from brax.v1.envs import State, Wrapper, Env class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -69,3 +69,65 @@ def step(self, state: State, action: jp.ndarray) -> State: ) nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + +class ClipRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + +class AffineRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, clip_min=None, clip_max=None) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max)) + +class OffsetRewardWrapper(Wrapper): + """Wraps gym environments to offset the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, offset=0.) -> None: + super().__init__(env) + self._offset = offset + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=state.reward + self._offset) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=state.reward + self._offset) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..845c329b 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -9,8 +9,8 @@ import qdax.environments from qdax import environments -from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( Descriptor, @@ -18,7 +18,6 @@ ExtraScores, Fitness, Genotype, - Observation, Params, RNGKey, ) @@ -83,15 +82,6 @@ def default_play_step_fn( return default_play_step_fn -def get_mask_from_transitions( - data: Transition, -) -> jnp.ndarray: - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) - return mask - - @partial( jax.jit, static_argnames=( @@ -144,7 +134,9 @@ def scoring_function_brax_envs( _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) # create a mask to extract data properly - mask = get_mask_from_transitions(data) + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) # scores fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) @@ -160,6 +152,78 @@ def scoring_function_brax_envs( ) +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. + descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll_actor_dc, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + @partial( jax.jit, static_argnames=( @@ -225,6 +289,72 @@ def reset_based_scoring_function_brax_envs( return fitnesses, descriptors, extra_scores, random_key +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policy_dc_params: The parameters of closed-loop descriptor-conditioned policy to evaluate. + descriptors: The descriptors the descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0]) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + fitnesses, descriptors, extra_scores, random_key = scoring_actor_dc_function_brax_envs( + actors_dc_params=actors_dc_params, + descs=descs, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, @@ -275,10 +405,10 @@ def create_brax_scoring_fn( init_state = env.reset(subkey) # Define the function to deterministically reset the environment - def deterministic_reset(_: RNGKey, _init_state: EnvState) -> EnvState: - return _init_state + def deterministic_reset(key: RNGKey, init_state: EnvState) -> EnvState: + return init_state - play_reset_fn = partial(deterministic_reset, _init_state=init_state) + play_reset_fn = partial(deterministic_reset, init_state=init_state) # Stochastic case elif play_reset_fn is None: @@ -348,36 +478,3 @@ def create_default_brax_task_components( ) return env, policy_network, scoring_fn, random_key - - -def get_aurora_scoring_fn( - scoring_fn: Callable[ - [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] - ], - observation_extractor_fn: Callable[[Transition], Observation], -) -> Callable[ - [Genotype, RNGKey], Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey] -]: - """Evaluates policies contained in flatten_variables in parallel - - This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarly - evaluated with the same environment everytime, this won't be determinist. - - When the init states are different, this is not purely stochastic. This - choice was made for performance reason, as the reset function of brax envs - is quite time-consuming. If pure stochasticity of the environment is needed - for a use case, please open an issue. - """ - - @functools.wraps(scoring_fn) - def _wrapper( - params: Params, random_key: RNGKey # Perform rollouts with each policy - ) -> Tuple[Fitness, Optional[Descriptor], ExtraScores, RNGKey]: - fitnesses, _, extra_scores, random_key = scoring_fn(params, random_key) - data = extra_scores["transitions"] - observation = observation_extractor_fn(data) # type: ignore - extra_scores["last_valid_observations"] = observation - return fitnesses, None, extra_scores, random_key - - return _wrapper