Skip to content

Commit

Permalink
Fix ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Dec 13, 2023
1 parent 2ea28e9 commit 0933e56
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 6 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ profile = "black"

[tool.ruff]
line-length = 88
# ignore = ["UP035"]
select = ["E", "F", "B", "UP"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
# For pyserde
"src/emevo/exp_utils.py" = ["UP006", "UP035"]
# For typer
"experiments/**/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
4 changes: 3 additions & 1 deletion smoke-tests/circle_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def main(

if replace and i % replace_interval == 0:
if i < steps // 2:
flag = jnp.zeros(n_max_agents, dtype=bool).at[deactivate_index].set(True)
flag = (
jnp.zeros(n_max_agents, dtype=bool).at[deactivate_index].set(True)
)
state = env.deactivate(state, flag)
deactivate_index -= 1
else:
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" Implementation of registry and built-in emevo environments.
"""

from emevo.environments.registry import description, make, register
from emevo.environments.circle_foraging import CircleForaging
from emevo.environments.registry import description, make, register

register(
"CircleForaging-v0",
Expand Down
31 changes: 30 additions & 1 deletion src/emevo/environments/moderngl_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,33 @@ def _collect_heads(circle: Circle, state: State) -> NDArray:
return np.concatenate((p1, p2), axis=1).reshape(-1, 2)


# def _collect_policies(
# circle: Circle,
# state: State,
# max_arrow_length: float,
# ) -> NDArray:
# max_f = max(map(lambda bp: bp[1].max(), bodies_and_policies))
# policy_scaling = max_arrow_length / max_f
# points = []
# radius = None
# for body, policy in bodies_and_policies:
# a = body.position
# if radius is None:
# radius = next(
# filter(lambda shape: isinstance(shape, pymunk.Circle), body.shapes)
# ).radius
# f1, f2 = policy
# from1 = a + pymunk.Vec2d(0, radius).rotated(body.angle + np.pi * 0.75)
# to1 = from1 + pymunk.Vec2d(0, -f1 * policy_scaling).rotated(body.angle)
# from2 = a + pymunk.Vec2d(0, radius).rotated(body.angle - np.pi * 0.75)
# to2 = from2 + pymunk.Vec2d(0, -f2 * policy_scaling).rotated(body.angle)
# points.append(from1)
# points.append(to1)
# points.append(from2)
# points.append(to2)
# return np.array(points, dtype=np.float32)


def _get_clip_ranges(lengthes: list[float]) -> list[tuple[float, float]]:
"""Clip ranges to [-1, 1]"""
total = sum(lengthes)
Expand Down Expand Up @@ -435,7 +462,9 @@ def overlay(self, name: str, value: Any) -> Any:
"""Render additional value as an overlay"""
key = name.lower()
if key == "arrow":
segments = _collect_policies(value, self._range_min * 0.1)
# Not implmented yet
# segments = _collect_policies(value, self._range_min * 0.1)
segments = np.zeros(1)
if "arrow" in self._overlays:
do_render = self._overlays["arrow"].update(segments)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CfConfig:
n_agent_sensors: int
sensor_length: float
food_loc_fn: str
food_num_fn: Tuple["str", int]
food_num_fn: Tuple[str, int]
xlim: Tuple[float, float]
ylim: Tuple[float, float]
env_radius: float
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/rl/ppo_normal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import cast, NamedTuple
from typing import NamedTuple, cast

import chex
import distrax
Expand Down
1 change: 0 additions & 1 deletion tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest

from emevo.rl.ppo_normal import (
Batch,
NormalPPONet,
Rollout,
get_minibatches,
Expand Down

0 comments on commit 0933e56

Please sign in to comment.