Skip to content

Commit

Permalink
DCG-MAP-Elites
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Dec 21, 2023
1 parent 82c0437 commit c6d33a6
Show file tree
Hide file tree
Showing 21 changed files with 1,471 additions and 141 deletions.
28 changes: 28 additions & 0 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,34 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

return samples, random_key

@partial(jax.jit, static_argnames=("num_samples",))
def sample_with_descs(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""

repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
descs = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.descriptors,
)

return samples, descs, random_key

@jax.jit
def add(
self,
Expand Down
12 changes: 9 additions & 3 deletions qdax/core/emitters/cma_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,20 @@ def batch_size(self) -> int:

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
Expand Down Expand Up @@ -154,7 +160,7 @@ def emit(
cmaes_state=emitter_state.cmaes_state, random_key=random_key
)

return offsprings, random_key
return offsprings, {}, random_key

@partial(
jax.jit,
Expand Down
14 changes: 10 additions & 4 deletions qdax/core/emitters/cma_mega_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,20 @@ def __init__(

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAMEGAState, RNGKey]:
"""
Initializes the CMA-MEGA emitter.
Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
Expand All @@ -117,7 +123,7 @@ def init(
# define init theta as 0
theta = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x[:1, ...]),
init_genotypes,
genotypes,
)

# score it
Expand Down Expand Up @@ -181,7 +187,7 @@ def emit(
# Compute new candidates
new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)

return new_thetas, random_key
return new_thetas, {}, random_key

@partial(
jax.jit,
Expand Down
14 changes: 10 additions & 4 deletions qdax/core/emitters/cma_pool_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ def batch_size(self) -> int:

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAPoolEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
Expand All @@ -67,7 +73,7 @@ def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
random_key = carry
emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
emitter_state, random_key = self._emitter.init(genotypes, random_key)
return random_key, emitter_state

# init all the emitter states
Expand Down Expand Up @@ -115,7 +121,7 @@ def emit(
repertoire, used_emitter_state, random_key
)

return offsprings, random_key
return offsprings, {}, random_key

@partial(
jax.jit,
Expand Down
88 changes: 88 additions & 0 deletions qdax/core/emitters/dcg_me_emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Callable, Tuple

import flax.linen as nn

from qdax.core.emitters.multi_emitter import MultiEmitter
from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.environments.base_wrappers import QDEnv
from qdax.types import Params, RNGKey


@dataclass
class DCGMEConfig:
"""Configuration for DCGME Algorithm"""

ga_batch_size: int = 128
qpg_batch_size: int = 64
ai_batch_size: int = 64
lengthscale: float = 0.1

# PG emitter
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
num_critic_training_steps: int = 3000
num_pg_training_steps: int = 150
batch_size: int = 100
replay_buffer_size: int = 1_000_000
discount: float = 0.99
reward_scaling: float = 1.0
critic_learning_rate: float = 3e-4
actor_learning_rate: float = 3e-4
policy_learning_rate: float = 1e-3
noise_clip: float = 0.5
policy_noise: float = 0.2
soft_tau_update: float = 0.005
policy_delay: int = 2


class DCGMEEmitter(MultiEmitter):
def __init__(
self,
config: DCGMEConfig,
policy_network: nn.Module,
actor_network: nn.Module,
env: QDEnv,
variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
) -> None:
self._config = config
self._env = env
self._variation_fn = variation_fn

qdcg_config = QualityDCGConfig(
qpg_batch_size=config.qpg_batch_size,
ai_batch_size=config.ai_batch_size,
lengthscale=config.lengthscale,
critic_hidden_layer_size=config.critic_hidden_layer_size,
num_critic_training_steps=config.num_critic_training_steps,
num_pg_training_steps=config.num_pg_training_steps,
batch_size=config.batch_size,
replay_buffer_size=config.replay_buffer_size,
discount=config.discount,
reward_scaling=config.reward_scaling,
critic_learning_rate=config.critic_learning_rate,
actor_learning_rate=config.actor_learning_rate,
policy_learning_rate=config.policy_learning_rate,
noise_clip=config.noise_clip,
policy_noise=config.policy_noise,
soft_tau_update=config.soft_tau_update,
policy_delay=config.policy_delay,
)

# define the quality emitter
q_emitter = QualityDCGEmitter(
config=qdcg_config,
policy_network=policy_network,
actor_network=actor_network,
env=env
)

# define the GA emitter
ga_emitter = MixingEmitter(
mutation_fn=lambda x, r: (x, r),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=config.ga_batch_size,
)

super().__init__(emitters=(q_emitter, ga_emitter))
16 changes: 11 additions & 5 deletions qdax/core/emitters/dpg_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import optax

from qdax.core.containers.archive import Archive
from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.repertoire import MapElitesRepertoire
from qdax.core.emitters.qpg_emitter import (
QualityPGConfig,
QualityPGEmitter,
Expand Down Expand Up @@ -77,20 +77,26 @@ def __init__(
self._score_novelty = score_novelty

def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[DiversityPGEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""

# init elements of diversity emitter state with QualityEmitterState.init()
diversity_emitter_state, random_key = super().init(init_genotypes, random_key)
diversity_emitter_state, random_key = super().init(genotypes, random_key)

# store elements in a dictionary
attributes_dict = vars(diversity_emitter_state)
Expand All @@ -116,7 +122,7 @@ def init(
def state_update(
self,
emitter_state: DiversityPGEmitterState,
repertoire: Optional[Repertoire],
repertoire: Optional[MapElitesRepertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
Expand Down
8 changes: 7 additions & 1 deletion qdax/core/emitters/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ class EmitterState(PyTreeNode):

class Emitter(ABC):
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
self,
random_key: RNGKey,
repertoire: Repertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[Optional[EmitterState], RNGKey]:
"""Initialises the state of the emitter. Some emitters do
not need a state, in which case, the value None can be
Expand Down
24 changes: 15 additions & 9 deletions qdax/core/emitters/mees_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,26 +236,32 @@ def batch_size(self) -> int:
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1:
genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
genotypes,
)

# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
initial_optimizer_state = self._optimizer.init(genotypes)

# Create empty Novelty archive
if self._config.use_explore:
Expand All @@ -270,7 +276,7 @@ def init(
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
Expand All @@ -280,7 +286,7 @@ def init(
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
offspring=genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
Expand Down Expand Up @@ -313,7 +319,7 @@ def emit(
a new jax PRNG key
"""

return emitter_state.offspring, random_key
return emitter_state.offspring, {}, random_key

@partial(
jax.jit,
Expand Down
Loading

0 comments on commit c6d33a6

Please sign in to comment.