Skip to content

Commit

Permalink
Multipolicy plot
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 19, 2024
1 parent a7c529a commit aa7e9b4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
12 changes: 10 additions & 2 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def vis_policy(
seq_plot: bool = False,
) -> None:
from emevo.analysis.evaluate import eval_policy
from emevo.analysis.policy import draw_cf_policy
from emevo.analysis.policy import draw_cf_policy, draw_cf_policy_multi

with cfconfig_path.open("r") as f:
cfconfig = toml.from_toml(CfConfig, f.read())
Expand All @@ -630,7 +630,15 @@ def vis_policy(
# Get outputs
outputs = eval_policy(env, physstate_path, policy_path, agent_index)
if seq_plot:
pass
policy_means = [np.array(output.mean) for output, _, _ in outputs]
rot = [state.physics.circle.p.angle[idx].item() for _, state, idx in outputs]
draw_cf_policy_multi(
names,
rot,
env.act_space.sigmoid_scale(np.stack(policy_means)),
fig_unit=fig_unit,
max_force=max_force,
)
else:
visualizer = None
for output, env_state, ag_idx in outputs:
Expand Down
69 changes: 67 additions & 2 deletions src/emevo/analysis/policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Policy visualization
"""
"""Policy visualization"""

import math

import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Arrow, Circle
Expand Down Expand Up @@ -81,3 +81,68 @@ def draw_cf_policy(
ax.add_patch(arrow)

plt.show()


def draw_cf_policy_multi(
names: list[str],
rotations: list[float],
policy_means: NDArray, # (N-obs, N-agents, LR)
fig_unit: float,
max_force: float,
) -> None:
print(policy_means.shape)
n_policies = len(names)
n_obs, n_policies = policy_means.shape[:2]
fig, axes = plt.subplots(
nrows=n_obs,
ncols=n_policies,
figsize=(n_policies * fig_unit, n_obs * fig_unit),
)
fig.tight_layout()
# Arrow points
center = Vec2d(max_force * 1.5, max_force * 1.5)
unit = Vec2d(0.0, 1.0)
# Draw the arrows
for i, (title, rot) in enumerate(zip(names, rotations)):
d_unit = unit.rotated(rot)
s_left = unit.rotated(math.pi * 1.25 + rot) * max_force * 0.5 + center
s_right = unit.rotated(math.pi * 0.75 + rot) * max_force * 0.5 + center
for j, policy_mean in enumerate(policy_means[i]):
ax = axes[i][j]
ax.set_xlim(0, max_force * 3)
ax.set_ylim(0, max_force * 3)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal", adjustable="box")
if i == 0:
ax.set_title(title)
# Circle
circle = Circle((center.x, center.y), max_force * 0.5, fill=False)
ax.add_patch(circle)
# Left
d_left = d_unit * policy_mean[0].item()
s_left_shifted = s_left - d_left
arrow = Arrow(
s_left_shifted.x,
s_left_shifted.y,
d_left.x,
d_left.y,
# 10% of the width? Looks thinner...
width=max_force * 0.3,
color="r",
)
ax.add_patch(arrow)
# Right
d_right = d_unit * policy_mean[1].item()
s_right_shifted = s_right - d_right
arrow = Arrow(
s_right_shifted.x,
s_right_shifted.y,
d_right.x,
d_right.y,
width=max_force * 0.3,
color="r",
)
ax.add_patch(arrow)

plt.show()

0 comments on commit aa7e9b4

Please sign in to comment.