Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Sep 22, 2024
1 parent 0635296 commit fa45483
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 84 deletions.
4 changes: 2 additions & 2 deletions examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down
46 changes: 13 additions & 33 deletions examples/mees.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
},
{
"cell_type": "markdown",
"id": "aa4b43a1",
"id": "16",
"metadata": {},
"source": [
"### Visualize learnt behaviors"
Expand Down
4 changes: 1 addition & 3 deletions qdax/baselines/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,7 @@ def eval_qd_policy_fn(
true_returns = jnp.nansum(transitions.rewards, axis=0)
true_return = jnp.mean(true_returns, axis=-1)

transitions = jax.tree.map(
lambda x: jnp.swapaxes(x, 0, 1), transitions
)
transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions)
masks = jnp.isnan(transitions.rewards)
descriptors = descriptor_extraction_fn(transitions, masks)

Expand Down
4 changes: 1 addition & 3 deletions qdax/baselines/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,7 @@ def eval_qd_policy_fn(
true_returns = jnp.nansum(transitions.rewards, axis=0)
true_return = jnp.mean(true_returns, axis=-1)

transitions = jax.tree.map(
lambda x: jnp.swapaxes(x, 0, 1), transitions
)
transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions)
masks = jnp.isnan(transitions.rewards)
descriptors = descriptor_extraction_fn(transitions, masks)

Expand Down
4 changes: 1 addition & 3 deletions qdax/core/containers/ga_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def add(
survivor_indices = indices[: self.size]

# keep only the best ones
new_candidates = jax.tree.map(
lambda x: x[survivor_indices], candidates
)
new_candidates = jax.tree.map(lambda x: x[survivor_indices], candidates)

new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
Expand Down
16 changes: 4 additions & 12 deletions qdax/core/containers/mome_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def sample(self, key: RNGKey, num_samples: int) -> Genotype:
cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)

# get genotypes (front) from the chosen indices
pareto_front_genotypes = jax.tree.map(
lambda x: x[cells_idx], self.genotypes
)
pareto_front_genotypes = jax.tree.map(lambda x: x[cells_idx], self.genotypes)

# prepare second sampling function
sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)
Expand All @@ -130,9 +128,7 @@ def sample(self, key: RNGKey, num_samples: int) -> Genotype:
)

# remove the dim coming from pareto front
sampled_genotypes = jax.tree.map(
lambda x: x.squeeze(axis=1), sampled_genotypes
)
sampled_genotypes = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_genotypes)

return sampled_genotypes

Expand Down Expand Up @@ -287,16 +283,12 @@ def _add_one(
index = index.astype(jnp.int32)

# get current repertoire cell data
cell_genotype = jax.tree.map(
lambda x: x[index][0], carry.genotypes
)
cell_genotype = jax.tree.map(lambda x: x[index][0], carry.genotypes)
cell_fitness = carry.fitnesses[index][0]
cell_descriptor = carry.descriptors[index][0]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

new_genotypes = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), genotype
)
new_genotypes = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), genotype)

# update pareto front
(
Expand Down
4 changes: 1 addition & 3 deletions qdax/core/emitters/cma_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,7 @@ def state_update(
sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

# sort the candidates
sorted_candidates = jax.tree.map(
lambda x: x[sorted_indices], genotypes
)
sorted_candidates = jax.tree.map(lambda x: x[sorted_indices], genotypes)
sorted_improvements = improvements[sorted_indices]

# compute reinitialize condition
Expand Down
4 changes: 1 addition & 3 deletions qdax/core/emitters/cma_pool_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def state_update(
current_index = emitter_state.current_index
emitter_states = emitter_state.emitter_states

used_emitter_state = jax.tree.map(
lambda x: x[current_index], emitter_states
)
used_emitter_state = jax.tree.map(lambda x: x[current_index], emitter_states)

# update the used emitter state
used_emitter_state = self._emitter.state_update(
Expand Down
8 changes: 2 additions & 6 deletions qdax/core/emitters/mutation_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def polynomial_crossover(
)
crossover_fn = jax.vmap(crossover_fn)
# TODO: check that key usage is correct
x = jax.tree.map(
lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2
)
x = jax.tree.map(lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2)
return x


Expand Down Expand Up @@ -223,8 +221,6 @@ def _variation_fn(x1: jnp.ndarray, x2: jnp.ndarray, key: RNGKey) -> jnp.ndarray:
keys_tree = jax.tree.unflatten(jax.tree.structure(x1), keys)

# apply isolinedd to each branch of the tree
x = jax.tree.map(
lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree
)
x = jax.tree.map(lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree)

return x
4 changes: 1 addition & 3 deletions qdax/core/emitters/omg_mega_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ def emit(
update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

# update the genotypes
new_genotypes = jax.tree.map(
lambda x, y: x + y, genotypes, update_grad
)
new_genotypes = jax.tree.map(lambda x, y: x + y, genotypes, update_grad)

return new_genotypes, {}

Expand Down
4 changes: 1 addition & 3 deletions qdax/environments/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def reset(self, rng: jp.ndarray) -> State:
reset_state = self.env.reset(rng)
reset_state.metrics["reward"] = reset_state.reward
eval_metrics = CompletedEvalMetrics(
current_episode_metrics=jax.tree.map(
jp.zeros_like, reset_state.metrics
),
current_episode_metrics=jax.tree.map(jp.zeros_like, reset_state.metrics),
completed_episodes_metrics=jax.tree.map(
lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics
),
Expand Down
4 changes: 1 addition & 3 deletions qdax/tasks/brax_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,7 @@ def scoring_function_brax_envs(

# Reset environments
key, subkey = jax.random.split(key)
keys = jax.random.split(
subkey, jax.tree.leaves(policies_params)[0].shape[0]
)
keys = jax.random.split(subkey, jax.tree.leaves(policies_params)[0].shape[0])
init_states = jax.vmap(play_reset_fn)(keys)

# Step environments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def test_insert_batch() -> None:
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree.map(
lambda x: x.repeat(3, axis=0), dummy_transition
)
simple_transition = jax.tree.map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))
data = QDTransition.from_flatten(replay_buffer.data, dummy_transition)
pytest.assume(
Expand Down Expand Up @@ -85,9 +83,7 @@ def test_sample() -> None:
buffer_size=buffer_size, transition=dummy_transition
)

simple_transition = jax.tree.map(
lambda x: x.repeat(3, axis=0), dummy_transition
)
simple_transition = jax.tree.map(lambda x: x.repeat(3, axis=0), dummy_transition)
simple_transition = simple_transition.replace(rewards=jnp.arange(3))

replay_buffer = replay_buffer.insert(simple_transition)
Expand Down

0 comments on commit fa45483

Please sign in to comment.