From 6dc9a40aaf229558404d2243c3a1429ace5739d6 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Mon, 18 Dec 2023 14:40:12 +0900 Subject: [PATCH] test_config --- .gitignore | 4 ++- config/env/20231214-square.toml | 2 +- experiments/cf_asexual_evo.py | 23 ++++++++++++----- src/emevo/exp_utils.py | 44 ++++++++++++++++++++++----------- tests/test_config.py | 10 ++++++++ 5 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 tests/test_config.py diff --git a/.gitignore b/.gitignore index d3da7d66..0e00be34 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,9 @@ **/build/ **/*.egg-info/ **/.mypy_cache/ +.virtual_documents/ .ipynb_checkpoints/ requirements/*.txt # This should be local -pyrightconfig.json \ No newline at end of file +pyrightconfig.json +*.eqx \ No newline at end of file diff --git a/config/env/20231214-square.toml b/config/env/20231214-square.toml index 12b31c5d..2f7cedad 100644 --- a/config/env/20231214-square.toml +++ b/config/env/20231214-square.toml @@ -22,7 +22,7 @@ max_force = 40.0 min_force = -20.0 init_energy = 20.0 energy_capacity = 100.0 -force_energy_consumption = 0.01 / 40.0 +force_energy_consumption = 0.00025 energy_share_ratio = 0.4 n_velocity_iter = 6 n_position_iter = 2 diff --git a/experiments/cf_asexual_evo.py b/experiments/cf_asexual_evo.py index 1308ef2a..3a48a697 100644 --- a/experiments/cf_asexual_evo.py +++ b/experiments/cf_asexual_evo.py @@ -1,7 +1,8 @@ """Example of using circle foraging environment""" import dataclasses +import enum from pathlib import Path -from typing import Literal +from typing import Protocol import chex import equinox as eqx @@ -11,6 +12,7 @@ import optax import typer from fastavro import parse_schema, writer +from jax._src.numpy.lax_numpy import Protocol from serde import toml from emevo import Env @@ -33,6 +35,11 @@ N_MAX_AGENTS: int = 10 +class RewardFn(Protocol): + def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array: + ... + + class LinearReward(eqx.Module): weight: jax.Array max_action_norm: float @@ -46,6 +53,11 @@ def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array: return jax.vmap(jnp.dot)(input_, self.weight) +class RewardKind(str, enum.Enum): + LINEAR = "linear" + SIGMOID = "sigmoid" + + def visualize( key: chex.PRNGKey, env: Env, @@ -243,7 +255,6 @@ def run_evolution( if visualizer is not None: visualizer.render(env_state) visualizer.show() - print(f"Rewards: {[x.item() for x in ri[: n_agents]]}") # weight_summary(pponet) print(f"Sum of rewards {[x.item() for x in rewards[: n_agents]]}") return pponet @@ -267,7 +278,7 @@ def evolve( n_total_steps: int = 1024 * 1000, cfconfig_path: Path = here.joinpath("../config/env/20231214-square.toml"), bdconfig_path: Path = here.joinpath("../config/bd/20230530-a035-e020.toml"), - reward_fn: Literal["linear", "sigmoid"] = "linear", + reward_fn: RewardKind = RewardKind.LINEAR, logdir: Path = Path("./log"), debug_vis: bool = False, ) -> None: @@ -277,13 +288,13 @@ def evolve( bdconfig = toml.from_toml(BDConfig, f.read()) # Override config - cfconfig.n_agents = n_agents + cfconfig.n_initial_agents = n_agents env = make("CircleForaging-v0", **dataclasses.asdict(cfconfig)) birth_fn, hazard_fn = bdconfig.load_models() key, reward_key = jax.random.split(jax.random.PRNGKey(seed)) - if reward_fn == "linear": + if reward_fn == RewardKind.LINEAR: reward_fn_instance = LinearReward(reward_key, cfconfig.n_max_agents) - elif reward_fn == "sigmoid": + elif reward_fn == RewardKind.SIGMOID: assert False, "Unimplemented" else: raise ValueError(f"Invalid reward_fn {reward_fn}") diff --git a/src/emevo/exp_utils.py b/src/emevo/exp_utils.py index afe0c028..670a0d9e 100644 --- a/src/emevo/exp_utils.py +++ b/src/emevo/exp_utils.py @@ -3,35 +3,51 @@ import dataclasses import importlib -from typing import Dict, List, Tuple, Type +from typing import Dict, Tuple, Type, Union import chex -import fastavro import jax import serde from emevo import birth_and_death as bd +from emevo.environments.circle_foraging import SensorRange @serde.serde @dataclasses.dataclass class CfConfig: - agent_radius: float - n_agents: int - n_agent_sensors: int - sensor_length: float - food_loc_fn: str - food_num_fn: Tuple[str, int] - xlim: Tuple[float, float] - ylim: Tuple[float, float] - env_radius: float - env_shape: str - obstacles: List[Tuple[float, float, float, float]] - seed: int + n_initial_agents: int = 6 + n_max_agents: int = 100 + n_max_foods: int = 40 + food_num_fn: Union[str, Tuple[str, ...]] = "constant" + food_loc_fn: Union[str, Tuple[str, ...]] = "gaussian" + agent_loc_fn: Union[str, Tuple[str, ...]] = "uniform" + xlim: Tuple[float, float] = (0.0, 200.0) + ylim: Tuple[float, float] = (0.0, 200.0) + env_radius: float = 120.0 + env_shape: str = "square" + obstacles: str = "none" + newborn_loc: str = "neighbor" + neighbor_stddev: float = 40.0 + n_agent_sensors: int = 16 + sensor_length: float = 100.0 + sensor_range: SensorRange = SensorRange.WIDE + agent_radius: float = 10.0 + food_radius: float = 4.0 + foodloc_interval: int = 1000 + dt: float = 0.1 linear_damping: float = 0.8 angular_damping: float = 0.6 max_force: float = 40.0 min_force: float = -20.0 + init_energy: float = 20.0 + energy_capacity: float = 100.0 + force_energy_consumption: float = 0.01 / 40.0 + energy_share_ratio: float = 0.4 + n_velocity_iter: int = 6 + n_position_iter: int = 2 + n_physics_iter: int = 5 + max_place_attempts: int = 10 def _load_cls(cls_path: str) -> Type: diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..d9654d98 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,10 @@ +from serde import toml + +from emevo.exp_utils import CfConfig + + +def test_cfconfig() -> None: + with open("config/env/20231214-square.toml", "r") as f: + cfconfig = toml.from_toml(CfConfig, f.read()) + + assert cfconfig.sensor_range == "wide"