From a97a95f8184c822126393e26718554e69251d302 Mon Sep 17 00:00:00 2001 From: templierpaul Date: Mon, 20 Nov 2023 15:49:06 +0100 Subject: [PATCH] Added evosax in requirements Added docstrings --- qdax/core/emitters/cma_me_policies.py | 35 +++++++------ qdax/core/emitters/evosax_cma_me.py | 73 +++++++++++++++------------ requirements.txt | 1 + 3 files changed, 61 insertions(+), 48 deletions(-) diff --git a/qdax/core/emitters/cma_me_policies.py b/qdax/core/emitters/cma_me_policies.py index ee4440b8..96ba4504 100644 --- a/qdax/core/emitters/cma_me_policies.py +++ b/qdax/core/emitters/cma_me_policies.py @@ -26,6 +26,7 @@ from evosax import Strategies except: import warnings + warnings.warn("evosax not installed, custom CMA_ME will not work") from qdax.core.emitters.termination import cma_criterion @@ -38,7 +39,8 @@ EvosaxCMAOptimizingEmitter, EvosaxCMARndEmitter, EvosaxCMARndEmitterState, -) +) + def net_shape(net): return jax.tree_map(lambda x: x.shape, net) @@ -51,14 +53,14 @@ def __init__( centroids: Centroid, min_count: Optional[int] = None, max_count: Optional[float] = None, - es_params = None, + es_params=None, es_type="Sep_CMA_ES", ): """ Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Behavior Space" by Fontaine et al. - This implementation relies on the Evosax library for ES and adds a wrapper to optimize + This implementation relies on the Evosax library for ES and adds a wrapper to optimize QDax neural networks. Args: @@ -119,7 +121,7 @@ def init( lambda x: x[0], init_genotypes, ) - + self.reshaper = QDaxReshaper.init(init_genotypes) self.es = Strategies[self.es_type]( @@ -135,16 +137,14 @@ def init( # Initialize the ES state random_key, init_key = jax.random.split(random_key) es_params = self.es.default_params - es_state = self.es.initialize( - init_key, params=es_params - ) + es_state = self.es.initialize(init_key, params=es_params) # return the initial state random_key, subkey = jax.random.split(random_key) return ( EvosaxCMAEmitterState( random_key=subkey, - es_state=es_state, + es_state=es_state, es_params=es_params, previous_fitnesses=default_fitnesses, emit_count=0, @@ -258,18 +258,17 @@ def _update_and_init_emitter_state( new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) es_state = emitter_state.es_state.replace( - mean = new_mean, + mean=new_mean, ) - emitter_state = emitter_state.replace( - es_state=es_state, emit_count=0 - ) + emitter_state = emitter_state.replace(es_state=es_state, emit_count=0) return emitter_state, random_key class PolicyCMAPoolEmitter(CMAPoolEmitter): """CMA-ME pool emitter for policies""" + def init( self, init_genotypes: Genotype, random_key: RNGKey ) -> Tuple[CMAPoolEmitterState, RNGKey]: @@ -304,14 +303,20 @@ def init( class PolicyCMAOptimizingEmitter(CMAMEPolicies, EvosaxCMAOptimizingEmitter): + """CMA-ME optimizing emitter for policies""" + pass class PolicyCMAImprovementEmitter(CMAMEPolicies, EvosaxCMAImprovementEmitter): + """CMA-ME improvement emitter for policies""" + pass class PolicyCMARndEmitter(CMAMEPolicies, EvosaxCMARndEmitter): + """CMA-ME RND emitter for policies""" + def init( self, init_genotypes: Genotype, random_key: RNGKey ) -> Tuple[CMAEmitterState, RNGKey]: @@ -348,9 +353,7 @@ def init( # Initialize the ES state random_key, init_key = jax.random.split(random_key) es_params = self.es.default_params - es_state = self.es.initialize( - init_key, params=es_params - ) + es_state = self.es.initialize(init_key, params=es_params) # take a random direction random_key, subkey = jax.random.split(random_key) @@ -364,7 +367,7 @@ def init( return ( EvosaxCMARndEmitterState( random_key=subkey, - es_state=es_state, + es_state=es_state, es_params=es_params, previous_fitnesses=default_fitnesses, emit_count=0, diff --git a/qdax/core/emitters/evosax_cma_me.py b/qdax/core/emitters/evosax_cma_me.py index da938b5a..92cebcf2 100644 --- a/qdax/core/emitters/evosax_cma_me.py +++ b/qdax/core/emitters/evosax_cma_me.py @@ -25,8 +25,10 @@ from evosax import EvoState, EvoParams, Strategies except: import warnings + warnings.warn("evosax not installed, custom CMA_ME will not work") + class EvosaxCMAEmitterState(EmitterState): """ Emitter state for the CMA-ME emitter. @@ -48,7 +50,12 @@ class EvosaxCMAEmitterState(EmitterState): previous_fitnesses: Fitness emit_count: int + class EvosaxCMARndEmitterState(EvosaxCMAEmitterState): + """ + Emitter state for the CMA-ME RND emitter. + """ + random_direction: Descriptor @@ -60,7 +67,7 @@ def __init__( centroids: Centroid, min_count: Optional[int] = None, max_count: Optional[float] = None, - es_params = {}, + es_params={}, es_type="CMA_ES", ): """ @@ -107,12 +114,22 @@ def __init__( self.stop_condition = cma_criterion else: self.stop_condition = lambda f, s, p: False - @partial(jax.jit, static_argnames=("self",)) def init( self, init_genotypes: Genotype, random_key: RNGKey ) -> Tuple[CMAEmitterState, RNGKey]: + """ + Initializes the CMA-ME emitter + + Args: + init_genotypes: initial genotypes to add to the grid. + random_key: a random key to handle stochastic operations. + + Returns: + The initial state of the emitter. + """ + # Initialize repertoire with default values num_centroids = self._centroids.shape[0] default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) @@ -120,16 +137,14 @@ def init( # Initialize the ES state random_key, init_key = jax.random.split(random_key) es_params = self.es.default_params - es_state = self.es.initialize( - init_key, params=es_params - ) + es_state = self.es.initialize(init_key, params=es_params) # return the initial state random_key, subkey = jax.random.split(random_key) return ( EvosaxCMAEmitterState( random_key=subkey, - es_state=es_state, + es_state=es_state, es_params=es_params, previous_fitnesses=default_fitnesses, emit_count=0, @@ -165,7 +180,7 @@ def emit( offsprings, es_state = self.es.ask(subkey, es_state, es_params) return offsprings, random_key - + @partial( jax.jit, static_argnames=("self",), @@ -230,40 +245,32 @@ def state_update( reinitialize = ( jnp.all(improvements < 0) * (emit_count > self._min_count) + (emit_count > self._max_count) - + self.stop_condition( - None, - emitter_state.es_state, - emitter_state.es_params - ) + + self.stop_condition(None, emitter_state.es_state, emitter_state.es_params) + flat_criteria_condition ) # If true, draw randomly and re-initialize parameters def update_and_reinit( - operand: Tuple[ - CMAEmitterState, MapElitesRepertoire, int, RNGKey - ], + operand: Tuple[CMAEmitterState, MapElitesRepertoire, int, RNGKey], ) -> Tuple[CMAEmitterState, RNGKey]: return self._update_and_init_emitter_state(*operand) def update_wo_reinit( - operand: Tuple[ - CMAEmitterState, MapElitesRepertoire, int, RNGKey - ], + operand: Tuple[CMAEmitterState, MapElitesRepertoire, int, RNGKey], ) -> Tuple[CMAEmitterState, RNGKey]: """Update the emitter when no reinit event happened. - The QDax implementation with custom CMA-ES bypasses the masked update - of the CMAES, so we remove it too too. + The QDax implementation with custom CMA-ES bypasses the masked update + of the CMAES, so we remove it too too. """ (emitter_state, repertoire, emit_count, random_key) = operand es_state = emitter_state.es_state # Update CMA Parameters - + # Flip the sign of the improvements flipped_sorted_improvements = -sorted_improvements - + es_state = self.es.tell( sorted_candidates, flipped_sorted_improvements, @@ -326,25 +333,29 @@ def _update_and_init_emitter_state( new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype) es_state = emitter_state.es_state.replace( - mean = new_mean, + mean=new_mean, ) - emitter_state = emitter_state.replace( - es_state=es_state, emit_count=0 - ) + emitter_state = emitter_state.replace(es_state=es_state, emit_count=0) return emitter_state, random_key - + class EvosaxCMAOptimizingEmitter(EvosaxCMAMEEmitter, CMAOptimizingEmitter): + """CMA-ME Optimizing Emitter using Evosax""" + pass class EvosaxCMAImprovementEmitter(EvosaxCMAMEEmitter, CMAImprovementEmitter): + """CMA-ME Improvement Emitter using Evosax""" + pass class EvosaxCMARndEmitter(EvosaxCMAMEEmitter, CMARndEmitter): + """CMA-ME RND Emitter using Evosax""" + @partial(jax.jit, static_argnames=("self",)) def init( self, init_genotypes: Genotype, random_key: RNGKey @@ -359,16 +370,14 @@ def init( Returns: The initial state of the emitter. """ - # Initialize repertoire with default values + # Initialize repertoire with default values num_centroids = self._centroids.shape[0] default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) # Initialize the ES state random_key, init_key = jax.random.split(random_key) es_params = self.es.default_params - es_state = self.es.initialize( - init_key, params=es_params - ) + es_state = self.es.initialize(init_key, params=es_params) # take a random direction random_key, direction_key = jax.random.split(random_key) @@ -382,7 +391,7 @@ def init( return ( EvosaxCMARndEmitterState( random_key=subkey, - es_state=es_state, + es_state=es_state, es_params=es_params, previous_fitnesses=default_fitnesses, emit_count=0, diff --git a/requirements.txt b/requirements.txt index 31008a59..338aedfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ absl-py==1.0.0 brax==0.9.2 chex==0.1.83 dm-haiku==0.0.9 +evosax==0.1.4 flax==0.7.4 gym==0.26.2 ipython