Skip to content

Commit

Permalink
Added evosax in requirements
Browse files Browse the repository at this point in the history
Added docstrings
  • Loading branch information
templierpaul committed Nov 20, 2023
1 parent 9bc1b06 commit a97a95f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 48 deletions.
35 changes: 19 additions & 16 deletions qdax/core/emitters/cma_me_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from evosax import Strategies
except:
import warnings

warnings.warn("evosax not installed, custom CMA_ME will not work")

from qdax.core.emitters.termination import cma_criterion
Expand All @@ -38,7 +39,8 @@
EvosaxCMAOptimizingEmitter,
EvosaxCMARndEmitter,
EvosaxCMARndEmitterState,
)
)


def net_shape(net):
return jax.tree_map(lambda x: x.shape, net)
Expand All @@ -51,14 +53,14 @@ def __init__(
centroids: Centroid,
min_count: Optional[int] = None,
max_count: Optional[float] = None,
es_params = None,
es_params=None,
es_type="Sep_CMA_ES",
):
"""
Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
Rapid Illumination of Behavior Space" by Fontaine et al.
This implementation relies on the Evosax library for ES and adds a wrapper to optimize
This implementation relies on the Evosax library for ES and adds a wrapper to optimize
QDax neural networks.
Args:
Expand Down Expand Up @@ -119,7 +121,7 @@ def init(
lambda x: x[0],
init_genotypes,
)

self.reshaper = QDaxReshaper.init(init_genotypes)

self.es = Strategies[self.es_type](
Expand All @@ -135,16 +137,14 @@ def init(
# Initialize the ES state
random_key, init_key = jax.random.split(random_key)
es_params = self.es.default_params
es_state = self.es.initialize(
init_key, params=es_params
)
es_state = self.es.initialize(init_key, params=es_params)

# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
EvosaxCMAEmitterState(
random_key=subkey,
es_state=es_state,
es_state=es_state,
es_params=es_params,
previous_fitnesses=default_fitnesses,
emit_count=0,
Expand Down Expand Up @@ -258,18 +258,17 @@ def _update_and_init_emitter_state(
new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)

es_state = emitter_state.es_state.replace(
mean = new_mean,
mean=new_mean,
)

emitter_state = emitter_state.replace(
es_state=es_state, emit_count=0
)
emitter_state = emitter_state.replace(es_state=es_state, emit_count=0)

return emitter_state, random_key


class PolicyCMAPoolEmitter(CMAPoolEmitter):
"""CMA-ME pool emitter for policies"""

def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAPoolEmitterState, RNGKey]:
Expand Down Expand Up @@ -304,14 +303,20 @@ def init(


class PolicyCMAOptimizingEmitter(CMAMEPolicies, EvosaxCMAOptimizingEmitter):
"""CMA-ME optimizing emitter for policies"""

pass


class PolicyCMAImprovementEmitter(CMAMEPolicies, EvosaxCMAImprovementEmitter):
"""CMA-ME improvement emitter for policies"""

pass


class PolicyCMARndEmitter(CMAMEPolicies, EvosaxCMARndEmitter):
"""CMA-ME RND emitter for policies"""

def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAEmitterState, RNGKey]:
Expand Down Expand Up @@ -348,9 +353,7 @@ def init(
# Initialize the ES state
random_key, init_key = jax.random.split(random_key)
es_params = self.es.default_params
es_state = self.es.initialize(
init_key, params=es_params
)
es_state = self.es.initialize(init_key, params=es_params)

# take a random direction
random_key, subkey = jax.random.split(random_key)
Expand All @@ -364,7 +367,7 @@ def init(
return (
EvosaxCMARndEmitterState(
random_key=subkey,
es_state=es_state,
es_state=es_state,
es_params=es_params,
previous_fitnesses=default_fitnesses,
emit_count=0,
Expand Down
73 changes: 41 additions & 32 deletions qdax/core/emitters/evosax_cma_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from evosax import EvoState, EvoParams, Strategies
except:
import warnings

warnings.warn("evosax not installed, custom CMA_ME will not work")


class EvosaxCMAEmitterState(EmitterState):
"""
Emitter state for the CMA-ME emitter.
Expand All @@ -48,7 +50,12 @@ class EvosaxCMAEmitterState(EmitterState):
previous_fitnesses: Fitness
emit_count: int


class EvosaxCMARndEmitterState(EvosaxCMAEmitterState):
"""
Emitter state for the CMA-ME RND emitter.
"""

random_direction: Descriptor


Expand All @@ -60,7 +67,7 @@ def __init__(
centroids: Centroid,
min_count: Optional[int] = None,
max_count: Optional[float] = None,
es_params = {},
es_params={},
es_type="CMA_ES",
):
"""
Expand Down Expand Up @@ -107,29 +114,37 @@ def __init__(
self.stop_condition = cma_criterion
else:
self.stop_condition = lambda f, s, p: False


@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAEmitterState, RNGKey]:
"""
Initializes the CMA-ME emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""

# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

# Initialize the ES state
random_key, init_key = jax.random.split(random_key)
es_params = self.es.default_params
es_state = self.es.initialize(
init_key, params=es_params
)
es_state = self.es.initialize(init_key, params=es_params)

# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
EvosaxCMAEmitterState(
random_key=subkey,
es_state=es_state,
es_state=es_state,
es_params=es_params,
previous_fitnesses=default_fitnesses,
emit_count=0,
Expand Down Expand Up @@ -165,7 +180,7 @@ def emit(
offsprings, es_state = self.es.ask(subkey, es_state, es_params)

return offsprings, random_key

@partial(
jax.jit,
static_argnames=("self",),
Expand Down Expand Up @@ -230,40 +245,32 @@ def state_update(
reinitialize = (
jnp.all(improvements < 0) * (emit_count > self._min_count)
+ (emit_count > self._max_count)
+ self.stop_condition(
None,
emitter_state.es_state,
emitter_state.es_params
)
+ self.stop_condition(None, emitter_state.es_state, emitter_state.es_params)
+ flat_criteria_condition
)

# If true, draw randomly and re-initialize parameters
def update_and_reinit(
operand: Tuple[
CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
operand: Tuple[CMAEmitterState, MapElitesRepertoire, int, RNGKey],
) -> Tuple[CMAEmitterState, RNGKey]:
return self._update_and_init_emitter_state(*operand)

def update_wo_reinit(
operand: Tuple[
CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
operand: Tuple[CMAEmitterState, MapElitesRepertoire, int, RNGKey],
) -> Tuple[CMAEmitterState, RNGKey]:
"""Update the emitter when no reinit event happened.
The QDax implementation with custom CMA-ES bypasses the masked update
of the CMAES, so we remove it too too.
The QDax implementation with custom CMA-ES bypasses the masked update
of the CMAES, so we remove it too too.
"""

(emitter_state, repertoire, emit_count, random_key) = operand

es_state = emitter_state.es_state
# Update CMA Parameters

# Flip the sign of the improvements
flipped_sorted_improvements = -sorted_improvements

es_state = self.es.tell(
sorted_candidates,
flipped_sorted_improvements,
Expand Down Expand Up @@ -326,25 +333,29 @@ def _update_and_init_emitter_state(
new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)

es_state = emitter_state.es_state.replace(
mean = new_mean,
mean=new_mean,
)

emitter_state = emitter_state.replace(
es_state=es_state, emit_count=0
)
emitter_state = emitter_state.replace(es_state=es_state, emit_count=0)

return emitter_state, random_key


class EvosaxCMAOptimizingEmitter(EvosaxCMAMEEmitter, CMAOptimizingEmitter):
"""CMA-ME Optimizing Emitter using Evosax"""

pass


class EvosaxCMAImprovementEmitter(EvosaxCMAMEEmitter, CMAImprovementEmitter):
"""CMA-ME Improvement Emitter using Evosax"""

pass


class EvosaxCMARndEmitter(EvosaxCMAMEEmitter, CMARndEmitter):
"""CMA-ME RND Emitter using Evosax"""

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
Expand All @@ -359,16 +370,14 @@ def init(
Returns:
The initial state of the emitter.
"""
# Initialize repertoire with default values
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

# Initialize the ES state
random_key, init_key = jax.random.split(random_key)
es_params = self.es.default_params
es_state = self.es.initialize(
init_key, params=es_params
)
es_state = self.es.initialize(init_key, params=es_params)

# take a random direction
random_key, direction_key = jax.random.split(random_key)
Expand All @@ -382,7 +391,7 @@ def init(
return (
EvosaxCMARndEmitterState(
random_key=subkey,
es_state=es_state,
es_state=es_state,
es_params=es_params,
previous_fitnesses=default_fitnesses,
emit_count=0,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ absl-py==1.0.0
brax==0.9.2
chex==0.1.83
dm-haiku==0.0.9
evosax==0.1.4
flax==0.7.4
gym==0.26.2
ipython
Expand Down

0 comments on commit a97a95f

Please sign in to comment.