From 83d4201fe4ebda7bb0f97a7a30ba0b5daba80573 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Sun, 22 Sep 2024 19:25:18 +0000 Subject: [PATCH] Rename all arguments named "unused" by "_" --- examples/aurora.ipynb | 2 +- examples/pga_aurora.ipynb | 2 +- qdax/baselines/genetic_algorithm.py | 4 ++-- qdax/core/containers/archive.py | 2 +- qdax/core/distributed_map_elites.py | 2 +- qdax/core/emitters/cma_pool_emitter.py | 6 ++---- qdax/core/emitters/qpg_emitter.py | 4 ++-- qdax/core/map_elites.py | 4 ++-- qdax/utils/uncertainty_metrics.py | 2 +- tests/baselines_test/mees_test.py | 2 +- tests/baselines_test/pgame_test.py | 2 +- tests/baselines_test/qdpg_test.py | 2 +- 12 files changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index b5075356..bb0f403a 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -322,7 +322,7 @@ "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", "@jax.jit\n", - "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + "def update_scan_fn(carry: Any, _: Any) -> Any:\n", " \"\"\"Scan the update function.\"\"\"\n", " repertoire, key, aurora_extra_info = carry\n", "\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index bd4059f9..02239498 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -368,7 +368,7 @@ "centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n", "\n", "@jax.jit\n", - "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + "def update_scan_fn(carry: Any, _: Any) -> Any:\n", " \"\"\"Scan the update function.\"\"\"\n", " (\n", " repertoire,\n", diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index ef23483e..1cef4410 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -131,7 +131,7 @@ def update( def scan_update( self, carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive. @@ -139,7 +139,7 @@ def scan_update( Args: carry: a tuple containing the repertoire, the emitter state and a random key. - unused: unused element, necessary to respect jax.lax.scan API. + _: unused element, necessary to respect jax.lax.scan API. Returns: The updated repertoire and emitter state, with a new random key and metrics. diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index d2e1f812..1590d9eb 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -311,7 +311,7 @@ def top_1(data: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: return data, value, indice def scannable_top_1( - carry: jnp.ndarray, unused: Any + carry: jnp.ndarray, _: Any ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: data = carry data, value, indice = top_1(data) diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index 83bd6b2c..23024d4c 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -191,7 +191,7 @@ def get_distributed_update_fn( @jax.jit def _scan_update( carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.""" diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 9ac52c72..14d58e14 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -45,7 +45,7 @@ def batch_size(self) -> int: Returns: the batch size emitted by the emitter. """ - return self._emitter.batch_size + return self._emitter.batch_size # type: ignore @partial(jax.jit, static_argnames=("self",)) def init( @@ -69,9 +69,7 @@ def init( The initial state of the emitter. """ - def scan_emitter_init( - carry: RNGKey, unused: Any - ) -> Tuple[RNGKey, CMAEmitterState]: + def scan_emitter_init(carry: RNGKey, _: Any) -> Tuple[RNGKey, CMAEmitterState]: key = carry key, subkey = jax.random.split(key) emitter_state = self._emitter.init( diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index 4db06617..87ec3e4e 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -312,7 +312,7 @@ def state_update( emitter_state = emitter_state.replace(replay_buffer=replay_buffer) def scan_train_critics( - carry: QualityPGEmitterState, unused: Any + carry: QualityPGEmitterState, _: Any ) -> Tuple[QualityPGEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state) @@ -501,7 +501,7 @@ def _mutation_function_pg( def scan_train_policy( carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState], - unused: Any, + _: Any, ) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]: emitter_state, policy_params, policy_optimizer_state = carry ( diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index 0039c0df..1cacd4ae 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -154,7 +154,7 @@ def update( def scan_update( self, carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], - unused: Any, + _: Any, ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]: """Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive. @@ -162,7 +162,7 @@ def scan_update( Args: carry: a tuple containing the repertoire, the emitter state and a random key. - unused: unused element, necessary to respect jax.lax.scan API. + _: unused element, necessary to respect jax.lax.scan API. Returns: The updated repertoire and emitter state, with a new random key and metrics. diff --git a/qdax/utils/uncertainty_metrics.py b/qdax/utils/uncertainty_metrics.py index cc39b4f4..1fb2ec0c 100644 --- a/qdax/utils/uncertainty_metrics.py +++ b/qdax/utils/uncertainty_metrics.py @@ -288,7 +288,7 @@ def _perform_reevaluation( def _sampling_scan( key: RNGKey, - unused: Tuple[()], + _: Tuple[()], ) -> Tuple[Tuple[RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]: key, subkey = jax.random.split(key) ( diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 5d17cf4a..924b557c 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -163,7 +163,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid repertoire, emitter_state, key = carry key, subkey = jax.random.split(key) diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 33364f41..5dec3745 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -177,7 +177,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid repertoire, emitter_state, key = carry key, subkey = jax.random.split(key) diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 6eee586c..63e5c0a7 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -227,7 +227,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey) @jax.jit - def update_scan_fn(carry: Any, unused: Any) -> Any: + def update_scan_fn(carry: Any, _: Any) -> Any: # iterate over grid repertoire, emitter_state, key = carry key, subkey = jax.random.split(key)