diff --git a/qdax/utils/uncertainty_metrics.py b/qdax/utils/uncertainty_metrics.py new file mode 100644 index 00000000..2dd61c18 --- /dev/null +++ b/qdax/utils/uncertainty_metrics.py @@ -0,0 +1,323 @@ +from functools import partial +from typing import Callable, Tuple + +import jax +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.utils.sampling import ( + dummy_extra_scores_extractor, + median, + multi_sample_scoring_function, + std, +) + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "fitness_extractor", + "descriptor_extractor", + "extra_scores_extractor", + "scan_size", + ), +) +def reevaluation_function( + repertoire: MapElitesRepertoire, + random_key: RNGKey, + empty_corrected_repertoire: MapElitesRepertoire, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + scan_size: int = 0, +) -> Tuple[MapElitesRepertoire, RNGKey]: + """ + Perform reevaluation of a repertoire and construct a corrected repertoire from it. + + Args: + repertoire: repertoire to reevaluate. + empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, + allow to use a different type of repertoire than the one from the algorithm. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + fitness_extractor: function to extract the final fitness from + multiple samples of the same solution (default: median). + descriptor_extractor: function to extract the final descriptor from + multiple samples of the same solution (default: median). + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same solution (default: no effect). + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The corrected repertoire and a random key. + """ + + # If no reevaluations, return copies of the original container + if num_reevals == 0: + return repertoire, random_key + + # Perform reevaluation + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = _perform_reevaluation( + policies_params=repertoire.genotypes, + random_key=random_key, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + scan_size=scan_size, + ) + + # Extract the final scores + extra_scores = extra_scores_extractor(all_extra_scores, num_reevals) + fitnesses = fitness_extractor(all_fitnesses) + descriptors = descriptor_extractor(all_descriptors) + + # Set -inf fitness for all unexisting indivs + fitnesses = jnp.where(repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses) + + # Fill-in the corrected repertoire + corrected_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=descriptors, + batch_of_fitnesses=fitnesses, + batch_of_extra_scores=extra_scores, + ) + + return corrected_repertoire, random_key + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "fitness_extractor", + "fitness_reproducibility_extractor", + "descriptor_extractor", + "descriptor_reproducibility_extractor", + "extra_scores_extractor", + "scan_size", + ), +) +def reevaluation_reproducibility_function( + repertoire: MapElitesRepertoire, + random_key: RNGKey, + empty_corrected_repertoire: MapElitesRepertoire, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = median, + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + scan_size: int = 0, +) -> Tuple[MapElitesRepertoire, MapElitesRepertoire, MapElitesRepertoire, RNGKey]: + """ + Perform reevaluation of a repertoire and construct a corrected repertoire and a + reproducibility repertoire from it. + + Args: + repertoire: repertoire to reevaluate. + empty_corrected_repertoire: repertoire to be filled with reevaluated solutions, + allow to use a different type of repertoire than the one from the algorithm. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + fitness_extractor: function to extract the final fitness from + multiple samples of the same solution (default: median). + fitness_reproducibility_extractor: function to extract the fitness + reproducibility from multiple samples of the same solution (default: std). + descriptor_extractor: function to extract the final descriptor from + multiple samples of the same solution (default: median). + descriptor_reproducibility_extractor: function to extract the descriptor + reproducibility from multiple samples of the same solution (default: std). + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same solution (default: no effect). + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The corrected repertoire. + A repertoire storing reproducibility in fitness. + A repertoire storing reproducibility in descriptor. + A random key. + """ + + # If no reevaluations, return copies of the original container + if num_reevals == 0: + return ( + repertoire, + repertoire, + repertoire, + random_key, + ) + + # Perform reevaluation + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = _perform_reevaluation( + policies_params=repertoire.genotypes, + random_key=random_key, + scoring_fn=scoring_fn, + num_reevals=num_reevals, + scan_size=scan_size, + ) + + # Extract the final scores + extra_scores = extra_scores_extractor(all_extra_scores, num_reevals) + fitnesses = fitness_extractor(all_fitnesses) + fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses) + descriptors = descriptor_extractor(all_descriptors) + descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors) + + # WARNING: in the case of descriptors_reproducibility, take average over dimensions + descriptors_reproducibility = jnp.average(descriptors_reproducibility, axis=-1) + + # Set -inf fitness for all unexisting indivs + fitnesses = jnp.where(repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses) + fitnesses_reproducibility = jnp.where( + repertoire.fitnesses == -jnp.inf, -jnp.inf, fitnesses_reproducibility + ) + descriptors_reproducibility = jnp.where( + repertoire.fitnesses == -jnp.inf, -jnp.inf, descriptors_reproducibility + ) + + # Fill-in corrected repertoire + corrected_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=descriptors, + batch_of_fitnesses=fitnesses, + batch_of_extra_scores=extra_scores, + ) + + # Fill-in fit_reproducibility repertoire + fit_reproducibility_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=repertoire.descriptors, + batch_of_fitnesses=fitnesses_reproducibility, + batch_of_extra_scores=extra_scores, + ) + + # Fill-in desc_reproducibility repertoire + desc_reproducibility_repertoire = empty_corrected_repertoire.add( + batch_of_genotypes=repertoire.genotypes, + batch_of_descriptors=repertoire.descriptors, + batch_of_fitnesses=descriptors_reproducibility, + batch_of_extra_scores=extra_scores, + ) + + return ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_reevals", + "scan_size", + ), +) +def _perform_reevaluation( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_reevals: int, + scan_size: int = 0, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Sub-function used to perform reevaluation of a repertoire in uncertain applications. + + Args: + policies_params: genotypes to reevaluate. + random_key: JAX random key. + scoring_fn: scoring function used for evaluation. + num_reevals: number of samples to generate for each individual. + scan_size: allow to split the reevaluations in multiple batch to reduce + the memory load of the reevaluation. + Returns: + The fitnesses, descriptors and extra score from the reevaluation, + and a randon key. + """ + + # If no need for scan, call the sampling function + if scan_size == 0: + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params=policies_params, + random_key=random_key, + scoring_fn=scoring_fn, + num_samples=num_reevals, + ) + + # If need for scan, call the sampling function multiple times + else: + + # Ensure that num_reevals is a multiple of scan_size + assert ( + num_reevals % scan_size == 0 + ), "num_reevals should be a multiple of scan_size to be able to scan." + num_loops = num_reevals // scan_size + + def _sampling_scan( + random_key: RNGKey, + unused: Tuple[()], + ) -> Tuple[Tuple[RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params=policies_params, + random_key=random_key, + scoring_fn=scoring_fn, + num_samples=scan_size, + ) + return (random_key), ( + all_fitnesses, + all_descriptors, + all_extra_scores, + ) + + (random_key), ( + all_fitnesses, + all_descriptors, + all_extra_scores, + ) = jax.lax.scan(_sampling_scan, (random_key), (), length=num_loops) + all_fitnesses = jnp.hstack(all_fitnesses) + all_descriptors = jnp.hstack(all_descriptors) + + return all_fitnesses, all_descriptors, all_extra_scores, random_key diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py new file mode 100644 index 00000000..d49e2527 --- /dev/null +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -0,0 +1,190 @@ +import functools + +import jax +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + compute_cvt_centroids, +) +from qdax.tasks.arm import arm_scoring_function, noisy_arm_scoring_function +from qdax.utils.uncertainty_metrics import ( + reevaluation_function, + reevaluation_reproducibility_function, +) + + +def test_uncertainty_metrics() -> None: + seed = 42 + num_reevals = 512 + scan_size = 128 + batch_size = 512 + num_init_cvt_samples = 50000 + num_centroids = 1024 + genotype_dim = 8 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # First, init a deterministic environment + init_policies = jax.random.uniform( + random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 + ) + fitnesses, descriptors, extra_scores, random_key = arm_scoring_function( + init_policies, random_key + ) + + # Initialise a container + centroids, random_key = compute_cvt_centroids( + num_descriptors=2, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=jnp.array([0.0, 0.0]), + maxval=jnp.array([1.0, 1.0]), + random_key=random_key, + ) + repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Initialise an empty container for corrected repertoire + fitnesses = jnp.full_like(fitnesses, -jnp.inf) + empty_corrected_repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Test that reevaluation_function accurately predicts no change + corrected_repertoire, random_key = reevaluation_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=arm_scoring_function, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + + # Test that scanned reevaluation_function accurately predicts no change + corrected_repertoire, random_key = reevaluation_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=arm_scoring_function, + num_reevals=num_reevals, + random_key=random_key, + scan_size=scan_size, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + + # Test that reevaluation_reproducibility_function accurately predicts no change + ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) = reevaluation_reproducibility_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=arm_scoring_function, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume( + jnp.allclose( + corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05 + ) + ) + zero_fitnesses = jnp.where( + repertoire.fitnesses > -jnp.inf, + 0.0, + -jnp.inf, + ) + pytest.assume( + jnp.allclose( + fit_reproducibility_repertoire.fitnesses, + zero_fitnesses, + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + desc_reproducibility_repertoire.fitnesses, + zero_fitnesses, + rtol=1e-05, + atol=1e-05, + ) + ) + + # Second, init a stochastic environment + init_policies = jax.random.uniform( + random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 + ) + noisy_scoring_function = functools.partial( + noisy_arm_scoring_function, + fit_variance=0.01, + desc_variance=0.01, + params_variance=0.0, + ) + fitnesses, descriptors, extra_scores, random_key = noisy_scoring_function( + init_policies, random_key + ) + + # Initialise a container + centroids, random_key = compute_cvt_centroids( + num_descriptors=2, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=jnp.array([0.0, 0.0]), + maxval=jnp.array([1.0, 1.0]), + random_key=random_key, + ) + repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Initialise an empty container for corrected repertoire + fitnesses = jnp.full_like(fitnesses, -jnp.inf) + empty_corrected_repertoire = MapElitesRepertoire.init( + genotypes=init_policies, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # Test that reevaluation_function runs and keeps at least one solution + ( + corrected_repertoire, + fit_reproducibility_repertoire, + desc_reproducibility_repertoire, + random_key, + ) = reevaluation_reproducibility_function( + repertoire=repertoire, + empty_corrected_repertoire=empty_corrected_repertoire, + scoring_fn=noisy_scoring_function, + num_reevals=num_reevals, + random_key=random_key, + ) + pytest.assume(jnp.any(corrected_repertoire.fitnesses > -jnp.inf)) + pytest.assume(jnp.any(fit_reproducibility_repertoire.fitnesses > -jnp.inf)) + pytest.assume(jnp.any(desc_reproducibility_repertoire.fitnesses > -jnp.inf))