From 823c1d8279667723963fb5f8844ef9975d8ce27c Mon Sep 17 00:00:00 2001 From: kngwyu Date: Tue, 2 Jul 2024 15:17:42 +0900 Subject: [PATCH] Fix lint --- experiments/cf_nolearn.py | 6 +----- experiments/cf_simple.py | 5 +++-- noxfile.py | 2 +- pyproject.toml | 5 +++-- scripts/make_web_data.py | 14 +++++++------- src/emevo/exp_utils.py | 2 +- 6 files changed, 16 insertions(+), 18 deletions(-) diff --git a/experiments/cf_nolearn.py b/experiments/cf_nolearn.py index 9fb3792..0928a56 100644 --- a/experiments/cf_nolearn.py +++ b/experiments/cf_nolearn.py @@ -33,11 +33,7 @@ SavedProfile, is_cuda_ready, ) -from emevo.rl.ppo_normal import ( - NormalPPONet, - vmap_apply, - vmap_net, -) +from emevo.rl.ppo_normal import NormalPPONet, vmap_apply, vmap_net from emevo.spaces import BoxSpace PROJECT_ROOT = Path(__file__).parent.parent diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index 4ccb180..dd9efe4 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -1,4 +1,5 @@ """Asexual reward evolution with Circle Foraging""" + import dataclasses import json from pathlib import Path @@ -56,11 +57,11 @@ class RewardExtractor: _max_norm: jax.Array = dataclasses.field(init=False) def __post_init__(self) -> None: - self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high ** 2, axis=-1)) + self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high**2, axis=-1)) def normalize_action(self, action: jax.Array) -> jax.Array: scaled = self.act_space.sigmoid_scale(action) - norm = jnp.sqrt(jnp.sum(scaled ** 2, axis=-1, keepdims=True)) + norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True)) return norm / self._max_norm def extract( diff --git a/noxfile.py b/noxfile.py index 071392e..9518f9f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -6,7 +6,7 @@ import nox -SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments"] +SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments", "scripts"] def _sync(session: nox.Session, requirements: str) -> None: diff --git a/pyproject.toml b/pyproject.toml index 086a4c1..915d390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,5 +73,6 @@ select = ["E", "F", "B", "UP"] "__init__.py" = ["F401"] "src/emevo/reward_fn.py" = ["B023"] # For typer -"experiments/**/*.py" = ["B008", "UP006", "UP007"] -"smoke-tests/*.py" = ["B008", "UP006", "UP007"] \ No newline at end of file +"experiments/*.py" = ["B008", "UP006", "UP007"] +"smoke-tests/*.py" = ["B008", "UP006", "UP007"] +"scripts/*.py" = ["UP006", "UP035"] \ No newline at end of file diff --git a/scripts/make_web_data.py b/scripts/make_web_data.py index 12c26c4..479cbc0 100644 --- a/scripts/make_web_data.py +++ b/scripts/make_web_data.py @@ -7,11 +7,8 @@ import numpy as np import polars as pl import typer -from numpy.typing import NDArray -from serde import toml from emevo.analysis.log_plotting import load_log -from emevo.exp_utils import CfConfig, SavedPhysicsState PROJECT_ROOT = Path(__file__).parent.parent @@ -106,7 +103,7 @@ def _agg_df( def main( profile_and_rewards_path: Path, - starting_points: List[int] = [], + starting_points: List[int], write_dir: Optional[Path] = None, length: int = 100, ) -> None: @@ -127,9 +124,12 @@ def main( ldfi, index * 1024000, ) - cxy_df.write_parquet(write_dir / f"saved_cpos-{point}.parqut", compression="snappy") - sxy_df.write_parquet(write_dir / f"saved_spos-{point}.parqut", compression="snappy") - + cxy_df.write_parquet( + write_dir / f"saved_cpos-{point}.parqut", compression="snappy" + ) + sxy_df.write_parquet( + write_dir / f"saved_spos-{point}.parqut", compression="snappy" + ) if __name__ == "__main__": diff --git a/src/emevo/exp_utils.py b/src/emevo/exp_utils.py index 67cd081..148f3a5 100644 --- a/src/emevo/exp_utils.py +++ b/src/emevo/exp_utils.py @@ -134,7 +134,7 @@ class GopsConfig: path: str init_std: float init_mean: float - params: dict[str, float | dict[str, float]] + params: dict[str, float | dict[str, Any]] init_kwargs: dict[str, float] = dataclasses.field(default_factory=dict) def load_model(self) -> gops.Mutation | gops.Crossover: