Skip to content

Commit

Permalink
Support multiple physstate path in vis_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 18, 2024
1 parent a4d06fd commit a7c529a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 61 deletions.
94 changes: 35 additions & 59 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,83 +601,59 @@ def widget(

@app.command()
def vis_policy(
physstate_path: Path,
policy_path: list[Path],
physstate_path: list[Path],
policy_path: list[Path] = [],
subtitle: list[str] | None = None,
agent_index: int | None = None,
cfconfig_path: Path = DEFAULT_CFCONFIG,
fig_unit: float = 4.0,
scale: float = 1.0,
seq_plot: bool = False,
) -> None:
from emevo.analysis.evaluate import eval_policy
from emevo.analysis.policy import draw_cf_policy

with cfconfig_path.open("r") as f:
cfconfig = toml.from_toml(CfConfig, f.read())

cfconfig.n_initial_agents = 1
# Load env state
phys_state = SavedPhysicsState.load(physstate_path)
env = make("CircleForaging-v0", **dataclasses.asdict(cfconfig))
key = jax.random.PRNGKey(0)
env_state, _ = env.reset(key)
loaded_phys = phys_state.set_by_index(..., env_state.physics)
env_state = dataclasses.replace(env_state, physics=loaded_phys)
# agent_index
if agent_index is None:
file_name = physstate_path.stem
if "slot" in file_name:
agent_index = int(file_name[file_name.index("slot") + 4 :])
else:
print("Set --agent-index")
return
# Load agents
input_size = int(np.prod(env.obs_space.flatten().shape))
act_size = int(np.prod(env.act_space.shape))
ref_net = ppo.NormalPPONet(input_size, 64, act_size, key)
names, net_params = [], []
max_force = max(cfconfig.max_force, -cfconfig.min_force)

names = []
for policy_path_i, name in itertools.zip_longest(
policy_path,
[] if subtitle is None else subtitle,
):
pponet = eqx.tree_deserialise_leaves(policy_path_i, ref_net)
# Append only params of the network, excluding functions (etc. tanh).
net_params.append(eqx.filter(pponet, eqx.is_array))
names.append(policy_path_i.stem if name is None else name)
net_params = jax.tree.map(lambda *args: jnp.stack(args), *net_params)
network = eqx.combine(net_params, ref_net)
# Get obs
n_agents = cfconfig.n_max_agents
zero_action = jnp.zeros((n_agents, *env.act_space.shape))
_, timestep = env.step(env_state, zero_action)
obs_array = timestep.obs.as_array()
obs_i = obs_array[agent_index]

@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def evaluate(network: ppo.NormalPPONet, obs: jax.Array) -> ppo.Output:
return network(obs)

# Get output
output = evaluate(network, obs_i)
# Make visualizer
visualizer = env.visualizer(
env_state,
figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale),
sensor_index=agent_index,
sensor_width=0.004,
sensor_color=np.array([0.0, 0.0, 0.0, 0.3], dtype=np.float32),
)
visualizer.render(env_state.physics)
visualizer.show()
max_force = max(cfconfig.max_force, -cfconfig.min_force)
rot = env_state.physics.circle.p.angle[agent_index].item()
policy_mean = env.act_space.sigmoid_scale(output.mean)
draw_cf_policy(
names,
np.array(policy_mean),
rotation=rot,
fig_unit=fig_unit,
max_force=max_force,
)

# Get outputs
outputs = eval_policy(env, physstate_path, policy_path, agent_index)
if seq_plot:
pass
else:
visualizer = None
for output, env_state, ag_idx in outputs:
if visualizer is None:
visualizer = env.visualizer(
env_state,
figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale),
sensor_index=ag_idx,
sensor_width=0.004,
sensor_color=np.array([0.0, 0.0, 0.0, 0.3], dtype=np.float32),
)
env._sensor_index = ag_idx # type:ignore
visualizer.render(env_state.physics)
visualizer.show()
rot = env_state.physics.circle.p.angle[ag_idx].item()
policy_mean = env.act_space.sigmoid_scale(output.mean)
draw_cf_policy(
names,
np.array(policy_mean),
rotation=rot,
fig_unit=fig_unit,
max_force=max_force,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/analysis/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def draw_cf_policy(
# Arrow points
center = Vec2d(max_force * 1.5, max_force * 1.5)
unit = Vec2d(0.0, 1.0)
d_unit = unit.rotated(math.pi)
d_unit = unit.rotated(rotation)
s_left = unit.rotated(math.pi * 1.25 + rotation) * max_force * 0.5 + center
s_right = unit.rotated(math.pi * 0.75 + rotation) * max_force * 0.5 + center
# Draw the arrows
Expand Down
11 changes: 10 additions & 1 deletion src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,9 @@ def place_newborn_neighbor(
self._smell_diff_max = smell_diff_max
self._smell_diff_coef = smell_diff_coef

# Sensor index
self._sensor_index = 0

@staticmethod
def _make_food_num_fn(
food_num_fn: str | tuple | FoodNumFn,
Expand Down Expand Up @@ -1185,6 +1188,9 @@ def visualizer(
"""Create a visualizer for the environment"""
from emevo.environments import moderngl_vis

if sensor_index is not None:
self._sensor_index = sensor_index

return moderngl_vis.MglVisualizer(
x_range=self._x_range,
y_range=self._y_range,
Expand All @@ -1196,7 +1202,10 @@ def visualizer(
sensor_fn=(
self._get_sensors # type: ignore
if sensor_index is None
else lambda stated: self._get_selected_sensor(stated, sensor_index)
else lambda stated: self._get_selected_sensor(
stated,
self._sensor_index,
)
),
**kwargs,
)

0 comments on commit a7c529a

Please sign in to comment.