diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index 371fd21..ebf2b58 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -611,7 +611,7 @@ def vis_policy( seq_plot: bool = False, ) -> None: from emevo.analysis.evaluate import eval_policy - from emevo.analysis.policy import draw_cf_policy + from emevo.analysis.policy import draw_cf_policy, draw_cf_policy_multi with cfconfig_path.open("r") as f: cfconfig = toml.from_toml(CfConfig, f.read()) @@ -630,7 +630,15 @@ def vis_policy( # Get outputs outputs = eval_policy(env, physstate_path, policy_path, agent_index) if seq_plot: - pass + policy_means = [np.array(output.mean) for output, _, _ in outputs] + rot = [state.physics.circle.p.angle[idx].item() for _, state, idx in outputs] + draw_cf_policy_multi( + names, + rot, + env.act_space.sigmoid_scale(np.stack(policy_means)), + fig_unit=fig_unit, + max_force=max_force, + ) else: visualizer = None for output, env_state, ag_idx in outputs: diff --git a/src/emevo/analysis/policy.py b/src/emevo/analysis/policy.py index e1e6512..20874a5 100644 --- a/src/emevo/analysis/policy.py +++ b/src/emevo/analysis/policy.py @@ -1,8 +1,8 @@ -"""Policy visualization -""" +"""Policy visualization""" import math +import matplotlib as mpl import numpy as np from matplotlib import pyplot as plt from matplotlib.patches import Arrow, Circle @@ -81,3 +81,68 @@ def draw_cf_policy( ax.add_patch(arrow) plt.show() + + +def draw_cf_policy_multi( + names: list[str], + rotations: list[float], + policy_means: NDArray, # (N-obs, N-agents, LR) + fig_unit: float, + max_force: float, +) -> None: + print(policy_means.shape) + n_policies = len(names) + n_obs, n_policies = policy_means.shape[:2] + fig, axes = plt.subplots( + nrows=n_obs, + ncols=n_policies, + figsize=(n_policies * fig_unit, n_obs * fig_unit), + ) + fig.tight_layout() + # Arrow points + center = Vec2d(max_force * 1.5, max_force * 1.5) + unit = Vec2d(0.0, 1.0) + # Draw the arrows + for i, (title, rot) in enumerate(zip(names, rotations)): + d_unit = unit.rotated(rot) + s_left = unit.rotated(math.pi * 1.25 + rot) * max_force * 0.5 + center + s_right = unit.rotated(math.pi * 0.75 + rot) * max_force * 0.5 + center + for j, policy_mean in enumerate(policy_means[i]): + ax = axes[i][j] + ax.set_xlim(0, max_force * 3) + ax.set_ylim(0, max_force * 3) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect("equal", adjustable="box") + if i == 0: + ax.set_title(title) + # Circle + circle = Circle((center.x, center.y), max_force * 0.5, fill=False) + ax.add_patch(circle) + # Left + d_left = d_unit * policy_mean[0].item() + s_left_shifted = s_left - d_left + arrow = Arrow( + s_left_shifted.x, + s_left_shifted.y, + d_left.x, + d_left.y, + # 10% of the width? Looks thinner... + width=max_force * 0.3, + color="r", + ) + ax.add_patch(arrow) + # Right + d_right = d_unit * policy_mean[1].item() + s_right_shifted = s_right - d_right + arrow = Arrow( + s_right_shifted.x, + s_right_shifted.y, + d_right.x, + d_right.y, + width=max_force * 0.3, + color="r", + ) + ax.add_patch(arrow) + + plt.show()