Skip to content

Commit

Permalink
Fix report in circle_ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 22, 2024
1 parent d22c28d commit 371e69e
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions smoke-tests/circle_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import optax
import typer
from numpy.typing import NDArray
from serde import toml

from emevo import Env, make
Expand Down Expand Up @@ -187,7 +188,7 @@ def run_training(
figsize: tuple[float, float],
reset_interval: int | None = None,
debug_vis: bool = False,
) -> tuple[NormalPPONet, jax.Array]:
) -> tuple[NormalPPONet, NDArray]:
key, net_key, reset_key = jax.random.split(key, 3)
obs_space = env.obs_space.flatten()
input_size = int(np.prod(obs_space.shape))
Expand All @@ -204,7 +205,7 @@ def run_training(
obs = timestep.obs

n_loop = n_total_steps // n_rollout_steps
rewards = jnp.zeros(N_MAX_AGENTS)
rewards = None
keys = jax.random.split(key, n_loop)
if debug_vis:
visualizer = env.visualizer(env_state, figsize=figsize)
Expand All @@ -231,7 +232,10 @@ def run_training(
entropy_weight,
)
ri = np.array(jnp.sum(rewards_i, axis=0))
rewards = rewards + ri
if rewards is None:
rewards = ri
else:
rewards += ri
if visualizer is not None:
visualizer.render(env_state.physics) # type: ignore
visualizer.show()
Expand All @@ -241,8 +245,10 @@ def run_training(
env_state, timestep = env.reset(key)
obs = timestep.obs
# weight_summary(pponet)
print(f"Sum of rewards {[x.item() for x in rewards[: n_agents]]}")
return pponet, rewards
assert rewards is not None
for i in range(rewards.shape[1]):
print(f"Rewards for {i + 1}: {rewards[:n_agents, i]}")
return pponet, rewards # type: ignore


app = typer.Typer(pretty_exceptions_show_locals=False)
Expand Down Expand Up @@ -302,7 +308,7 @@ def train(
)
eqx.tree_serialise_leaves(modelpath, network)
if savelog_path is not None:
np.savez(savelog_path, np.array(rewards))
np.savez(savelog_path, rewards)


@app.command()
Expand Down

0 comments on commit 371e69e

Please sign in to comment.