From cf653d0ba993d3eb10e7b80c7bbcc726ede33817 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Tue, 12 Nov 2024 18:55:55 +0900 Subject: [PATCH] Fix bug in toxin so that it has food_2 reward --- experiments/cf_simple.py | 13 ++++++++++--- experiments/cf_toxin.py | 17 +++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index c106ce7..f27fe24 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -501,7 +501,12 @@ def replay( end: int | None = None, cfconfig_path: Path = DEFAULT_CFCONFIG, env_override: str = "", + scale: float = 1.0, + force_cpu: bool = False, ) -> None: + if force_cpu: + jax.config.update("jax_default_device", jax.devices("cpu")[0]) + with cfconfig_path.open("r") as f: cfconfig = toml.from_toml(CfConfig, f.read()) # For speedup @@ -513,14 +518,16 @@ def replay( end_index = end if end is not None else phys_state.circle_axy.shape[0] visualizer = env.visualizer( env_state, - figsize=(cfconfig.xlim[1] * 2, cfconfig.ylim[1] * 2), + figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale), backend=backend, ) if videopath is not None: visualizer = SaveVideoWrapper(visualizer, videopath, fps=60) for i in range(start, end_index): - phys = phys_state.set_by_index(i, env_state.physics) - env_state = dataclasses.replace(env_state, physics=phys) + ph = phys_state.set_by_index(i, env_state.physics) + # Disable rendering agents + ph = ph.nested_replace("circle.is_active", jnp.zeros_like(ph.circle.is_active)) + env_state = dataclasses.replace(env_state, physics=ph) visualizer.render(env_state.physics) visualizer.show() visualizer.close() diff --git a/experiments/cf_toxin.py b/experiments/cf_toxin.py index fca98f0..2d9f7fd 100644 --- a/experiments/cf_toxin.py +++ b/experiments/cf_toxin.py @@ -70,12 +70,14 @@ def normalize_action(self, action: jax.Array) -> jax.Array: def extract( self, ate_food: jax.Array, + ate_toxin: jax.Array, action: jax.Array, - energy: jax.Array, ) -> jax.Array: - del energy act_input = self.act_coef * self.normalize_action(action) - return jnp.concatenate((ate_food.astype(jnp.float32), act_input), axis=1) + return jnp.concatenate( + (ate_food.astype(jnp.float32), ate_toxin.astype(jnp.float32), act_input), + axis=1, + ) def serialize_weight(w: jax.Array) -> dict[str, jax.Array]: @@ -110,8 +112,11 @@ def step_rollout( env.act_space.sigmoid_scale(actions), # type: ignore ) obs_t1 = timestep.obs - energy = state_t.status.energy - rewards = reward_fn(timestep.info["n_ate_food"], actions, energy).reshape(-1, 1) + rewards = reward_fn( + timestep.info["n_ate_food"], + jnp.expand_dims(timestep.info["n_ate_toxin"], axis=1), + actions, + ).reshape(-1, 1) rollout = ppo.Rollout( observations=obs_t_array, actions=actions, @@ -464,7 +469,7 @@ def evolve( reward_fn_instance = rfn.LinearReward( key=reward_key, n_agents=cfconfig.n_max_agents, - n_weights=cfconfig.n_food_sources, # Because one of the foods is toxin + n_weights=cfconfig.n_food_sources + 1, std=gopsconfig.init_std, mean=gopsconfig.init_mean, extractor=reward_extracor.extract,