Skip to content

Commit

Permalink
Fix typing in dcrlme_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Sep 5, 2024
1 parent f1541e7 commit d7fd89e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 54 deletions.
8 changes: 2 additions & 6 deletions qdax/core/emitters/dcrl_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ def emit(
genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg)

# Actor injection emitter
_, descs_ai, key = repertoire.sample_with_descs(
key, self._config.ai_batch_size
)
_, descs_ai, key = repertoire.sample_with_descs(key, self._config.ai_batch_size)
descs_ai = descs_ai.reshape(
descs_ai.shape[0], self._env.behavior_descriptor_length
)
Expand Down Expand Up @@ -354,9 +352,7 @@ def emit_pg(
jax.jit,
static_argnames=("self",),
)
def emit_ai(
self, emitter_state: DCRLEmitterState, descs: Descriptor
) -> Genotype:
def emit_ai(self, emitter_state: DCRLEmitterState, descs: Descriptor) -> Genotype:
"""Emit the offsprings generated through pg mutation.
Args:
Expand Down
2 changes: 1 addition & 1 deletion qdax/core/emitters/dcrl_me_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import flax.linen as nn

from qdax.core.emitters.multi_emitter import MultiEmitter
from qdax.core.emitters.dcrl_emitter import DCRLConfig, DCRLEmitter
from qdax.core.emitters.multi_emitter import MultiEmitter
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.custom_types import Params, RNGKey
from qdax.environments.base_wrappers import QDEnv
Expand Down
72 changes: 25 additions & 47 deletions tests/baselines_test/dcrlme_test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Any, Dict, Tuple
import functools
import pytest
from typing import Any, Tuple

import jax
import jax.numpy as jnp
import pytest

from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function
from qdax import environments
from qdax.environments import behavior_descriptor_extractor
from qdax.core.map_elites import MAPElites
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.map_elites import MAPElites
from qdax.core.neuroevolution.buffers.buffer import DCRLTransition
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC
from qdax.custom_types import EnvState, Params, RNGKey
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs
from qdax.utils.metrics import default_qd_metrics


Expand All @@ -23,8 +23,8 @@ def test_dcrlme() -> None:

env_name = "ant_omni"
episode_length = 100
min_bd = -30.
max_bd = 30.
min_bd = -30.0
max_bd = 30.0

num_iterations = 5
batch_size = 256
Expand All @@ -48,7 +48,6 @@ def test_dcrlme() -> None:
critic_hidden_layer_size = (256, 256)
num_critic_training_steps = 3000
num_pg_training_steps = 150
pg_batch_size = 100
replay_buffer_size = 1_000_000
discount = 0.99
reward_scaling = 1.0
Expand Down Expand Up @@ -90,15 +89,16 @@ def test_dcrlme() -> None:
final_activation=jnp.tanh,
)


# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

# Define the fonction to play a step with the policy in the environment
def play_step_fn(env_state, policy_params, random_key):
def play_step_fn(
env_state: EnvState, policy_params: Params, random_key: RNGKey
) -> Tuple[EnvState, Params, RNGKey, DCRLTransition]:
actions = policy_network.apply(policy_params, env_state.obs)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)
Expand All @@ -112,52 +112,28 @@ def play_step_fn(env_state, policy_params, random_key):
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc=jnp.zeros(
env.behavior_descriptor_length,
)
* jnp.nan,
desc_prime=jnp.zeros(
env.behavior_descriptor_length,
)
* jnp.nan,
)

return next_state, policy_params, random_key, transition

# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
scoring_function,
reset_based_scoring_function_brax_envs,
episode_length=episode_length,
play_reset_fn=reset_fn,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
)

def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key):
desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc)
actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized)
state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)

transition = DCRLTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
truncations=next_state.info["truncation"],
actions=actions,
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
)

return next_state, actor_dc_params, desc, random_key, transition

# Prepare the scoring function
scoring_actor_dc_fn = jax.jit(functools.partial(
scoring_actor_dc_function,
episode_length=episode_length,
play_reset_fn=reset_fn,
play_step_actor_dc_fn=play_step_actor_dc_fn,
behavior_descriptor_extractor=bd_extraction_fn,
))

# Get minimum reward value to make sure qd_score are positive
reward_offset = 0

Expand Down Expand Up @@ -210,7 +186,9 @@ def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key):
)

# compute initial repertoire
repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)
repertoire, emitter_state, random_key = map_elites.init(
init_params, centroids, random_key
)

@jax.jit
def update_scan_fn(carry: Any, unused: Any) -> Any:
Expand Down

0 comments on commit d7fd89e

Please sign in to comment.