Skip to content

Commit

Permalink
test_config
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Dec 18, 2023
1 parent 1614885 commit 6dc9a40
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 22 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
**/build/
**/*.egg-info/
**/.mypy_cache/
.virtual_documents/
.ipynb_checkpoints/
requirements/*.txt
# This should be local
pyrightconfig.json
pyrightconfig.json
*.eqx
2 changes: 1 addition & 1 deletion config/env/20231214-square.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ max_force = 40.0
min_force = -20.0
init_energy = 20.0
energy_capacity = 100.0
force_energy_consumption = 0.01 / 40.0
force_energy_consumption = 0.00025
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
Expand Down
23 changes: 17 additions & 6 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Example of using circle foraging environment"""
import dataclasses
import enum
from pathlib import Path
from typing import Literal
from typing import Protocol

import chex
import equinox as eqx
Expand All @@ -11,6 +12,7 @@
import optax
import typer
from fastavro import parse_schema, writer
from jax._src.numpy.lax_numpy import Protocol
from serde import toml

from emevo import Env
Expand All @@ -33,6 +35,11 @@
N_MAX_AGENTS: int = 10


class RewardFn(Protocol):
def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array:
...


class LinearReward(eqx.Module):
weight: jax.Array
max_action_norm: float
Expand All @@ -46,6 +53,11 @@ def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array:
return jax.vmap(jnp.dot)(input_, self.weight)


class RewardKind(str, enum.Enum):
LINEAR = "linear"
SIGMOID = "sigmoid"


def visualize(
key: chex.PRNGKey,
env: Env,
Expand Down Expand Up @@ -243,7 +255,6 @@ def run_evolution(
if visualizer is not None:
visualizer.render(env_state)
visualizer.show()
print(f"Rewards: {[x.item() for x in ri[: n_agents]]}")
# weight_summary(pponet)
print(f"Sum of rewards {[x.item() for x in rewards[: n_agents]]}")
return pponet
Expand All @@ -267,7 +278,7 @@ def evolve(
n_total_steps: int = 1024 * 1000,
cfconfig_path: Path = here.joinpath("../config/env/20231214-square.toml"),
bdconfig_path: Path = here.joinpath("../config/bd/20230530-a035-e020.toml"),
reward_fn: Literal["linear", "sigmoid"] = "linear",
reward_fn: RewardKind = RewardKind.LINEAR,
logdir: Path = Path("./log"),
debug_vis: bool = False,
) -> None:
Expand All @@ -277,13 +288,13 @@ def evolve(
bdconfig = toml.from_toml(BDConfig, f.read())

# Override config
cfconfig.n_agents = n_agents
cfconfig.n_initial_agents = n_agents
env = make("CircleForaging-v0", **dataclasses.asdict(cfconfig))
birth_fn, hazard_fn = bdconfig.load_models()
key, reward_key = jax.random.split(jax.random.PRNGKey(seed))
if reward_fn == "linear":
if reward_fn == RewardKind.LINEAR:
reward_fn_instance = LinearReward(reward_key, cfconfig.n_max_agents)
elif reward_fn == "sigmoid":
elif reward_fn == RewardKind.SIGMOID:
assert False, "Unimplemented"
else:
raise ValueError(f"Invalid reward_fn {reward_fn}")
Expand Down
44 changes: 30 additions & 14 deletions src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,51 @@

import dataclasses
import importlib
from typing import Dict, List, Tuple, Type
from typing import Dict, Tuple, Type, Union

import chex
import fastavro
import jax
import serde

from emevo import birth_and_death as bd
from emevo.environments.circle_foraging import SensorRange


@serde.serde
@dataclasses.dataclass
class CfConfig:
agent_radius: float
n_agents: int
n_agent_sensors: int
sensor_length: float
food_loc_fn: str
food_num_fn: Tuple[str, int]
xlim: Tuple[float, float]
ylim: Tuple[float, float]
env_radius: float
env_shape: str
obstacles: List[Tuple[float, float, float, float]]
seed: int
n_initial_agents: int = 6
n_max_agents: int = 100
n_max_foods: int = 40
food_num_fn: Union[str, Tuple[str, ...]] = "constant"
food_loc_fn: Union[str, Tuple[str, ...]] = "gaussian"
agent_loc_fn: Union[str, Tuple[str, ...]] = "uniform"
xlim: Tuple[float, float] = (0.0, 200.0)
ylim: Tuple[float, float] = (0.0, 200.0)
env_radius: float = 120.0
env_shape: str = "square"
obstacles: str = "none"
newborn_loc: str = "neighbor"
neighbor_stddev: float = 40.0
n_agent_sensors: int = 16
sensor_length: float = 100.0
sensor_range: SensorRange = SensorRange.WIDE
agent_radius: float = 10.0
food_radius: float = 4.0
foodloc_interval: int = 1000
dt: float = 0.1
linear_damping: float = 0.8
angular_damping: float = 0.6
max_force: float = 40.0
min_force: float = -20.0
init_energy: float = 20.0
energy_capacity: float = 100.0
force_energy_consumption: float = 0.01 / 40.0
energy_share_ratio: float = 0.4
n_velocity_iter: int = 6
n_position_iter: int = 2
n_physics_iter: int = 5
max_place_attempts: int = 10


def _load_cls(cls_path: str) -> Type:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from serde import toml

from emevo.exp_utils import CfConfig


def test_cfconfig() -> None:
with open("config/env/20231214-square.toml", "r") as f:
cfconfig = toml.from_toml(CfConfig, f.read())

assert cfconfig.sensor_range == "wide"

0 comments on commit 6dc9a40

Please sign in to comment.