diff --git a/tests/utils_test/uncertainty_metrics_test.py b/tests/utils_test/uncertainty_metrics_test.py index 810006e8..5d1abbea 100644 --- a/tests/utils_test/uncertainty_metrics_test.py +++ b/tests/utils_test/uncertainty_metrics_test.py @@ -1,5 +1,5 @@ import functools -from typing import Tuple +from typing import Callable, Tuple import jax import jax.numpy as jnp @@ -32,15 +32,15 @@ def test_uncertainty_metrics() -> None: random_key = jax.random.PRNGKey(seed) # First, init a deterministic environment - scoring_fn = arm_scoring_function + genotype_dim = 8 # Init policies init_policies = jax.random.uniform( - random_key, shape=(batch_size, 8), minval=0, maxval=1 + random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1 ) # Evaluate in the deterministic environment - fitnesses, descriptors, extra_scores, random_key = scoring_fn( + fitnesses, descriptors, extra_scores, random_key = arm_scoring_function( init_policies, random_key ) @@ -75,7 +75,7 @@ def test_uncertainty_metrics() -> None: corrected_repertoire, random_key = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=scoring_fn, + scoring_fn=arm_scoring_function, num_reevals=num_reevals, random_key=random_key, ) @@ -94,7 +94,7 @@ def test_uncertainty_metrics() -> None: ) = reevaluation_reproducibility_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=scoring_fn, + scoring_fn=arm_scoring_function, num_reevals=num_reevals, random_key=random_key, ) @@ -178,7 +178,7 @@ def play_step_fn( # Create the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] - scoring_fn = functools.partial( + brax_scoring_fn: Callable = functools.partial( scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, @@ -187,7 +187,7 @@ def play_step_fn( ) # Evaluate in the Brax environment - fitnesses, descriptors, extra_scores, random_key = scoring_fn( + fitnesses, descriptors, extra_scores, random_key = brax_scoring_fn( init_policies, random_key ) @@ -220,11 +220,9 @@ def play_step_fn( ) # Test that reevaluation_function runs and keeps at least one solution - keys = jnp.repeat( - jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0 - ) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0) init_states = reset_fn(keys) - reeval_scoring_fn = functools.partial( + reeval_brax_scoring_fn: Callable = functools.partial( scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, @@ -234,7 +232,7 @@ def play_step_fn( corrected_repertoire, random_key = reevaluation_function( repertoire=repertoire, empty_corrected_repertoire=empty_corrected_repertoire, - scoring_fn=reeval_scoring_fn, + scoring_fn=reeval_brax_scoring_fn, num_reevals=num_reevals, random_key=random_key, )