Skip to content

Commit

Permalink
fix: some typing and names redefinition issues
Browse files Browse the repository at this point in the history
  • Loading branch information
manon-but-yes committed Aug 21, 2024
1 parent 6fe9efb commit d8ee819
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions tests/utils_test/uncertainty_metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Tuple
from typing import Callable, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -32,15 +32,15 @@ def test_uncertainty_metrics() -> None:
random_key = jax.random.PRNGKey(seed)

# First, init a deterministic environment
scoring_fn = arm_scoring_function
genotype_dim = 8

# Init policies
init_policies = jax.random.uniform(
random_key, shape=(batch_size, 8), minval=0, maxval=1
random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1
)

# Evaluate in the deterministic environment
fitnesses, descriptors, extra_scores, random_key = scoring_fn(
fitnesses, descriptors, extra_scores, random_key = arm_scoring_function(
init_policies, random_key
)

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_uncertainty_metrics() -> None:
corrected_repertoire, random_key = reevaluation_function(
repertoire=repertoire,
empty_corrected_repertoire=empty_corrected_repertoire,
scoring_fn=scoring_fn,
scoring_fn=arm_scoring_function,
num_reevals=num_reevals,
random_key=random_key,
)
Expand All @@ -94,7 +94,7 @@ def test_uncertainty_metrics() -> None:
) = reevaluation_reproducibility_function(
repertoire=repertoire,
empty_corrected_repertoire=empty_corrected_repertoire,
scoring_fn=scoring_fn,
scoring_fn=arm_scoring_function,
num_reevals=num_reevals,
random_key=random_key,
)
Expand Down Expand Up @@ -178,7 +178,7 @@ def play_step_fn(

# Create the scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
brax_scoring_fn: Callable = functools.partial(
scoring_function_brax_envs,
init_states=init_states,
episode_length=episode_length,
Expand All @@ -187,7 +187,7 @@ def play_step_fn(
)

# Evaluate in the Brax environment
fitnesses, descriptors, extra_scores, random_key = scoring_fn(
fitnesses, descriptors, extra_scores, random_key = brax_scoring_fn(
init_policies, random_key
)

Expand Down Expand Up @@ -220,11 +220,9 @@ def play_step_fn(
)

# Test that reevaluation_function runs and keeps at least one solution
keys = jnp.repeat(
jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0
)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0)
init_states = reset_fn(keys)
reeval_scoring_fn = functools.partial(
reeval_brax_scoring_fn: Callable = functools.partial(
scoring_function_brax_envs,
init_states=init_states,
episode_length=episode_length,
Expand All @@ -234,7 +232,7 @@ def play_step_fn(
corrected_repertoire, random_key = reevaluation_function(
repertoire=repertoire,
empty_corrected_repertoire=empty_corrected_repertoire,
scoring_fn=reeval_scoring_fn,
scoring_fn=reeval_brax_scoring_fn,
num_reevals=num_reevals,
random_key=random_key,
)
Expand Down

0 comments on commit d8ee819

Please sign in to comment.