Skip to content

Commit

Permalink
FoodLog
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Mar 16, 2024
1 parent f5427c4 commit 3aa4066
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 13 deletions.
24 changes: 15 additions & 9 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from emevo.exp_utils import (
BDConfig,
CfConfig,
FoodLog,
GopsConfig,
Log,
Logger,
Expand Down Expand Up @@ -99,11 +100,11 @@ def exec_rollout(
birth_fn: bd.BirthFunction,
prng_key: jax.Array,
n_rollout_steps: int,
) -> tuple[State, Rollout, Log, SavedPhysicsState, Obs, jax.Array]:
) -> tuple[State, Rollout, Log, FoodLog, SavedPhysicsState, Obs, jax.Array]:
def step_rollout(
carried: tuple[State, Obs],
key: jax.Array,
) -> tuple[tuple[State, Obs], tuple[Rollout, Log, SavedPhysicsState]]:
) -> tuple[tuple[State, Obs], tuple[Rollout, Log, FoodLog, SavedPhysicsState]]:
act_key, hazard_key, birth_key = jax.random.split(key, 3)
state_t, obs_t = carried
obs_t_array = obs_t.as_array()
Expand Down Expand Up @@ -147,6 +148,10 @@ def step_rollout(
unique_id=state_t1db.unique_id.unique_id,
consumed_energy=timestep.info["energy_consumption"],
)
foodlog = FoodLog(
eaten=timestep.info["food_eaten"],
regenerated=timestep.info["food_regeneration"],
)
phys = state_t.physics # type: ignore
phys_state = SavedPhysicsState(
circle_axy=phys.circle.p.into_axy(),
Expand All @@ -155,15 +160,15 @@ def step_rollout(
static_circle_is_active=phys.static_circle.is_active,
static_circle_label=phys.static_circle.label,
)
return (state_t1db, obs_t1), (rollout, log, phys_state)
return (state_t1db, obs_t1), (rollout, log, foodlog, phys_state)

(state, obs), (rollout, log, phys_state) = jax.lax.scan(
(state, obs), (rollout, log, foodlog, phys_state) = jax.lax.scan(
step_rollout,
(state, initial_obs),
jax.random.split(prng_key, n_rollout_steps),
)
next_value = vmap_value(network, obs.as_array())
return state, rollout, log, phys_state, obs, next_value
return state, rollout, log, foodlog, phys_state, obs, next_value


@eqx.filter_jit
Expand All @@ -183,9 +188,9 @@ def epoch(
opt_state: optax.OptState,
minibatch_size: int,
n_optim_epochs: int,
) -> tuple[State, Obs, Log, SavedPhysicsState, optax.OptState, NormalPPONet]:
) -> tuple[State, Obs, Log, FoodLog, SavedPhysicsState, optax.OptState, NormalPPONet]:
keys = jax.random.split(prng_key, env.n_max_agents + 1)
env_state, rollout, log, phys_state, obs, next_value = exec_rollout(
env_state, rollout, log, foodlog, phys_state, obs, next_value = exec_rollout(
state,
initial_obs,
env,
Expand All @@ -208,7 +213,7 @@ def epoch(
0.2,
0.0,
)
return env_state, obs, log, phys_state, opt_state, pponet
return env_state, obs, log, foodlog, phys_state, opt_state, pponet


def run_evolution(
Expand Down Expand Up @@ -287,7 +292,7 @@ def replace_net(

for i, key in enumerate(jax.random.split(key, n_total_steps // n_rollout_steps)):
epoch_key, init_key = jax.random.split(key)
env_state, obs, log, phys_state, opt_state, pponet = epoch(
env_state, obs, log, foodlog, phys_state, opt_state, pponet = epoch(
env_state,
obs,
env,
Expand Down Expand Up @@ -345,6 +350,7 @@ def replace_net(

# Push log and physics state
logger.push_log(log_with_step.filter_active())
logger.push_foodlog(foodlog)
logger.push_physstate(phys_state)

# Save logs before exiting
Expand Down
22 changes: 18 additions & 4 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def step(
)
# Remove and regenerate foods
key, food_key = jax.random.split(state.key)
stated, food_num, food_loc = self._remove_and_regenerate_foods(
stated, food_num, food_loc, n_eaten, n_re = self._remove_and_regenerate_foods(
food_key,
jnp.max(c2sc, axis=0),
stated,
Expand All @@ -773,7 +773,11 @@ def step(
timestep = TimeStep(
encount=c2c,
obs=obs,
info={"energy_consumption": energy_consumption},
info={
"energy_consumption": energy_consumption,
"food_regeneration": n_re.astype(bool),
"food_eaten": n_eaten,
},
)
state = CFState(
physics=stated,
Expand Down Expand Up @@ -990,7 +994,9 @@ def _remove_and_regenerate_foods(
n_steps: jax.Array,
food_num_states: list[FoodNumState],
food_loc_states: list[LocatingState],
) -> tuple[StateDict, list[FoodNumState], list[LocatingState]]:
) -> tuple[
StateDict, list[FoodNumState], list[LocatingState], jax.Array, jax.Array
]:
# Remove foods
xy = jnp.where(
jnp.expand_dims(eaten, axis=1),
Expand All @@ -1005,6 +1011,7 @@ def _remove_and_regenerate_foods(
)
sc = sd.static_circle
# Regenerate food for each source
n_generated_foods = jnp.zeros(self._n_food_sources, dtype=jnp.int32)
for i in range(self._n_food_sources):
food_num = self._food_num_fns[i](
n_steps,
Expand Down Expand Up @@ -1034,7 +1041,14 @@ def _remove_and_regenerate_foods(
incr = jnp.sum(place)
food_num_states[i] = food_num.recover(incr)
food_loc_states[i] = food_loc.increment(incr)
return replace(sd, static_circle=sc), food_num_states, food_loc_states
n_generated_foods = n_generated_foods.at[i].add(incr)
return (
replace(sd, static_circle=sc),
food_num_states,
food_loc_states,
eaten_per_source,
n_generated_foods,
)

def visualizer(
self,
Expand Down
38 changes: 38 additions & 0 deletions src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def with_step(self, from_: int) -> LogWithStep:
)


@chex.dataclass
class FoodLog:
eaten: jax.Array # i32, [N_FOOD_SOURCES,]
regenerated: jax.Array # bool, [N_FOOD_SOURCES,]


@chex.dataclass
class LogWithStep(Log):
step: jax.Array
Expand Down Expand Up @@ -284,6 +290,7 @@ class Logger:
reward_fn_dict: dict[int, RewardFn] = dataclasses.field(default_factory=dict)
profile_dict: dict[int, SavedProfile] = dataclasses.field(default_factory=dict)
_log_list: list[Log] = dataclasses.field(default_factory=list, init=False)
_foodlog_list: list[FoodLog] = dataclasses.field(default_factory=list, init=False)
_physstate_list: list[SavedPhysicsState] = dataclasses.field(
default_factory=list,
init=False,
Expand Down Expand Up @@ -321,6 +328,37 @@ def _save_log(self) -> None:
self._log_index += 1
self._log_list.clear()

def push_foodlog(self, log: FoodLog) -> None:
if self.mode not in [LogMode.FULL, LogMode.REWARD_AND_LOG]:
return

# Move log to CPU
self._foodlog_list.append(jax.tree_map(np.array, log))

if len(self._log_list) % self.log_interval == 0:
self._save_foodlog()

def _save_foodlog(self) -> None:
if len(self._log_list) == 0:
return

all_log = jax.tree_map(
lambda *args: np.stack(args, axis=0),
*self._log_list,
)
log_dict = {}
for i in range(all_log.eaten.shape[1]):
log_dict[f"eaten_{i}"] = all_log.eaten[:, i]
log_dict[f"regen_{i}"] = all_log.regenerated[:, i]

# Don't change log_index here
pq.write_table(
pa.Table.from_pydict(log_dict),
self.logdir.joinpath(f"foodlog-{self._log_index}.parquet"),
compression="zstd",
)
self._foodlog_list.clear()

def push_physstate(self, phys_state: SavedPhysicsState) -> None:
if self.mode != LogMode.FULL:
return
Expand Down

0 comments on commit 3aa4066

Please sign in to comment.