-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
82d87c2
commit f1541e7
Showing
1 changed file
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
from typing import Any, Dict, Tuple | ||
import functools | ||
import pytest | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids | ||
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function | ||
from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function | ||
from qdax import environments | ||
from qdax.environments import behavior_descriptor_extractor | ||
from qdax.core.map_elites import MAPElites | ||
from qdax.core.emitters.mutation_operators import isoline_variation | ||
from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter | ||
from qdax.core.neuroevolution.buffers.buffer import DCRLTransition | ||
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC | ||
from qdax.utils.metrics import default_qd_metrics | ||
|
||
|
||
def test_dcrlme() -> None: | ||
seed = 42 | ||
|
||
env_name = "ant_omni" | ||
episode_length = 100 | ||
min_bd = -30. | ||
max_bd = 30. | ||
|
||
num_iterations = 5 | ||
batch_size = 256 | ||
|
||
# Archive | ||
num_init_cvt_samples = 50000 | ||
num_centroids = 1024 | ||
policy_hidden_layer_sizes = (128, 128) | ||
|
||
# DCRL-ME | ||
ga_batch_size = 128 | ||
dcg_batch_size = 64 | ||
ai_batch_size = 64 | ||
lengthscale = 0.1 | ||
|
||
# GA emitter | ||
iso_sigma = 0.005 | ||
line_sigma = 0.05 | ||
|
||
# DCRL emitter | ||
critic_hidden_layer_size = (256, 256) | ||
num_critic_training_steps = 3000 | ||
num_pg_training_steps = 150 | ||
pg_batch_size = 100 | ||
replay_buffer_size = 1_000_000 | ||
discount = 0.99 | ||
reward_scaling = 1.0 | ||
critic_learning_rate = 3e-4 | ||
actor_learning_rate = 3e-4 | ||
policy_learning_rate = 5e-3 | ||
noise_clip = 0.5 | ||
policy_noise = 0.2 | ||
soft_tau_update = 0.005 | ||
policy_delay = 2 | ||
|
||
# Init a random key | ||
random_key = jax.random.PRNGKey(seed) | ||
|
||
# Init environment | ||
env = environments.create(env_name, episode_length=episode_length) | ||
reset_fn = jax.jit(env.reset) | ||
|
||
# Compute the centroids | ||
centroids, random_key = compute_cvt_centroids( | ||
num_descriptors=env.behavior_descriptor_length, | ||
num_init_cvt_samples=num_init_cvt_samples, | ||
num_centroids=num_centroids, | ||
minval=min_bd, | ||
maxval=max_bd, | ||
random_key=random_key, | ||
) | ||
|
||
# Init policy network | ||
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) | ||
policy_network = MLP( | ||
layer_sizes=policy_layer_sizes, | ||
kernel_init=jax.nn.initializers.lecun_uniform(), | ||
final_activation=jnp.tanh, | ||
) | ||
actor_dc_network = MLPDC( | ||
layer_sizes=policy_layer_sizes, | ||
kernel_init=jax.nn.initializers.lecun_uniform(), | ||
final_activation=jnp.tanh, | ||
) | ||
|
||
|
||
# Init population of controllers | ||
random_key, subkey = jax.random.split(random_key) | ||
keys = jax.random.split(subkey, num=batch_size) | ||
fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size)) | ||
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) | ||
|
||
# Define the fonction to play a step with the policy in the environment | ||
def play_step_fn(env_state, policy_params, random_key): | ||
actions = policy_network.apply(policy_params, env_state.obs) | ||
state_desc = env_state.info["state_descriptor"] | ||
next_state = env.step(env_state, actions) | ||
|
||
transition = DCRLTransition( | ||
obs=env_state.obs, | ||
next_obs=next_state.obs, | ||
rewards=next_state.reward, | ||
dones=next_state.done, | ||
truncations=next_state.info["truncation"], | ||
actions=actions, | ||
state_desc=state_desc, | ||
next_state_desc=next_state.info["state_descriptor"], | ||
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, | ||
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, | ||
) | ||
|
||
return next_state, policy_params, random_key, transition | ||
|
||
# Prepare the scoring function | ||
bd_extraction_fn = behavior_descriptor_extractor[env_name] | ||
scoring_fn = functools.partial( | ||
scoring_function, | ||
episode_length=episode_length, | ||
play_reset_fn=reset_fn, | ||
play_step_fn=play_step_fn, | ||
behavior_descriptor_extractor=bd_extraction_fn, | ||
) | ||
|
||
def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key): | ||
desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc) | ||
actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized) | ||
state_desc = env_state.info["state_descriptor"] | ||
next_state = env.step(env_state, actions) | ||
|
||
transition = DCRLTransition( | ||
obs=env_state.obs, | ||
next_obs=next_state.obs, | ||
rewards=next_state.reward, | ||
dones=next_state.done, | ||
truncations=next_state.info["truncation"], | ||
actions=actions, | ||
state_desc=state_desc, | ||
next_state_desc=next_state.info["state_descriptor"], | ||
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, | ||
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, | ||
) | ||
|
||
return next_state, actor_dc_params, desc, random_key, transition | ||
|
||
# Prepare the scoring function | ||
scoring_actor_dc_fn = jax.jit(functools.partial( | ||
scoring_actor_dc_function, | ||
episode_length=episode_length, | ||
play_reset_fn=reset_fn, | ||
play_step_actor_dc_fn=play_step_actor_dc_fn, | ||
behavior_descriptor_extractor=bd_extraction_fn, | ||
)) | ||
|
||
# Get minimum reward value to make sure qd_score are positive | ||
reward_offset = 0 | ||
|
||
# Define a metrics function | ||
metrics_function = functools.partial( | ||
default_qd_metrics, | ||
qd_offset=reward_offset * episode_length, | ||
) | ||
|
||
# Define the DCG-emitter config | ||
dcg_emitter_config = DCRLMEConfig( | ||
ga_batch_size=ga_batch_size, | ||
dcg_batch_size=dcg_batch_size, | ||
ai_batch_size=ai_batch_size, | ||
lengthscale=lengthscale, | ||
critic_hidden_layer_size=critic_hidden_layer_size, | ||
num_critic_training_steps=num_critic_training_steps, | ||
num_pg_training_steps=num_pg_training_steps, | ||
batch_size=batch_size, | ||
replay_buffer_size=replay_buffer_size, | ||
discount=discount, | ||
reward_scaling=reward_scaling, | ||
critic_learning_rate=critic_learning_rate, | ||
actor_learning_rate=actor_learning_rate, | ||
policy_learning_rate=policy_learning_rate, | ||
noise_clip=noise_clip, | ||
policy_noise=policy_noise, | ||
soft_tau_update=soft_tau_update, | ||
policy_delay=policy_delay, | ||
) | ||
|
||
# Get the emitter | ||
variation_fn = functools.partial( | ||
isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma | ||
) | ||
|
||
dcg_emitter = DCRLMEEmitter( | ||
config=dcg_emitter_config, | ||
policy_network=policy_network, | ||
actor_network=actor_dc_network, | ||
env=env, | ||
variation_fn=variation_fn, | ||
) | ||
|
||
# Instantiate MAP Elites | ||
map_elites = MAPElites( | ||
scoring_function=scoring_fn, | ||
emitter=dcg_emitter, | ||
metrics_function=metrics_function, | ||
) | ||
|
||
# compute initial repertoire | ||
repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key) | ||
|
||
@jax.jit | ||
def update_scan_fn(carry: Any, unused: Any) -> Any: | ||
# iterate over grid | ||
repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) | ||
|
||
return (repertoire, emitter_state, random_key), metrics | ||
|
||
# Run the algorithm | ||
( | ||
repertoire, | ||
emitter_state, | ||
random_key, | ||
), metrics = jax.lax.scan( | ||
update_scan_fn, | ||
(repertoire, emitter_state, random_key), | ||
(), | ||
length=num_iterations, | ||
) | ||
|
||
pytest.assume(repertoire is not None) |