From 82d87c2418a4fce45545085c3eef1f5d7788c59d Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 5 Sep 2024 14:14:31 +0000 Subject: [PATCH] Rename DCG-ME to DCRL-ME --- .../{qdcg_emitter.py => dcrl_emitter.py} | 134 +++++++++--------- .../{dcg_me_emitter.py => dcrl_me_emitter.py} | 22 +-- qdax/core/emitters/qpg_emitter.py | 6 +- qdax/core/neuroevolution/buffers/buffer.py | 10 +- 4 files changed, 86 insertions(+), 86 deletions(-) rename qdax/core/emitters/{qdcg_emitter.py => dcrl_emitter.py} (87%) rename qdax/core/emitters/{dcg_me_emitter.py => dcrl_me_emitter.py} (84%) diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/dcrl_emitter.py similarity index 87% rename from qdax/core/emitters/qdcg_emitter.py rename to qdax/core/emitters/dcrl_emitter.py index 0fb19c4b..e7bb011d 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -1,4 +1,4 @@ -"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm +"""Implements the DCRL Emitter from DCRL-MAP-Elites algorithm in JAX for Brax environments. """ @@ -13,7 +13,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn from qdax.core.neuroevolution.networks.networks import QModuleDC from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @@ -21,10 +21,10 @@ @dataclass -class QualityDCGConfig: - """Configuration for QualityDCG Emitter""" +class DCRLConfig: + """Configuration for DCRL Emitter""" - qpg_batch_size: int = 64 + dcg_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -44,7 +44,7 @@ class QualityDCGConfig: policy_delay: int = 2 -class QualityDCGEmitterState(EmitterState): +class DCRLEmitterState(EmitterState): """Contains training state for the learner.""" critic_params: Params @@ -54,19 +54,19 @@ class QualityDCGEmitterState(EmitterState): target_critic_params: Params target_actor_params: Params replay_buffer: ReplayBuffer - random_key: RNGKey + key: RNGKey steps: jnp.ndarray -class QualityDCGEmitter(Emitter): +class DCRLEmitter(Emitter): """ - A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites - (PGA-Map-Elites) algorithm. + A descriptor-conditioned reinforcement learning emitter used to implement + DCRL-MAP-Elites algorithm. """ def __init__( self, - config: QualityDCGConfig, + config: DCRLConfig, policy_network: nn.Module, actor_network: nn.Module, env: QDEnv, @@ -114,7 +114,7 @@ def batch_size(self) -> int: Returns: the batch size emitted by the emitter. """ - return self._config.qpg_batch_size + self._config.ai_batch_size + return self._config.dcg_batch_size + self._config.ai_batch_size @property def use_all_data(self) -> bool: @@ -127,18 +127,18 @@ def use_all_data(self) -> bool: def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[QualityDCGEmitterState, RNGKey]: + ) -> Tuple[DCRLEmitterState, RNGKey]: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: The initial state of the PGAMEEmitter, a new random key. @@ -149,7 +149,7 @@ def init( action_size = self._env.action_size # Initialise critic, greedy actor and population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) fake_obs = jnp.zeros(shape=(observation_size,)) fake_desc = jnp.zeros(shape=(descriptor_size,)) fake_action = jnp.zeros(shape=(action_size,)) @@ -159,7 +159,7 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) @@ -168,7 +168,7 @@ def init( actor_opt_state = self._actor_optimizer.init(actor_params) # Initialize replay buffer - dummy_transition = DCGTransition.init_dummy( + dummy_transition = DCRLTransition.init_dummy( observation_dim=self._env.observation_size, action_dim=action_size, descriptor_dim=descriptor_size, @@ -191,8 +191,8 @@ def init( replay_buffer = replay_buffer.insert(transitions) # Initial training state - random_key, subkey = jax.random.split(random_key) - emitter_state = QualityDCGEmitterState( + key, subkey = jax.random.split(key) + emitter_state = DCRLEmitterState( critic_params=critic_params, critic_opt_state=critic_opt_state, actor_params=actor_params, @@ -200,11 +200,11 @@ def init( target_critic_params=target_critic_params, target_actor_params=target_actor_params, replay_buffer=replay_buffer, - random_key=subkey, + key=subkey, steps=jnp.array(0), ) - return emitter_state, random_key + return emitter_state, key @partial(jax.jit, static_argnames=("self",)) def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: @@ -281,28 +281,28 @@ def _compute_equivalent_params_with_desc( def emit( self, repertoire: Repertoire, - emitter_state: QualityDCGEmitterState, - random_key: RNGKey, + emitter_state: DCRLEmitterState, + key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. """ # PG emitter - parents_pg, descs_pg, random_key = repertoire.sample_with_descs( - random_key, self._config.qpg_batch_size + parents_pg, descs_pg, key = repertoire.sample_with_descs( + key, self._config.dcg_batch_size ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, random_key = repertoire.sample_with_descs( - random_key, self._config.ai_batch_size + _, descs_ai, key = repertoire.sample_with_descs( + key, self._config.ai_batch_size ) descs_ai = descs_ai.reshape( descs_ai.shape[0], self._env.behavior_descriptor_length @@ -317,7 +317,7 @@ def emit( return ( genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, - random_key, + key, ) @partial( @@ -326,7 +326,7 @@ def emit( ) def emit_pg( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, parents: Genotype, descs: Descriptor, ) -> Genotype: @@ -355,7 +355,7 @@ def emit_pg( static_argnames=("self",), ) def emit_ai( - self, emitter_state: QualityDCGEmitterState, descs: Descriptor + self, emitter_state: DCRLEmitterState, descs: Descriptor ) -> Genotype: """Emit the offsprings generated through pg mutation. @@ -376,7 +376,7 @@ def emit_ai( return offsprings @partial(jax.jit, static_argnames=("self",)) - def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype: """Emit the greedy actor. Simply needs to be retrieved from the emitter state. @@ -396,13 +396,13 @@ def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: ) def state_update( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> QualityDCGEmitterState: + ) -> DCRLEmitterState: """This function gives an opportunity to update the emitter state after the genotypes have been scored. @@ -431,7 +431,7 @@ def state_update( desc_prime = jnp.concatenate( [ extra_scores["desc_prime"], - descriptors[self._config.qpg_batch_size + self._config.ai_batch_size :], + descriptors[self._config.dcg_batch_size + self._config.ai_batch_size :], ], axis=0, ) @@ -449,8 +449,8 @@ def state_update( emitter_state = emitter_state.replace(replay_buffer=replay_buffer) # sample transitions from the replay buffer - random_key, subkey = jax.random.split(emitter_state.random_key) - transitions, random_key = replay_buffer.sample( + key, subkey = jax.random.split(emitter_state.key) + transitions, key = replay_buffer.sample( subkey, self._config.num_critic_training_steps * self._config.batch_size ) transitions = jax.tree_util.tree_map( @@ -468,12 +468,12 @@ def state_update( rewards=self._similarity(transitions.desc, transitions.desc_prime) * transitions.rewards ) - emitter_state = emitter_state.replace(random_key=random_key) + emitter_state = emitter_state.replace(key=key) def scan_train_critics( - carry: QualityDCGEmitterState, - transitions: DCGTransition, - ) -> Tuple[QualityDCGEmitterState, Any]: + carry: DCRLEmitterState, + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state, transitions) return new_emitter_state, () @@ -490,8 +490,8 @@ def scan_train_critics( @partial(jax.jit, static_argnames=("self",)) def _train_critics( - self, emitter_state: QualityDCGEmitterState, transitions: DCGTransition - ) -> QualityDCGEmitterState: + self, emitter_state: DCRLEmitterState, transitions: DCRLTransition + ) -> DCRLEmitterState: """Apply one gradient step to critics and to the greedy actor (contained in carry in training_state), then soft update target critics and target actor. @@ -510,14 +510,14 @@ def _train_critics( critic_opt_state, critic_params, target_critic_params, - random_key, + key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_opt_state=emitter_state.critic_opt_state, transitions=transitions, - random_key=emitter_state.random_key, + key=emitter_state.key, ) # Update greedy actor @@ -550,7 +550,7 @@ def _train_critics( actor_opt_state=actor_opt_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, ) @@ -563,13 +563,13 @@ def _update_critic( target_critic_params: Params, target_actor_params: Params, critic_opt_state: Params, - transitions: DCGTransition, - random_key: RNGKey, + transitions: DCRLTransition, + key: RNGKey, ) -> Tuple[Params, Params, Params, RNGKey]: # compute loss and gradients - random_key, subkey = jax.random.split(random_key) - critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + key, subkey = jax.random.split(key) + critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, @@ -591,7 +591,7 @@ def _update_critic( critic_params, ) - return critic_opt_state, critic_params, target_critic_params, random_key + return critic_opt_state, critic_params, target_critic_params, key @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -600,11 +600,11 @@ def _update_actor( actor_opt_state: optax.OptState, target_actor_params: Params, critic_params: Params, - transitions: DCGTransition, + transitions: DCRLTransition, ) -> Tuple[optax.OptState, Params, Params]: # Update greedy actor - policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + policy_gradient = jax.grad(self._actor_loss_fn)( actor_params, critic_params, transitions, @@ -637,7 +637,7 @@ def _mutation_function_pg( self, policy_params: Genotype, descs: Descriptor, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, ) -> Genotype: """Apply pg mutation to a policy via multiple steps of gradient descent. First, update the rewards to be diversity rewards, then apply the gradient @@ -653,8 +653,8 @@ def _mutation_function_pg( The updated params of the neural network. """ # Get transitions - transitions, random_key = emitter_state.replay_buffer.sample( - emitter_state.random_key, + transitions, key = emitter_state.replay_buffer.sample( + emitter_state.key, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) descs_prime = jnp.tile( @@ -678,16 +678,16 @@ def _mutation_function_pg( transitions, ) - # Replace random_key - emitter_state = emitter_state.replace(random_key=random_key) + # Replace key + emitter_state = emitter_state.replace(key=key) # Define new policy optimizer state policy_opt_state = self._policies_optimizer.init(policy_params) def scan_train_policy( - carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], - transitions: DCGTransition, - ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + carry: Tuple[DCRLEmitterState, Genotype, optax.OptState], + transitions: DCRLTransition, + ) -> Tuple[Tuple[DCRLEmitterState, Genotype, optax.OptState], Any]: emitter_state, policy_params, policy_opt_state = carry ( new_emitter_state, @@ -721,11 +721,11 @@ def scan_train_policy( @partial(jax.jit, static_argnames=("self",)) def _train_policy( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, policy_params: Params, policy_opt_state: optax.OptState, - transitions: DCGTransition, - ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Params, optax.OptState]: """Apply one gradient step to a policy (called policy_params). Args: @@ -752,11 +752,11 @@ def _update_policy( critic_params: Params, policy_opt_state: optax.OptState, policy_params: Params, - transitions: DCGTransition, + transitions: DCRLTransition, ) -> Tuple[optax.OptState, Params]: # compute loss - _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( policy_params, critic_params, transitions, diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcrl_me_emitter.py similarity index 84% rename from qdax/core/emitters/dcg_me_emitter.py rename to qdax/core/emitters/dcrl_me_emitter.py index fea237c6..36a9f03d 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcrl_me_emitter.py @@ -4,18 +4,18 @@ import flax.linen as nn from qdax.core.emitters.multi_emitter import MultiEmitter -from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.dcrl_emitter import DCRLConfig, DCRLEmitter from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv @dataclass -class DCGMEConfig: - """Configuration for DCGME Algorithm""" +class DCRLMEConfig: + """Configuration for DCRL-MAP-Elites Algorithm""" ga_batch_size: int = 128 - qpg_batch_size: int = 64 + dcg_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -36,10 +36,10 @@ class DCGMEConfig: policy_delay: int = 2 -class DCGMEEmitter(MultiEmitter): +class DCRLMEEmitter(MultiEmitter): def __init__( self, - config: DCGMEConfig, + config: DCRLMEConfig, policy_network: nn.Module, actor_network: nn.Module, env: QDEnv, @@ -49,8 +49,8 @@ def __init__( self._env = env self._variation_fn = variation_fn - qdcg_config = QualityDCGConfig( - qpg_batch_size=config.qpg_batch_size, + dcrl_config = DCRLConfig( + dcg_batch_size=config.dcg_batch_size, ai_batch_size=config.ai_batch_size, lengthscale=config.lengthscale, critic_hidden_layer_size=config.critic_hidden_layer_size, @@ -70,8 +70,8 @@ def __init__( ) # define the quality emitter - q_emitter = QualityDCGEmitter( - config=qdcg_config, + dcrl_emitter = DCRLEmitter( + config=dcrl_config, policy_network=policy_network, actor_network=actor_network, env=env, @@ -85,4 +85,4 @@ def __init__( batch_size=config.ga_batch_size, ) - super().__init__(emitters=(q_emitter, ga_emitter)) + super().__init__(emitters=(dcrl_emitter, ga_emitter)) diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c6e2df7e..63373494 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -428,7 +428,7 @@ def _update_critic( # compute loss and gradients random_key, subkey = jax.random.split(random_key) - critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, @@ -463,7 +463,7 @@ def _update_actor( ) -> Tuple[optax.OptState, Params, Params]: # Update greedy actor - policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( actor_params, critic_params, transitions, @@ -595,7 +595,7 @@ def _update_policy( ) -> Tuple[optax.OptState, Params]: # compute loss - _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( policy_params, critic_params, transitions, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 5057e5e2..81f1e896 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -270,7 +270,7 @@ def init_dummy( # type: ignore return dummy_transition -class DCGTransition(QDTransition): +class DCRLTransition(QDTransition): """Stores data corresponding to a transition collected by a QD algorithm.""" desc: Descriptor @@ -325,8 +325,8 @@ def flatten(self) -> jnp.ndarray: def from_flatten( cls, flattened_transition: jnp.ndarray, - transition: QDTransition, - ) -> QDTransition: + transition: DCRLTransition, + ) -> DCRLTransition: """ Creates a transition from a flattened transition in a jnp.ndarray. Args: @@ -394,7 +394,7 @@ def from_flatten( @classmethod def init_dummy( # type: ignore cls, observation_dim: int, action_dim: int, descriptor_dim: int - ) -> QDTransition: + ) -> DCRLTransition: """ Initialize a dummy transition that then can be passed to constructors to get all shapes right. @@ -404,7 +404,7 @@ def init_dummy( # type: ignore Returns: a dummy transition """ - dummy_transition = DCGTransition( + dummy_transition = DCRLTransition( obs=jnp.zeros(shape=(1, observation_dim)), next_obs=jnp.zeros(shape=(1, observation_dim)), rewards=jnp.zeros(shape=(1,)),