Skip to content

Commit

Permalink
Add test for DCRL-ME
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Sep 5, 2024
1 parent 82d87c2 commit f1541e7
Showing 1 changed file with 234 additions and 0 deletions.
234 changes: 234 additions & 0 deletions tests/baselines_test/dcrlme_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from typing import Any, Dict, Tuple
import functools
import pytest

import jax
import jax.numpy as jnp

from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function
from qdax import environments
from qdax.environments import behavior_descriptor_extractor
from qdax.core.map_elites import MAPElites
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter
from qdax.core.neuroevolution.buffers.buffer import DCRLTransition
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC
from qdax.utils.metrics import default_qd_metrics


def test_dcrlme() -> None:
seed = 42

env_name = "ant_omni"
episode_length = 100
min_bd = -30.
max_bd = 30.

num_iterations = 5
batch_size = 256

# Archive
num_init_cvt_samples = 50000
num_centroids = 1024
policy_hidden_layer_sizes = (128, 128)

# DCRL-ME
ga_batch_size = 128
dcg_batch_size = 64
ai_batch_size = 64
lengthscale = 0.1

# GA emitter
iso_sigma = 0.005
line_sigma = 0.05

# DCRL emitter
critic_hidden_layer_size = (256, 256)
num_critic_training_steps = 3000
num_pg_training_steps = 150
pg_batch_size = 100
replay_buffer_size = 1_000_000
discount = 0.99
reward_scaling = 1.0
critic_learning_rate = 3e-4
actor_learning_rate = 3e-4
policy_learning_rate = 5e-3
noise_clip = 0.5
policy_noise = 0.2
soft_tau_update = 0.005
policy_delay = 2

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init environment
env = environments.create(env_name, episode_length=episode_length)
reset_fn = jax.jit(env.reset)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
num_descriptors=env.behavior_descriptor_length,
num_init_cvt_samples=num_init_cvt_samples,
num_centroids=num_centroids,
minval=min_bd,
maxval=max_bd,
random_key=random_key,
)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jnp.tanh,
)
actor_dc_network = MLPDC(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jnp.tanh,
)


# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

# Define the fonction to play a step with the policy in the environment
def play_step_fn(env_state, policy_params, random_key):
actions = policy_network.apply(policy_params, env_state.obs)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)

transition = DCRLTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
truncations=next_state.info["truncation"],
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
)

return next_state, policy_params, random_key, transition

# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
scoring_function,
episode_length=episode_length,
play_reset_fn=reset_fn,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
)

def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key):
desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc)
actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)

transition = DCRLTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
truncations=next_state.info["truncation"],
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
)

return next_state, actor_dc_params, desc, random_key, transition

# Prepare the scoring function
scoring_actor_dc_fn = jax.jit(functools.partial(
scoring_actor_dc_function,
episode_length=episode_length,
play_reset_fn=reset_fn,
play_step_actor_dc_fn=play_step_actor_dc_fn,
behavior_descriptor_extractor=bd_extraction_fn,
))

# Get minimum reward value to make sure qd_score are positive
reward_offset = 0

# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=reward_offset * episode_length,
)

# Define the DCG-emitter config
dcg_emitter_config = DCRLMEConfig(
ga_batch_size=ga_batch_size,
dcg_batch_size=dcg_batch_size,
ai_batch_size=ai_batch_size,
lengthscale=lengthscale,
critic_hidden_layer_size=critic_hidden_layer_size,
num_critic_training_steps=num_critic_training_steps,
num_pg_training_steps=num_pg_training_steps,
batch_size=batch_size,
replay_buffer_size=replay_buffer_size,
discount=discount,
reward_scaling=reward_scaling,
critic_learning_rate=critic_learning_rate,
actor_learning_rate=actor_learning_rate,
policy_learning_rate=policy_learning_rate,
noise_clip=noise_clip,
policy_noise=policy_noise,
soft_tau_update=soft_tau_update,
policy_delay=policy_delay,
)

# Get the emitter
variation_fn = functools.partial(
isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)

dcg_emitter = DCRLMEEmitter(
config=dcg_emitter_config,
policy_network=policy_network,
actor_network=actor_dc_network,
env=env,
variation_fn=variation_fn,
)

# Instantiate MAP Elites
map_elites = MAPElites(
scoring_function=scoring_fn,
emitter=dcg_emitter,
metrics_function=metrics_function,
)

# compute initial repertoire
repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)

@jax.jit
def update_scan_fn(carry: Any, unused: Any) -> Any:
# iterate over grid
repertoire, emitter_state, metrics, random_key = map_elites.update(*carry)

return (repertoire, emitter_state, random_key), metrics

# Run the algorithm
(
repertoire,
emitter_state,
random_key,
), metrics = jax.lax.scan(
update_scan_fn,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)

pytest.assume(repertoire is not None)

0 comments on commit f1541e7

Please sign in to comment.