Skip to content

Commit

Permalink
Merge pull request #6 from oist/multi-food-regen
Browse files Browse the repository at this point in the history
Enable to regerenerate multiple foods at the same time
  • Loading branch information
kngwyu authored Jun 5, 2024
2 parents e414dc7 + d68da7f commit e520de3
Show file tree
Hide file tree
Showing 25 changed files with 442 additions and 367 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ jobs:
- {name: Linux, python: '3.11', os: ubuntu-latest, nox: "tests"}
- {name: Windows, python: '3.10', os: windows-latest, nox: "tests"}
- {name: Mac, python: '3.10', os: macos-latest, nox: "tests"}
- {name: Py312, python: '3.12', os: ubuntu-latest, nox: "tests"}
- {name: Py310, python: '3.10', os: ubuntu-latest, nox: "tests"}
- {name: Py39, python: '3.9', os: ubuntu-latest, nox: "tests"}
- {name: Lint, python: '3.11', os: ubuntu-latest, nox: "lint"}
steps:
- uses: actions/checkout@v3
Expand Down
55 changes: 0 additions & 55 deletions benches/test_step.py

This file was deleted.

30 changes: 30 additions & 0 deletions config/env/20240524-const-uniform.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
n_initial_agents = 50
n_max_agents = 200
n_max_foods = 50
food_num_fn = ["constant", 40]
food_loc_fn = "uniform"
agent_loc_fn = "uniform"
xlim = [0.0, 480.0]
ylim = [0.0, 360.0]
env_shape = "square"
neighbor_stddev = 100.0
n_agent_sensors = 24
sensor_length = 200.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
foodloc_interval = 1000
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 80.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 400.0
force_energy_consumption = 2e-5
basic_energy_consumption = 1e-3
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
30 changes: 30 additions & 0 deletions config/env/20240524-uniform-square.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
n_initial_agents = 50
n_max_agents = 200
n_max_foods = 50
food_num_fn = ["linear", 20, 0.1, 60]
food_loc_fn = "uniform"
agent_loc_fn = "uniform"
xlim = [0.0, 480.0]
ylim = [0.0, 480.0]
env_shape = "square"
neighbor_stddev = 100.0
n_agent_sensors = 24
sensor_length = 200.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
foodloc_interval = 1000
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 80.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 400.0
force_energy_consumption = 2e-5
basic_energy_consumption = 1e-3
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
30 changes: 30 additions & 0 deletions config/env/20240524-uniform.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
n_initial_agents = 50
n_max_agents = 200
n_max_foods = 50
food_num_fn = ["linear", 20, 0.1, 60]
food_loc_fn = "uniform"
agent_loc_fn = "uniform"
xlim = [0.0, 480.0]
ylim = [0.0, 360.0]
env_shape = "square"
neighbor_stddev = 100.0
n_agent_sensors = 24
sensor_length = 200.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
foodloc_interval = 1000
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 80.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 400.0
force_energy_consumption = 2e-5
basic_energy_consumption = 1e-3
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
2 changes: 1 addition & 1 deletion experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def replace_net(
) -> tuple[NormalPPONet, optax.OptState]:
initialized = initialize_net(key)
pponet = eqx_where(flag, initialized, pponet)
opt_state = jax.tree_map(
opt_state = jax.tree_util.tree_map(
lambda a, b: jnp.where(
jnp.expand_dims(flag, tuple(range(1, a.ndim))),
b,
Expand Down
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def format(session: nox.Session) -> None:
session.run("isort", *SOURCES)


@nox.session(reuse_venv=True, python=["3.9", "3.10", "3.11"])
@nox.session(reuse_venv=True, python=["3.10", "3.11", "3.12"])
def lint(session: nox.Session) -> None:
_sync(session, "requirements/lint.txt")
session.run("ruff", "check", *SOURCES)
Expand Down Expand Up @@ -110,7 +110,7 @@ def smoke(session: nox.Session) -> None:
session.run("python", DEFAULT, *session.posargs)


@nox.session(reuse_venv=True, python=["3.9", "3.10", "3.11"])
@nox.session(reuse_venv=True, python=["3.10", "3.11", "3.12"])
def tests(session: nox.Session) -> None:
_sync(session, "requirements/tests.txt")
session.run("pytest", "tests", *session.posargs)
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies = [
"moderngl >= 5.6",
"moderngl-window >= 2.4",
"jax >= 0.4",
"pyarrow >= 8.0",
"pyserde[toml] == 0.13.2", # TODO: update
"pyarrow >= 9.0",
"pyserde[toml] >= 0.14",
"optax >= 0.1",
]
dynamic = ["version"]
Expand Down Expand Up @@ -71,8 +71,6 @@ select = ["E", "F", "B", "UP"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"src/emevo/reward_fn.py" = ["B023"]
# For pyserde
"src/emevo/exp_utils.py" = ["UP006", "UP007", "UP035"]
# For typer
"experiments/**/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
2 changes: 1 addition & 1 deletion smoke-tests/circle_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def weight_summary(network):
params, _ = eqx.partition(network, eqx.is_inexact_array)
params_mean = jax.tree_map(jnp.mean, params)
params_mean = jax.tree_util.tree_map(jnp.mean, params)
for k, v in jax.tree_util.tree_leaves_with_path(params_mean):
print(k, v)

Expand Down
1 change: 1 addition & 0 deletions src/emevo/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Implementation of registry and built-in emevo environments.
"""

from emevo.environments.registry import register

register(
Expand Down
Loading

0 comments on commit e520de3

Please sign in to comment.