Skip to content

Commit

Permalink
Fix bug in toxin so that it has food_2 reward
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 12, 2024
1 parent 5b80227 commit cf653d0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
13 changes: 10 additions & 3 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,12 @@ def replay(
end: int | None = None,
cfconfig_path: Path = DEFAULT_CFCONFIG,
env_override: str = "",
scale: float = 1.0,
force_cpu: bool = False,
) -> None:
if force_cpu:
jax.config.update("jax_default_device", jax.devices("cpu")[0])

with cfconfig_path.open("r") as f:
cfconfig = toml.from_toml(CfConfig, f.read())
# For speedup
Expand All @@ -513,14 +518,16 @@ def replay(
end_index = end if end is not None else phys_state.circle_axy.shape[0]
visualizer = env.visualizer(
env_state,
figsize=(cfconfig.xlim[1] * 2, cfconfig.ylim[1] * 2),
figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale),
backend=backend,
)
if videopath is not None:
visualizer = SaveVideoWrapper(visualizer, videopath, fps=60)
for i in range(start, end_index):
phys = phys_state.set_by_index(i, env_state.physics)
env_state = dataclasses.replace(env_state, physics=phys)
ph = phys_state.set_by_index(i, env_state.physics)
# Disable rendering agents
ph = ph.nested_replace("circle.is_active", jnp.zeros_like(ph.circle.is_active))
env_state = dataclasses.replace(env_state, physics=ph)
visualizer.render(env_state.physics)
visualizer.show()
visualizer.close()
Expand Down
17 changes: 11 additions & 6 deletions experiments/cf_toxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def normalize_action(self, action: jax.Array) -> jax.Array:
def extract(
self,
ate_food: jax.Array,
ate_toxin: jax.Array,
action: jax.Array,
energy: jax.Array,
) -> jax.Array:
del energy
act_input = self.act_coef * self.normalize_action(action)
return jnp.concatenate((ate_food.astype(jnp.float32), act_input), axis=1)
return jnp.concatenate(
(ate_food.astype(jnp.float32), ate_toxin.astype(jnp.float32), act_input),
axis=1,
)


def serialize_weight(w: jax.Array) -> dict[str, jax.Array]:
Expand Down Expand Up @@ -110,8 +112,11 @@ def step_rollout(
env.act_space.sigmoid_scale(actions), # type: ignore
)
obs_t1 = timestep.obs
energy = state_t.status.energy
rewards = reward_fn(timestep.info["n_ate_food"], actions, energy).reshape(-1, 1)
rewards = reward_fn(
timestep.info["n_ate_food"],
jnp.expand_dims(timestep.info["n_ate_toxin"], axis=1),
actions,
).reshape(-1, 1)
rollout = ppo.Rollout(
observations=obs_t_array,
actions=actions,
Expand Down Expand Up @@ -464,7 +469,7 @@ def evolve(
reward_fn_instance = rfn.LinearReward(
key=reward_key,
n_agents=cfconfig.n_max_agents,
n_weights=cfconfig.n_food_sources, # Because one of the foods is toxin
n_weights=cfconfig.n_food_sources + 1,
std=gopsconfig.init_std,
mean=gopsconfig.init_mean,
extractor=reward_extracor.extract,
Expand Down

0 comments on commit cf653d0

Please sign in to comment.