From 6fe9efb4a392e8e0b8a58af1bd993d058c3c0a7f Mon Sep 17 00:00:00 2001 From: Manon Flageat Date: Mon, 19 Aug 2024 16:52:00 +0000 Subject: [PATCH 1/3] feat: add reevaluation function to compute corrected archives in uncertain domains --- qdax/utils/uncertainty_metrics.py | 318 +++++++++++++++++++ tests/utils_test/uncertainty_metrics_test.py | 241 ++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 qdax/utils/uncertainty_metrics.py create mode 100644 tests/utils_test/uncertainty_metrics_test.py diff --git a/qdax/utils/uncertainty_metrics.py b/qdax/utils/uncertainty_metrics.py new file mode 100644 index 00000000..1a2f615a --- /dev/null +++ b/qdax/utils/uncertainty_metrics.py @@ -0,0 +1,318 @@ +from functools import partial +from typing import Callable, Tuple + +import jax +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.utils.sampling import ( + dummy_extra_scores_extractor, + median, + multi_sample_scoring_function, + std, +) + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "fitness_extractor", + "descriptor_extractor", + "extra_scores_extractor", + "scan_size", + ), +) +def reevaluation_function( + repertoire: MapElitesRepertoire, + random_key: RNGKey, + empty_corrected_repertoire: MapElitesRepertoire, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + scan_size: int = 0, +) -> Tuple[MapElitesRepertoire, RNGKey]: + """ + Perform reevaluation of a repertoire and construct a corrected repertoire from it. + + Args: + repertoire: repertoire to reevaluate. + empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, + allow to use a different type of repertoire than the one from the algorithm. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + fitness_extractor: function to extract the final fitness from + multiple samples of the same solution (default: median). + descriptor_extractor: function to extract the final descriptor from + multiple samples of the same solution (default: median). + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same solution (default: no effect). + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The corrected repertoire and a random key. + """ + + # If no reevaluations, return copies of the original container + if num_reevals == 0: + return repertoire, random_key + + # Perform reevaluation + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = _perform_reevaluation( + policies_params=repertoire.genotypes, + random_key=random_key, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + scan_size=scan_size, + ) + + # Extract the final scores + extra_scores = extra_scores_extractor(all_extra_scores, num_reevals) + fitnesses = fitness_extractor(all_fitnesses) + descriptors = descriptor_extractor(all_descriptors) + + # Set -inf fitness for all unexisting indivs + fitnesses = jnp.where(repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses) + + # Fill-in the corrected repertoire + corrected_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=descriptors, + batch_of_fitnesses=fitnesses, + batch_of_extra_scores=extra_scores, + ) + + return corrected_repertoire, random_key + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "fitness_extractor", + "fitness_reproducibility_extractor", + "descriptor_extractor", + "descriptor_reproducibility_extractor", + "extra_scores_extractor", + "scan_size", + ), +) +def reevaluation_reproducibility_function( + repertoire: MapElitesRepertoire, + random_key: RNGKey, + empty_corrected_repertoire: MapElitesRepertoire, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + scan_size: int = 0, +) -> Tuple[MapElitesRepertoire, MapElitesRepertoire, MapElitesRepertoire, RNGKey]: + """ + Perform reevaluation of a repertoire and construct a corrected repertoire and a + reproducibility repertoire from it. + + Args: + repertoire: repertoire to reevaluate. + empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, + allow to use a different type of repertoire than the one from the algorithm. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + fitness_extractor: function to extract the final fitness from + multiple samples of the same solution (default: median). + fitness_reproducibility_extractor: function to extract the fitness + reproducibility from multiple samples of the same solution (default: std). + descriptor_extractor: function to extract the final descriptor from + multiple samples of the same solution (default: median). + descriptor_reproducibility_extractor: function to extract the descriptor + reproducibility from multiple samples of the same solution (default: std). + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same solution (default: no effect). + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The corrected repertoire. + A repertoire storing reproducibility in fitness. + A repertoire storing reproducibility in descriptor. + A random key. + """ + + # If no reevaluations, return copies of the original container + if num_reevals == 0: + return ( + repertoire, + repertoire, + repertoire, + random_key, + ) + + # Perform reevaluation + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = _perform_reevaluation( + policies_params=repertoire.genotypes, + random_key=random_key, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + scan_size=scan_size, + ) + + # Extract the final scores + extra_scores = extra_scores_extractor(all_extra_scores, num_reevals) + fitnesses = fitness_extractor(all_fitnesses) + fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses) + descriptors = descriptor_extractor(all_descriptors) + descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors) + + # WARNING: in the case of descriptors_reproducibility, take average over dimensions + descriptors_reproducibility = jnp.average(descriptors_reproducibility, axis=-1) + + # Set -inf fitness for all unexisting indivs + fitnesses = jnp.where(repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses) + fitnesses_reproducibility = jnp.where( + repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses_reproducibility + ) + descriptors_reproducibility = jnp.where( + repertoire.fitnesses == -jnp.inf, -jnp.inf, descriptors_reproducibility + ) + + # Fill-in corrected repertoire + corrected_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=descriptors, + batch_of_fitnesses=fitnesses, + batch_of_extra_scores=extra_scores, + ) + + # Fill-in fit_reproducibility repertoire + fit_reproducibility_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=repertoire.descriptors, + batch_of_fitnesses=fitnesses_reproducibility, + batch_of_extra_scores=extra_scores, + ) + + # Fill-in desc_reproducibility repertoire + desc_reproducibility_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=repertoire.descriptors, + batch_of_fitnesses=descriptors_reproducibility, + batch_of_extra_scores=extra_scores, + ) + + return ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "scan_size", + ), +) +def _perform_reevaluation( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + scan_size: int = 0, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Sub-function used to perform reevaluation of a repertoire in uncertain applications. + + Args: + policies_params: genotypes to reevaluate. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The fitnesses, descriptors and extra score from the reevaluation, + and a randon key. + """ + + # If no need for scan, call the sampling function + if scan_size == 0: + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params=policies_params, + random_key=random_key, + scoring_fn=scoring_fn, + num_samples=num_reevals, + ) + + # If need for scan, call the sampling function multiple times + else: + num_loops = num_reevals // scan_size + + def _sampling_scan( + random_key: RNGKey, + unused: Tuple[()], + ) -> Tuple[Tuple[RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params=policies_params, + random_key=random_key, + scoring_fn=scoring_fn, + num_samples=scan_size, + ) + return (random_key), ( + all_fitnesses, + all_descriptors, + all_extra_scores, + ) + + (random_key), ( + all_fitnesses, + all_descriptors, + all_extra_scores, + ) = jax.lax.scan(_sampling_scan, (random_key), (), length=num_loops) + all_fitnesses = jnp.hstack(all_fitnesses) + all_descriptors = jnp.hstack(all_descriptors) + + return all_fitnesses, all_descriptors, all_extra_scores, random_key diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py new file mode 100644 index 00000000..810006e8 --- /dev/null +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -0,0 +1,241 @@ +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + compute_cvt_centroids, +) +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.arm import arm_scoring_function +from qdax.tasks.brax_envs import scoring_function_brax_envs +from qdax.types import EnvState, Params, RNGKey +from qdax.utils.uncertainty_metrics import ( + reevaluation_function, + reevaluation_reproducibility_function, +) + + +def test_uncertainty_metrics() -> None: + seed = 42 + num_reevals = 512 + batch_size = 512 + num_init_cvt_samples = 50000 + num_centroids = 1024 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # First, init a deterministic environment + scoring_fn = arm_scoring_function + + # Init policies + init_policies = jax.random.uniform( + random_key, shape=(batch_size, 8), minval=0, maxval=1 + ) + + # Evaluate in the deterministic environment + fitnesses, descriptors, extra_scores, random_key = scoring_fn( + init_policies, random_key + ) + + # Initialise a container + centroids, random_key = compute_cvt_centroids( + num_descriptors=2, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=jnp.array([0.0, 0.0]), + maxval=jnp.array([1.0, 1.0]), + random_key=random_key, + ) + repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Initialise an empty container for corrected repertoire + fitnesses = jnp.full_like(fitnesses, -jnp.inf) + empty_corrected_repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Test that reevaluation_function accurately predicts no change + corrected_repertoire, random_key = reevaluation_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + + # Test that reevaluation_reproducibility_function accurately predicts no change + ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) = reevaluation_reproducibility_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + zero_fitnesses = jnp.where( + repertoire.fitnesses > -jnp.inf, + 0.0, + -jnp.inf, + ) + pytest.assume( + jnp.allclose( + fit_reproducibility_repertoire.fitnesses, + zero_fitnesses, + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + desc_reproducibility_repertoire.fitnesses, + zero_fitnesses, + rtol=1e-05, + atol=1e-05, + ) + ) + + # Second, init a Brax environment + env_name = "walker2d_uni" + episode_length = 100 + policy_hidden_layer_sizes = (64, 64) + env = environments.create(env_name, episode_length=episode_length) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_policies = jax.vmap(policy_network.init)(keys, fake_batch) + + # Define the fonction to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + # Create the initial environment states for samples and final indivs + reset_fn = jax.jit(jax.vmap(env.reset)) + random_key, subkey = jax.random.split(random_key) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) + init_states = reset_fn(keys) + + # Create the scoring function + bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + scoring_function_brax_envs, + init_states=init_states, + episode_length=episode_length, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Evaluate in the Brax environment + fitnesses, descriptors, extra_scores, random_key = scoring_fn( + init_policies, random_key + ) + + # Initialise a container + min_bd, max_bd = env.behavior_descriptor_limits + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Initialise an empty container for corrected repertoire + fitnesses = jnp.full_like(fitnesses, -jnp.inf) + empty_corrected_repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Test that reevaluation_function runs and keeps at least one solution + keys = jnp.repeat( + jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0 + ) + init_states = reset_fn(keys) + reeval_scoring_fn = functools.partial( + scoring_function_brax_envs, + init_states=init_states, + episode_length=episode_length, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + corrected_repertoire, random_key = reevaluation_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=reeval_scoring_fn, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume(jnp.any(fit_reproducibility_repertoire.fitnesses > -jnp.inf)) From d8ee81971c899960e3257d112e6db866dc8f83d4 Mon Sep 17 00:00:00 2001 From: Manon Flageat Date: Wed, 21 Aug 2024 09:44:54 +0000 Subject: [PATCH 2/3] fix: some typing and names redefinition issues --- tests/utils_test/uncertainty_metrics_test.py | 24 +++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index 810006e8..5d1abbea 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -1,5 +1,5 @@ import functools -from typing import Tuple +from typing import Callable, Tuple import jax import jax.numpy as jnp @@ -32,15 +32,15 @@ def test_uncertainty_metrics() -> None: random_key = jax.random.PRNGKey(seed) # First, init a deterministic environment - scoring_fn = arm_scoring_function + genotype_dim = 8 # Init policies init_policies = jax.random.uniform( - random_key, shape=(batch_size, 8), minval=0, maxval=1 + random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) # Evaluate in the deterministic environment - fitnesses, descriptors, extra_scores, random_key = scoring_fn( + fitnesses, descriptors, extra_scores, random_key = arm_scoring_function( init_policies, random_key ) @@ -75,7 +75,7 @@ def test_uncertainty_metrics() -> None: corrected_repertoire, random_key = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=scoring_fn, + scoring_fn=arm_scoring_function, num_reevals=num_reevals, random_key=random_key, ) @@ -94,7 +94,7 @@ def test_uncertainty_metrics() -> None: ) = reevaluation_reproducibility_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=scoring_fn, + scoring_fn=arm_scoring_function, num_reevals=num_reevals, random_key=random_key, ) @@ -178,7 +178,7 @@ def play_step_fn( # Create the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] - scoring_fn = functools.partial( + brax_scoring_fn: Callable = functools.partial( scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, @@ -187,7 +187,7 @@ def play_step_fn( ) # Evaluate in the Brax environment - fitnesses, descriptors, extra_scores, random_key = scoring_fn( + fitnesses, descriptors, extra_scores, random_key = brax_scoring_fn( init_policies, random_key ) @@ -220,11 +220,9 @@ def play_step_fn( ) # Test that reevaluation_function runs and keeps at least one solution - keys = jnp.repeat( - jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0 - ) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0) init_states = reset_fn(keys) - reeval_scoring_fn = functools.partial( + reeval_brax_scoring_fn: Callable = functools.partial( scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, @@ -234,7 +232,7 @@ def play_step_fn( corrected_repertoire, random_key = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=reeval_scoring_fn, + scoring_fn=reeval_brax_scoring_fn, num_reevals=num_reevals, random_key=random_key, ) From 2e14e2e26bb3a3c202a18c0f1032051205426229 Mon Sep 17 00:00:00 2001 From: Manon Flageat Date: Mon, 2 Sep 2024 13:24:18 +0000 Subject: [PATCH 3/3] tests: include scan_reeval and remove brax environment --- qdax/utils/uncertainty_metrics.py | 5 + tests/utils_test/uncertainty_metrics_test.py | 127 ++++++------------- 2 files changed, 44 insertions(+), 88 deletions(-) diff --git a/qdax/utils/uncertainty_metrics.py b/qdax/utils/uncertainty_metrics.py index 1a2f615a..2dd61c18 100644 --- a/qdax/utils/uncertainty_metrics.py +++ b/qdax/utils/uncertainty_metrics.py @@ -284,6 +284,11 @@ def _perform_reevaluation( # If need for scan, call the sampling function multiple times else: + + # Ensure that num_reevals is a multiple of scan_size + assert ( + num_reevals % scan_size == 0 + ), "num_reevals should be a multiple of scan_size to be able to scan." num_loops = num_reevals // scan_size def _sampling_scan( diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index 5d1abbea..d49e2527 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -1,20 +1,14 @@ import functools -from typing import Callable, Tuple import jax import jax.numpy as jnp import pytest -from qdax import environments from qdax.core.containers.mapelites_repertoire import ( MapElitesRepertoire, compute_cvt_centroids, ) -from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.core.neuroevolution.networks.networks import MLP -from qdax.tasks.arm import arm_scoring_function -from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey +from qdax.tasks.arm import arm_scoring_function, noisy_arm_scoring_function from qdax.utils.uncertainty_metrics import ( reevaluation_function, reevaluation_reproducibility_function, @@ -24,22 +18,19 @@ def test_uncertainty_metrics() -> None: seed = 42 num_reevals = 512 + scan_size = 128 batch_size = 512 num_init_cvt_samples = 50000 num_centroids = 1024 + genotype_dim = 8 # Init a random key random_key = jax.random.PRNGKey(seed) # First, init a deterministic environment - genotype_dim = 8 - - # Init policies init_policies = jax.random.uniform( random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) - - # Evaluate in the deterministic environment fitnesses, descriptors, extra_scores, random_key = arm_scoring_function( init_policies, random_key ) @@ -85,6 +76,21 @@ def test_uncertainty_metrics() -> None: ) ) + # Test that scanned reevaluation_function accurately predicts no change + corrected_repertoire, random_key = reevaluation_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=arm_scoring_function, + num_reevals=num_reevals, + random_key=random_key, + scan_size=scan_size, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + # Test that reevaluation_reproducibility_function accurately predicts no change ( corrected_repertoire, @@ -125,80 +131,27 @@ def test_uncertainty_metrics() -> None: ) ) - # Second, init a Brax environment - env_name = "walker2d_uni" - episode_length = 100 - policy_hidden_layer_sizes = (64, 64) - env = environments.create(env_name, episode_length=episode_length) - - # Init policy network - policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) - policy_network = MLP( - layer_sizes=policy_layer_sizes, - kernel_init=jax.nn.initializers.lecun_uniform(), - final_activation=jnp.tanh, + # Second, init a stochastic environment + init_policies = jax.random.uniform( + random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) - - # Init population of controllers - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) - fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) - init_policies = jax.vmap(policy_network.init)(keys, fake_batch) - - # Define the fonction to play a step with the policy in the environment - def play_step_fn( - env_state: EnvState, - policy_params: Params, - random_key: RNGKey, - ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: - - actions = policy_network.apply(policy_params, env_state.obs) - - state_desc = env_state.info["state_descriptor"] - next_state = env.step(env_state, actions) - - transition = QDTransition( - obs=env_state.obs, - next_obs=next_state.obs, - rewards=next_state.reward, - dones=next_state.done, - actions=actions, - truncations=next_state.info["truncation"], - state_desc=state_desc, - next_state_desc=next_state.info["state_descriptor"], - ) - - return next_state, policy_params, random_key, transition - - # Create the initial environment states for samples and final indivs - reset_fn = jax.jit(jax.vmap(env.reset)) - random_key, subkey = jax.random.split(random_key) - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) - init_states = reset_fn(keys) - - # Create the scoring function - bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] - brax_scoring_fn: Callable = functools.partial( - scoring_function_brax_envs, - init_states=init_states, - episode_length=episode_length, - play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, + noisy_scoring_function = functools.partial( + noisy_arm_scoring_function, + fit_variance=0.01, + desc_variance=0.01, + params_variance=0.0, ) - - # Evaluate in the Brax environment - fitnesses, descriptors, extra_scores, random_key = brax_scoring_fn( + fitnesses, descriptors, extra_scores, random_key = noisy_scoring_function( init_policies, random_key ) # Initialise a container - min_bd, max_bd = env.behavior_descriptor_limits centroids, random_key = compute_cvt_centroids( - num_descriptors=env.behavior_descriptor_length, + num_descriptors=2, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, - minval=min_bd, - maxval=max_bd, + minval=jnp.array([0.0, 0.0]), + maxval=jnp.array([1.0, 1.0]), random_key=random_key, ) repertoire = MapElitesRepertoire.init( @@ -220,20 +173,18 @@ def play_step_fn( ) # Test that reevaluation_function runs and keeps at least one solution - keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0) - init_states = reset_fn(keys) - reeval_brax_scoring_fn: Callable = functools.partial( - scoring_function_brax_envs, - init_states=init_states, - episode_length=episode_length, - play_step_fn=play_step_fn, - behavior_descriptor_extractor=bd_extraction_fn, - ) - corrected_repertoire, random_key = reevaluation_function( + ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) = reevaluation_reproducibility_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=reeval_brax_scoring_fn, + scoring_fn=noisy_scoring_function, num_reevals=num_reevals, random_key=random_key, ) + pytest.assume(jnp.any(corrected_repertoire.fitnesses > -jnp.inf)) pytest.assume(jnp.any(fit_reproducibility_repertoire.fitnesses > -jnp.inf)) + pytest.assume(jnp.any(desc_reproducibility_repertoire.fitnesses > -jnp.inf))