Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCG-MAP-Elites #166

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,34 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

return samples, random_key

@partial(jax.jit, static_argnames=("num_samples",))
def sample_with_descs(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.

Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled

Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""

repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
descs = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.descriptors,
)

return samples, descs, random_key

@jax.jit
def add(
self,
Expand Down
12 changes: 9 additions & 3 deletions qdax/core/emitters/cma_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,20 @@ def batch_size(self) -> int:

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter


Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.

Returns:
Expand Down Expand Up @@ -154,7 +160,7 @@ def emit(
cmaes_state=emitter_state.cmaes_state, random_key=random_key
)

return offsprings, random_key
return offsprings, {}, random_key

@partial(
jax.jit,
Expand Down
14 changes: 10 additions & 4 deletions qdax/core/emitters/cma_mega_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,20 @@ def __init__(

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAMEGAState, RNGKey]:
"""
Initializes the CMA-MEGA emitter.


Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.

Returns:
Expand All @@ -117,7 +123,7 @@ def init(
# define init theta as 0
theta = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x[:1, ...]),
init_genotypes,
genotypes,
)

# score it
Expand Down Expand Up @@ -181,7 +187,7 @@ def emit(
# Compute new candidates
new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)

return new_thetas, random_key
return new_thetas, {}, random_key

@partial(
jax.jit,
Expand Down
14 changes: 10 additions & 4 deletions qdax/core/emitters/cma_pool_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ def batch_size(self) -> int:

@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[CMAPoolEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter


Args:
init_genotypes: initial genotypes to add to the grid.
genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.

Returns:
Expand All @@ -67,7 +73,7 @@ def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
random_key = carry
emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
emitter_state, random_key = self._emitter.init(genotypes, random_key)
return random_key, emitter_state

# init all the emitter states
Expand Down Expand Up @@ -115,7 +121,7 @@ def emit(
repertoire, used_emitter_state, random_key
)

return offsprings, random_key
return offsprings, {}, random_key

@partial(
jax.jit,
Expand Down
88 changes: 88 additions & 0 deletions qdax/core/emitters/dcg_me_emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Callable, Tuple

import flax.linen as nn

from qdax.core.emitters.multi_emitter import MultiEmitter
from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.environments.base_wrappers import QDEnv
from qdax.types import Params, RNGKey


@dataclass
class DCGMEConfig:
"""Configuration for DCGME Algorithm"""

ga_batch_size: int = 128
qpg_batch_size: int = 64
ai_batch_size: int = 64
lengthscale: float = 0.1

# PG emitter
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
num_critic_training_steps: int = 3000
num_pg_training_steps: int = 150
batch_size: int = 100
replay_buffer_size: int = 1_000_000
discount: float = 0.99
reward_scaling: float = 1.0
critic_learning_rate: float = 3e-4
actor_learning_rate: float = 3e-4
policy_learning_rate: float = 1e-3
noise_clip: float = 0.5
policy_noise: float = 0.2
soft_tau_update: float = 0.005
policy_delay: int = 2


class DCGMEEmitter(MultiEmitter):
def __init__(
self,
config: DCGMEConfig,
policy_network: nn.Module,
actor_network: nn.Module,
env: QDEnv,
variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
) -> None:
self._config = config
self._env = env
self._variation_fn = variation_fn

qdcg_config = QualityDCGConfig(
qpg_batch_size=config.qpg_batch_size,
ai_batch_size=config.ai_batch_size,
lengthscale=config.lengthscale,
critic_hidden_layer_size=config.critic_hidden_layer_size,
num_critic_training_steps=config.num_critic_training_steps,
num_pg_training_steps=config.num_pg_training_steps,
batch_size=config.batch_size,
replay_buffer_size=config.replay_buffer_size,
discount=config.discount,
reward_scaling=config.reward_scaling,
critic_learning_rate=config.critic_learning_rate,
actor_learning_rate=config.actor_learning_rate,
policy_learning_rate=config.policy_learning_rate,
noise_clip=config.noise_clip,
policy_noise=config.policy_noise,
soft_tau_update=config.soft_tau_update,
policy_delay=config.policy_delay,
)

# define the quality emitter
q_emitter = QualityDCGEmitter(
config=qdcg_config,
policy_network=policy_network,
actor_network=actor_network,
env=env
)

# define the GA emitter
ga_emitter = MixingEmitter(
mutation_fn=lambda x, r: (x, r),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=config.ga_batch_size,
)

super().__init__(emitters=(q_emitter, ga_emitter))
16 changes: 11 additions & 5 deletions qdax/core/emitters/dpg_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import optax

from qdax.core.containers.archive import Archive
from qdax.core.containers.repertoire import Repertoire
from qdax.core.containers.repertoire import MapElitesRepertoire
from qdax.core.emitters.qpg_emitter import (
QualityPGConfig,
QualityPGEmitter,
Expand Down Expand Up @@ -77,20 +77,26 @@ def __init__(
self._score_novelty = score_novelty

def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[DiversityPGEmitterState, RNGKey]:
"""Initializes the emitter state.

Args:
init_genotypes: The initial population.
genotypes: The initial population.
random_key: A random key.

Returns:
The initial state of the PGAMEEmitter, a new random key.
"""

# init elements of diversity emitter state with QualityEmitterState.init()
diversity_emitter_state, random_key = super().init(init_genotypes, random_key)
diversity_emitter_state, random_key = super().init(genotypes, random_key)

# store elements in a dictionary
attributes_dict = vars(diversity_emitter_state)
Expand All @@ -116,7 +122,7 @@ def init(
def state_update(
self,
emitter_state: DiversityPGEmitterState,
repertoire: Optional[Repertoire],
repertoire: Optional[MapElitesRepertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
Expand Down
8 changes: 7 additions & 1 deletion qdax/core/emitters/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ class EmitterState(PyTreeNode):

class Emitter(ABC):
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
self,
random_key: RNGKey,
repertoire: Repertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[Optional[EmitterState], RNGKey]:
"""Initialises the state of the emitter. Some emitters do
not need a state, in which case, the value None can be
Expand Down
24 changes: 15 additions & 9 deletions qdax/core/emitters/mees_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,26 +236,32 @@ def batch_size(self) -> int:
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
self,
random_key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.

Args:
init_genotypes: The initial population.
genotypes: The initial population.
random_key: A random key.

Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1:
genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
genotypes,
)

# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
initial_optimizer_state = self._optimizer.init(genotypes)

# Create empty Novelty archive
if self._config.use_explore:
Expand All @@ -270,7 +276,7 @@ def init(
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
Expand All @@ -280,7 +286,7 @@ def init(
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
offspring=genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
Expand Down Expand Up @@ -313,7 +319,7 @@ def emit(
a new jax PRNG key
"""

return emitter_state.offspring, random_key
return emitter_state.offspring, {}, random_key

@partial(
jax.jit,
Expand Down
Loading
Loading