Skip to content

Commit

Permalink
Move position masking to moderngl_vis
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Dec 18, 2023
1 parent 6dc9a40 commit 1113880
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 35 deletions.
4 changes: 2 additions & 2 deletions config/env/20231214-square.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
n_initial_agents = 20
n_max_agents = 100
n_max_foods = 20
food_num_fn = ["logistic", 0.01, 20]
n_max_foods = 40
food_num_fn = ["logistic", 20, 0.01, 40]
food_loc_fn = "gaussian"
agent_loc_fn = "uniform"
xlim = [0.0, 360.0]
Expand Down
15 changes: 5 additions & 10 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
)
from emevo.visualizer import SaveVideoWrapper

N_MAX_AGENTS: int = 10


class RewardFn(Protocol):
def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array:
Expand All @@ -42,7 +40,6 @@ def __call__(self, collision: jax.Array, action: jax.Array) -> jax.Array:

class LinearReward(eqx.Module):
weight: jax.Array
max_action_norm: float

def __init__(self, key: chex.PRNGKey, n_agents: int) -> None:
self.weight = jax.random.normal(key, (n_agents, 4))
Expand Down Expand Up @@ -126,7 +123,7 @@ def step_rollout(
state_t1d = env.deactivate(state_t1, dead)
birth_prob = birth_fn(state_t1d.status.age, state_t1d.status.energy)
possible_parents = jnp.logical_and(
jnp.logical_not(dead),
jnp.logical_and(jnp.logical_not(dead), state.profile.is_active()),
jax.random.bernoulli(birth_key, p=birth_prob),
)
state_t1db, parents = env.activate(state_t1d, possible_parents)
Expand Down Expand Up @@ -168,7 +165,7 @@ def epoch(
minibatch_size: int,
n_optim_epochs: int,
) -> tuple[State, Obs, Log, optax.OptState, NormalPPONet]:
keys = jax.random.split(prng_key, N_MAX_AGENTS + 1)
keys = jax.random.split(prng_key, env.n_max_agents + 1)
env_state, rollout, log, obs, next_value = exec_rollout(
state,
initial_obs,
Expand Down Expand Up @@ -197,7 +194,6 @@ def epoch(

def run_evolution(
key: jax.Array,
n_agents: int,
env: Env,
adam: optax.GradientTransformation,
gamma: float,
Expand All @@ -220,15 +216,15 @@ def run_evolution(
input_size,
64,
act_size,
jax.random.split(net_key, N_MAX_AGENTS),
jax.random.split(net_key, env.n_max_agents),
)
adam_init, adam_update = adam
opt_state = jax.vmap(adam_init)(eqx.filter(pponet, eqx.is_array))
env_state, timestep = env.reset(reset_key)
obs = timestep.obs

n_loop = n_total_steps // n_rollout_steps
rewards = jnp.zeros(N_MAX_AGENTS)
rewards = jnp.zeros(env.n_max_agents)
keys = jax.random.split(key, n_loop)
if debug_vis:
visualizer = env.visualizer(env_state, figsize=(640.0, 640.0))
Expand Down Expand Up @@ -261,7 +257,7 @@ def run_evolution(


app = typer.Typer(pretty_exceptions_show_locals=False)
here = Path(__file__)
here = Path(__file__).parent


@app.command()
Expand Down Expand Up @@ -300,7 +296,6 @@ def evolve(
raise ValueError(f"Invalid reward_fn {reward_fn}")
network = run_evolution(
key,
n_agents,
env,
optax.adam(adam_lr, eps=adam_eps),
gamma,
Expand Down
12 changes: 6 additions & 6 deletions src/emevo/birth_and_death.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def survival(self, age: jax.Array, energy: jax.Array) -> jax.Array:
return jnp.exp(-self.cumulative(age, energy))


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class DeterministicHazard(HazardFunction):
"""
A deterministic hazard function where an agent dies when
Expand Down Expand Up @@ -52,7 +52,7 @@ def cumulative(self, age: jax.Array, energy: jax.Array) -> jax.Array:
)


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class ConstantHazard(HazardFunction):
"""
Hazard with constant death rate.
Expand All @@ -79,7 +79,7 @@ def cumulative(self, age: jax.Array, energy: jax.Array) -> jax.Array:
return self(age, energy) * age


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class EnergyLogisticHazard(HazardFunction):
"""
Hazard with death rate that only depends on energy.
Expand All @@ -101,7 +101,7 @@ def cumulative(self, age: jax.Array, energy: jax.Array) -> jax.Array:
return self._energy_death_rate(energy) * age


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class GompertzHazard(ConstantHazard):
"""
Hazard with exponentially increasing death rate.
Expand All @@ -123,7 +123,7 @@ def cumulative(self, age: jax.Array, energy: jax.Array) -> jax.Array:
return ht - h0


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class ELGompertzHazard(EnergyLogisticHazard):
"""
Exponentially increasing with time + EnergyLogistic
Expand Down Expand Up @@ -158,7 +158,7 @@ def cumulative(self, age: jax.Array, energy: jax.Array) -> jax.Array:
...


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class EnergyLogisticBirth(BirthFunction):
"""
Only energy is important to give birth.
Expand Down
1 change: 1 addition & 0 deletions src/emevo/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class Env(abc.ABC, Generic[STATE, OBS]):

act_space: Space
obs_space: Space
n_max_agents: int

def __init__(self, *args, **kwargs) -> None:
# To supress PyRight errors in registry
Expand Down
24 changes: 12 additions & 12 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
MAX_VELOCITY: float = 10.0
AGENT_COLOR: Color = Color(2, 204, 254)
FOOD_COLOR: Color = Color(254, 2, 162)
NOWHERE: float = -100.0
NOWHERE: float = 0.0
N_OBJECTS: int = 3


Expand Down Expand Up @@ -368,7 +368,7 @@ def __init__(
assert n_max_agents > n_initial_agents
assert n_max_foods > self._food_num_fn.initial
self._n_initial_agents = n_initial_agents
self._n_max_agents = n_max_agents
self.n_max_agents = n_max_agents
self._n_initial_foods = self._food_num_fn.initial
self._n_max_foods = n_max_foods
self._max_place_attempts = max_place_attempts
Expand Down Expand Up @@ -407,11 +407,10 @@ def __init__(
# Obs
self._n_sensors = n_agent_sensors
# Some cached constants
self._invisible_xy = jnp.ones(2) * NOWHERE
act_p1 = Vec2d(0, agent_radius).rotated(np.pi * 0.75)
act_p2 = Vec2d(0, agent_radius).rotated(-np.pi * 0.75)
self._act_p1 = jnp.tile(jnp.array(act_p1), (self._n_max_agents, 1))
self._act_p2 = jnp.tile(jnp.array(act_p2), (self._n_max_agents, 1))
self._act_p1 = jnp.tile(jnp.array(act_p1), (self.n_max_agents, 1))
self._act_p2 = jnp.tile(jnp.array(act_p2), (self.n_max_agents, 1))
self._init_agent = jax.jit(
functools.partial(
place,
Expand Down Expand Up @@ -653,7 +652,7 @@ def activate(
is_parent: jax.Array,
) -> tuple[CFState, jax.Array]:
circle = state.physics.circle
keys = jax.random.split(state.key, self._n_max_agents + 1)
keys = jax.random.split(state.key, self.n_max_agents + 1)
new_xy, ok = self._place_newborn(
state.agent_loc,
state.physics,
Expand Down Expand Up @@ -695,12 +694,13 @@ def activate(
return new_state, parents_id

def deactivate(self, state: CFState, flag: jax.Array) -> CFState:
p_xy = state.physics.circle.p.xy.at[flag].set(self._invisible_xy)
expanded_flag = jnp.expand_dims(flag, axis=1)
p_xy = jnp.where(expanded_flag, NOWHERE, state.physics.circle.p.xy)
p = replace(state.physics.circle.p, xy=p_xy)
v_xy = state.physics.circle.v.xy.at[flag].set(0.0)
v_angle = state.physics.circle.v.angle.at[flag].set(0.0)
v_xy = jnp.where(expanded_flag, 0.0, state.physics.circle.v.xy)
v_angle = jnp.where(flag, 0.0, state.physics.circle.v.angle)
v = Velocity(angle=v_angle, xy=v_xy)
is_active = state.physics.circle.is_active.at[flag].set(False)
is_active = jnp.where(flag, False, state.physics.circle.is_active)
circle = replace(state.physics.circle, p=p, v=v, is_active=is_active)
physics = replace(state.physics, circle=circle)
profile = state.profile.deactivate(flag)
Expand All @@ -709,7 +709,7 @@ def deactivate(self, state: CFState, flag: jax.Array) -> CFState:

def reset(self, key: chex.PRNGKey) -> tuple[CFState, TimeStep[CFObs]]:
physics, agent_loc, food_loc = self._initialize_physics_state(key)
nmax = self._n_max_agents
nmax = self.n_max_agents
profile = init_profile(self._n_initial_agents, nmax)
status = init_status(self._n_initial_agents, nmax, self._init_energy)
state = CFState(
Expand Down Expand Up @@ -748,7 +748,7 @@ def _initialize_physics_state(
is_active_c = jnp.concatenate(
(
jnp.ones(self._n_initial_agents, dtype=bool),
jnp.zeros(self._n_max_agents - self._n_initial_agents, dtype=bool),
jnp.zeros(self.n_max_agents - self._n_initial_agents, dtype=bool),
)
)
is_active_s = jnp.concatenate(
Expand Down
23 changes: 19 additions & 4 deletions src/emevo/environments/moderngl_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class HasStateD(Protocol):
stated: StateDict


NOWHERE: float = -1000.0


_CIRCLE_VERTEX_SHADER = """
#version 330
uniform mat4 proj;
Expand Down Expand Up @@ -259,7 +262,8 @@ def _collect_circles(
state: State,
circle_scaling: float,
) -> tuple[NDArray, NDArray, NDArray]:
points = np.array(state.p.xy, dtype=np.float32)
flag = np.array(state.is_active).reshape(-1, 1)
points = np.where(flag, np.array(state.p.xy, dtype=np.float32), NOWHERE)
scales = circle.radius * circle_scaling
colors = np.array(circle.rgba, dtype=np.float32) / 255.0
is_active = np.expand_dims(np.array(state.is_active), axis=1)
Expand All @@ -271,15 +275,17 @@ def _collect_static_lines(segment: Segment, state: State) -> NDArray:
a, b = segment.point1, segment.point2
a = state.p.transform(a)
b = state.p.transform(b)
return np.concatenate((a, b), axis=1).reshape(-1, 2)
flag = np.repeat(np.array(state.is_active), 2).reshape(-1, 1)
return np.where(flag, np.concatenate((a, b), axis=1).reshape(-1, 2), NOWHERE)


def _collect_heads(circle: Circle, state: State) -> NDArray:
y = jnp.array(circle.radius)
x = jnp.zeros_like(y)
p1, p2 = jnp.stack((x, y * 0.8), axis=1), jnp.stack((x, y * 1.2), axis=1)
p1, p2 = state.p.transform(p1), state.p.transform(p2)
return np.concatenate((p1, p2), axis=1).reshape(-1, 2)
flag = np.repeat(np.array(state.is_active), 2).reshape(-1, 1)
return np.where(flag, np.concatenate((p1, p2), axis=1).reshape(-1, 2), NOWHERE)


# def _collect_policies(
Expand Down Expand Up @@ -406,7 +412,16 @@ def collect_sensors(stated: StateDict) -> NDArray:
sensor_fn(stated=stated), # type: ignore
axis=1,
)
return sensors.reshape(-1, 2).astype(jnp.float32)
sensors = sensors.reshape(-1, 2).astype(jnp.float32)
flag = np.repeat(
np.array(stated.circle.is_active),
sensors.shape[0] // stated.circle.batch_size(),
)
return np.where(
flag.reshape(-1, 1),
sensors,
NOWHERE,
)

self._sensors = SegmentVA(
ctx=context,
Expand Down
12 changes: 11 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from serde import toml

from emevo.exp_utils import CfConfig
from emevo import birth_and_death as bd
from emevo.exp_utils import BDConfig, CfConfig


def test_bdconfig() -> None:
with open("config/bd/20230530-a035-e020.toml", "r") as f:
bdconfig = toml.from_toml(BDConfig, f.read())

birth_fn, hazard_fn = bdconfig.load_models()
assert isinstance(birth_fn, bd.EnergyLogisticBirth)
assert isinstance(hazard_fn, bd.ELGompertzHazard)


def test_cfconfig() -> None:
Expand Down

0 comments on commit 1113880

Please sign in to comment.