Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jul 2, 2024
1 parent 3d220f2 commit 823c1d8
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 18 deletions.
6 changes: 1 addition & 5 deletions experiments/cf_nolearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
SavedProfile,
is_cuda_ready,
)
from emevo.rl.ppo_normal import (
NormalPPONet,
vmap_apply,
vmap_net,
)
from emevo.rl.ppo_normal import NormalPPONet, vmap_apply, vmap_net
from emevo.spaces import BoxSpace

PROJECT_ROOT = Path(__file__).parent.parent
Expand Down
5 changes: 3 additions & 2 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Asexual reward evolution with Circle Foraging"""

import dataclasses
import json
from pathlib import Path
Expand Down Expand Up @@ -56,11 +57,11 @@ class RewardExtractor:
_max_norm: jax.Array = dataclasses.field(init=False)

def __post_init__(self) -> None:
self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high ** 2, axis=-1))
self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high**2, axis=-1))

def normalize_action(self, action: jax.Array) -> jax.Array:
scaled = self.act_space.sigmoid_scale(action)
norm = jnp.sqrt(jnp.sum(scaled ** 2, axis=-1, keepdims=True))
norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True))
return norm / self._max_norm

def extract(
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import nox

SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments"]
SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments", "scripts"]


def _sync(session: nox.Session, requirements: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ select = ["E", "F", "B", "UP"]
"__init__.py" = ["F401"]
"src/emevo/reward_fn.py" = ["B023"]
# For typer
"experiments/**/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
"experiments/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
"scripts/*.py" = ["UP006", "UP035"]
14 changes: 7 additions & 7 deletions scripts/make_web_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
import numpy as np
import polars as pl
import typer
from numpy.typing import NDArray
from serde import toml

from emevo.analysis.log_plotting import load_log
from emevo.exp_utils import CfConfig, SavedPhysicsState

PROJECT_ROOT = Path(__file__).parent.parent

Expand Down Expand Up @@ -106,7 +103,7 @@ def _agg_df(

def main(
profile_and_rewards_path: Path,
starting_points: List[int] = [],
starting_points: List[int],
write_dir: Optional[Path] = None,
length: int = 100,
) -> None:
Expand All @@ -127,9 +124,12 @@ def main(
ldfi,
index * 1024000,
)
cxy_df.write_parquet(write_dir / f"saved_cpos-{point}.parqut", compression="snappy")
sxy_df.write_parquet(write_dir / f"saved_spos-{point}.parqut", compression="snappy")

cxy_df.write_parquet(
write_dir / f"saved_cpos-{point}.parqut", compression="snappy"
)
sxy_df.write_parquet(
write_dir / f"saved_spos-{point}.parqut", compression="snappy"
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class GopsConfig:
path: str
init_std: float
init_mean: float
params: dict[str, float | dict[str, float]]
params: dict[str, float | dict[str, Any]]
init_kwargs: dict[str, float] = dataclasses.field(default_factory=dict)

def load_model(self) -> gops.Mutation | gops.Crossover:
Expand Down

0 comments on commit 823c1d8

Please sign in to comment.