From a63df5b1cabcb3dd22ca112bc8b193bc6e6e7b64 Mon Sep 17 00:00:00 2001 From: Manon Flageat <61653012+manon-but-yes@users.noreply.github.com> Date: Mon, 27 Nov 2023 18:07:56 +0000 Subject: [PATCH] Add multiple variants of sampling extractors (#158) --- qdax/utils/sampling.py | 197 ++++++++++++++++++++++++++++-- tests/utils_test/sampling_test.py | 175 +++++++++++++++++++++----- 2 files changed, 334 insertions(+), 38 deletions(-) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index a25e190f..bf5c1ae4 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -8,6 +8,91 @@ from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +@jax.jit +def average(quantities: jnp.ndarray) -> jnp.ndarray: + """Default expectation extractor using average.""" + return jnp.average(quantities, axis=1) + + +@jax.jit +def median(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor using median. + More robust to outliers than average.""" + return jnp.median(quantities, axis=1) + + +@jax.jit +def mode(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor using mode. + More robust to outliers than average. + WARNING: for multidimensional objects such as descriptor, do + dimension-wise selection. + """ + + def _mode(quantity: jnp.ndarray) -> jnp.ndarray: + + # Ensure correct dimensions for both single and multi-dimension + quantity = jnp.reshape(quantity, (quantity.shape[0], -1)) + + # Dimension-wise voting in case of multi-dimension + def _dim_mode(dim_quantity: jnp.ndarray) -> jnp.ndarray: + unique_vals, counts = jnp.unique( + dim_quantity, return_counts=True, size=dim_quantity.size + ) + return unique_vals[jnp.argmax(counts)] + + # vmap over dimensions + return jnp.squeeze(jax.vmap(_dim_mode)(jnp.transpose(quantity))) + + # vmap over individuals + return jax.vmap(_mode)(quantities) + + +@jax.jit +def closest(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative expectation extractor selecting individual + that has the minimum distance to all other individuals. This + is an approximation of the geometric median. + More robust to outliers than average.""" + + def _closest(values: jnp.ndarray) -> jnp.ndarray: + def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return jnp.sqrt(jnp.sum(jnp.square(x - y))) + + distances = jax.vmap( + jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None) + )(values, values) + return values[jnp.argmin(jnp.mean(distances, axis=0))] + + return jax.vmap(_closest)(quantities) + + +@jax.jit +def std(quantities: jnp.ndarray) -> jnp.ndarray: + """Default reproducibility extractor using standard deviation.""" + return jnp.std(quantities, axis=1) + + +@jax.jit +def mad(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative reproducibility extractor using Median Absolute Deviation. + More robust to outliers than standard deviation.""" + num_samples = quantities.shape[1] + median = jnp.repeat( + jnp.median(quantities, axis=1, keepdims=True), num_samples, axis=1 + ) + return jnp.median(jnp.abs(quantities - median), axis=1) + + +@jax.jit +def iqr(quantities: jnp.ndarray) -> jnp.ndarray: + """Alternative reproducibility extractor using Inter-Quartile Range. + More robust to outliers than standard deviation.""" + q1 = jnp.quantile(quantities, 0.25, axis=1) + q4 = jnp.quantile(quantities, 0.75, axis=1) + return q4 - q1 + + @partial(jax.jit, static_argnames=("num_samples",)) def dummy_extra_scores_extractor( extra_scores: ExtraScores, @@ -89,6 +174,8 @@ def multi_sample_scoring_function( "scoring_fn", "num_samples", "extra_scores_extractor", + "fitness_extractor", + "descriptor_extractor", ), ) def sampling( @@ -102,11 +189,14 @@ def sampling( extra_scores_extractor: Callable[ [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: """Wrap scoring_function to perform sampling. - This function averages the fitnesses and descriptors for each individual - over `num_samples` evaluations. + This function return the expected fitnesses and descriptors for each + individual over `num_samples` evaluations using the provided extractor + function for the fitness and the descriptor. Args: policies_params: policies to evaluate @@ -115,12 +205,17 @@ def sampling( num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from multiple samples of the same policy. + fitness_extractor: function to extract the fitness expectation from + multiple samples of the same policy. + descriptor_extractor: function to extract the descriptor expectation + from multiple samples of the same policy. Returns: - The average fitness and descriptor of the individuals - The extra_score extract from samples with extra_scores_extractor + The expected fitnesses, descriptors and extra_scores of the individuals A new random key """ + + # Perform sampling ( all_fitnesses, all_descriptors, @@ -130,11 +225,95 @@ def sampling( policies_params, random_key, scoring_fn, num_samples ) - # average results - descriptors = jnp.average(all_descriptors, axis=1) - fitnesses = jnp.average(all_fitnesses, axis=1) - - # extract extra scores and add number of evaluations to it + # Extract final scores + descriptors = descriptor_extractor(all_descriptors) + fitnesses = fitness_extractor(all_fitnesses) extra_scores = extra_scores_extractor(all_extra_scores, num_samples) return fitnesses, descriptors, extra_scores, random_key + + +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_samples", + "extra_scores_extractor", + "fitness_extractor", + "descriptor_extractor", + "fitness_reproducibility_extractor", + "descriptor_reproducibility_extractor", + ), +) +def sampling_reproducibility( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_samples: int, + extra_scores_extractor: Callable[ + [ExtraScores, int], ExtraScores + ] = dummy_extra_scores_extractor, + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average, + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std, +) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]: + """Wrap scoring_function to perform sampling and compute the + expectation and reproduciblity. + + This function return the reproducibility of fitnesses and descriptors for each + individual over `num_samples` evaluations using the provided extractor + function for the fitness and the descriptor. + + Args: + policies_params: policies to evaluate + random_key: JAX random key + scoring_fn: scoring function used for evaluation + num_samples: number of samples to generate for each individual + extra_scores_extractor: function to extract the extra_scores from + multiple samples of the same policy. + fitness_extractor: function to extract the fitness expectation from + multiple samples of the same policy. + descriptor_extractor: function to extract the descriptor expectation + from multiple samples of the same policy. + fitness_reproducibility_extractor: function to extract the fitness + reproducibility from multiple samples of the same policy. + descriptor_reproducibility_extractor: function to extract the descriptor + reproducibility from multiple samples of the same policy. + + Returns: + The expected fitnesses, descriptors and extra_scores of the individuals + The fitnesses and descriptors reproducibility of the individuals + A new random key + """ + + # Perform sampling + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params, random_key, scoring_fn, num_samples + ) + + # Extract final scores + descriptors = descriptor_extractor(all_descriptors) + fitnesses = fitness_extractor(all_fitnesses) + extra_scores = extra_scores_extractor(all_extra_scores, num_samples) + + # Extract reproducibility + descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors) + fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses) + + return ( + fitnesses, + descriptors, + extra_scores, + fitnesses_reproducibility, + descriptors_reproducibility, + random_key, + ) diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index a7b2d15d..6ce6cbe9 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -1,5 +1,5 @@ import functools -from typing import Tuple +from typing import Callable, Tuple import jax import jax.numpy as jnp @@ -10,7 +10,17 @@ from qdax.core.neuroevolution.networks.networks import MLP from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey -from qdax.utils.sampling import sampling +from qdax.utils.sampling import ( + average, + closest, + iqr, + mad, + median, + mode, + sampling, + sampling_reproducibility, + std, +) def test_sampling() -> None: @@ -74,7 +84,7 @@ def play_step_fn( keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=1, axis=0) init_states = reset_fn(keys) - # Compare scoring against perforing a single sample + # Create the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function_brax_envs, @@ -83,34 +93,141 @@ def play_step_fn( play_step_fn=play_step_fn, behavior_descriptor_extractor=bd_extraction_fn, ) - scoring_1_sample_fn = functools.partial( - sampling, - scoring_fn=scoring_fn, - num_samples=1, - ) - # Evaluate individuals using the scoring functions - fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) - sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( - init_variables, random_key - ) + # Test function for different extractors + def sampling_test( + fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray], + descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray], + ) -> None: + + # Compare scoring against perforing a single sample + scoring_1_sample_fn = functools.partial( + sampling, + scoring_fn=scoring_fn, + num_samples=1, + fitness_extractor=fitness_extractor, + descriptor_extractor=descriptor_extractor, + ) - # Compare - pytest.assume(jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08)) - pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + # Evaluate individuals using the scoring functions + fitnesses, descriptors, _, _ = scoring_fn(init_variables, random_key) + sample_fitnesses, sample_descriptors, _, _ = scoring_1_sample_fn( + init_variables, random_key + ) - # Compare scoring against perforing multiple samples - scoring_multi_sample_fn = functools.partial( - sampling, - scoring_fn=scoring_fn, - num_samples=sample_number, - ) + # Compare + pytest.assume( + jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08) + ) + pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + + # Compare scoring against perforing multiple samples + scoring_multi_sample_fn = functools.partial( + sampling, + scoring_fn=scoring_fn, + num_samples=sample_number, + fitness_extractor=fitness_extractor, + descriptor_extractor=descriptor_extractor, + ) - # Evaluate individuals using the scoring functions - sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( - init_variables, random_key - ) + # Evaluate individuals using the scoring functions + sample_fitnesses, sample_descriptors, _, _ = scoring_multi_sample_fn( + init_variables, random_key + ) + + # Compare + pytest.assume( + jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08) + ) + pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + + # Call the test for each type of extractor + sampling_test(average, average) + sampling_test(median, median) + sampling_test(mode, mode) + sampling_test(closest, closest) + + # Test function for different reproducibility extractors + def sampling_reproducibility_test( + fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], + descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray], + ) -> None: + + # Compare scoring against perforing a single sample + scoring_1_sample_fn = functools.partial( + sampling_reproducibility, + scoring_fn=scoring_fn, + num_samples=1, + fitness_reproducibility_extractor=fitness_reproducibility_extractor, + descriptor_reproducibility_extractor=descriptor_reproducibility_extractor, + ) + + # Evaluate individuals using the scoring functions + ( + _, + _, + _, + fitnesses_reproducibility, + descriptors_reproducibility, + _, + ) = scoring_1_sample_fn(init_variables, random_key) + + # Compare - all reproducibility should be 0 + pytest.assume( + jnp.allclose( + fitnesses_reproducibility, + jnp.zeros_like(fitnesses_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + descriptors_reproducibility, + jnp.zeros_like(descriptors_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + + # Compare scoring against perforing multiple samples + scoring_multi_sample_fn = functools.partial( + sampling_reproducibility, + scoring_fn=scoring_fn, + num_samples=sample_number, + fitness_reproducibility_extractor=fitness_reproducibility_extractor, + descriptor_reproducibility_extractor=descriptor_reproducibility_extractor, + ) + + # Evaluate individuals using the scoring functions + ( + _, + _, + _, + fitnesses_reproducibility, + descriptors_reproducibility, + _, + ) = scoring_multi_sample_fn(init_variables, random_key) + + # Compare - all reproducibility should be 0 + pytest.assume( + jnp.allclose( + fitnesses_reproducibility, + jnp.zeros_like(fitnesses_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) + pytest.assume( + jnp.allclose( + descriptors_reproducibility, + jnp.zeros_like(descriptors_reproducibility), + rtol=1e-05, + atol=1e-05, + ) + ) - # Compare - pytest.assume(jnp.allclose(descriptors, sample_descriptors, rtol=1e-05, atol=1e-08)) - pytest.assume(jnp.allclose(fitnesses, sample_fitnesses, rtol=1e-05, atol=1e-08)) + # Call the test for each type of extractor + sampling_reproducibility_test(std, std) + sampling_reproducibility_test(mad, mad) + sampling_reproducibility_test(iqr, iqr)