diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index f27fe24..371fd21 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -601,83 +601,59 @@ def widget( @app.command() def vis_policy( - physstate_path: Path, - policy_path: list[Path], + physstate_path: list[Path], + policy_path: list[Path] = [], subtitle: list[str] | None = None, agent_index: int | None = None, cfconfig_path: Path = DEFAULT_CFCONFIG, fig_unit: float = 4.0, scale: float = 1.0, + seq_plot: bool = False, ) -> None: + from emevo.analysis.evaluate import eval_policy from emevo.analysis.policy import draw_cf_policy with cfconfig_path.open("r") as f: cfconfig = toml.from_toml(CfConfig, f.read()) cfconfig.n_initial_agents = 1 - # Load env state - phys_state = SavedPhysicsState.load(physstate_path) env = make("CircleForaging-v0", **dataclasses.asdict(cfconfig)) - key = jax.random.PRNGKey(0) - env_state, _ = env.reset(key) - loaded_phys = phys_state.set_by_index(..., env_state.physics) - env_state = dataclasses.replace(env_state, physics=loaded_phys) - # agent_index - if agent_index is None: - file_name = physstate_path.stem - if "slot" in file_name: - agent_index = int(file_name[file_name.index("slot") + 4 :]) - else: - print("Set --agent-index") - return - # Load agents - input_size = int(np.prod(env.obs_space.flatten().shape)) - act_size = int(np.prod(env.act_space.shape)) - ref_net = ppo.NormalPPONet(input_size, 64, act_size, key) - names, net_params = [], [] + max_force = max(cfconfig.max_force, -cfconfig.min_force) + + names = [] for policy_path_i, name in itertools.zip_longest( policy_path, [] if subtitle is None else subtitle, ): - pponet = eqx.tree_deserialise_leaves(policy_path_i, ref_net) - # Append only params of the network, excluding functions (etc. tanh). - net_params.append(eqx.filter(pponet, eqx.is_array)) names.append(policy_path_i.stem if name is None else name) - net_params = jax.tree.map(lambda *args: jnp.stack(args), *net_params) - network = eqx.combine(net_params, ref_net) - # Get obs - n_agents = cfconfig.n_max_agents - zero_action = jnp.zeros((n_agents, *env.act_space.shape)) - _, timestep = env.step(env_state, zero_action) - obs_array = timestep.obs.as_array() - obs_i = obs_array[agent_index] - - @eqx.filter_vmap(in_axes=(eqx.if_array(0), None)) - def evaluate(network: ppo.NormalPPONet, obs: jax.Array) -> ppo.Output: - return network(obs) - - # Get output - output = evaluate(network, obs_i) - # Make visualizer - visualizer = env.visualizer( - env_state, - figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale), - sensor_index=agent_index, - sensor_width=0.004, - sensor_color=np.array([0.0, 0.0, 0.0, 0.3], dtype=np.float32), - ) - visualizer.render(env_state.physics) - visualizer.show() - max_force = max(cfconfig.max_force, -cfconfig.min_force) - rot = env_state.physics.circle.p.angle[agent_index].item() - policy_mean = env.act_space.sigmoid_scale(output.mean) - draw_cf_policy( - names, - np.array(policy_mean), - rotation=rot, - fig_unit=fig_unit, - max_force=max_force, - ) + + # Get outputs + outputs = eval_policy(env, physstate_path, policy_path, agent_index) + if seq_plot: + pass + else: + visualizer = None + for output, env_state, ag_idx in outputs: + if visualizer is None: + visualizer = env.visualizer( + env_state, + figsize=(cfconfig.xlim[1] * scale, cfconfig.ylim[1] * scale), + sensor_index=ag_idx, + sensor_width=0.004, + sensor_color=np.array([0.0, 0.0, 0.0, 0.3], dtype=np.float32), + ) + env._sensor_index = ag_idx # type:ignore + visualizer.render(env_state.physics) + visualizer.show() + rot = env_state.physics.circle.p.angle[ag_idx].item() + policy_mean = env.act_space.sigmoid_scale(output.mean) + draw_cf_policy( + names, + np.array(policy_mean), + rotation=rot, + fig_unit=fig_unit, + max_force=max_force, + ) if __name__ == "__main__": diff --git a/src/emevo/analysis/policy.py b/src/emevo/analysis/policy.py index d573415..e1e6512 100644 --- a/src/emevo/analysis/policy.py +++ b/src/emevo/analysis/policy.py @@ -39,7 +39,7 @@ def draw_cf_policy( # Arrow points center = Vec2d(max_force * 1.5, max_force * 1.5) unit = Vec2d(0.0, 1.0) - d_unit = unit.rotated(math.pi) + d_unit = unit.rotated(rotation) s_left = unit.rotated(math.pi * 1.25 + rotation) * max_force * 0.5 + center s_right = unit.rotated(math.pi * 0.75 + rotation) * max_force * 0.5 + center # Draw the arrows diff --git a/src/emevo/environments/circle_foraging.py b/src/emevo/environments/circle_foraging.py index 54bab6d..8d0f6ce 100644 --- a/src/emevo/environments/circle_foraging.py +++ b/src/emevo/environments/circle_foraging.py @@ -737,6 +737,9 @@ def place_newborn_neighbor( self._smell_diff_max = smell_diff_max self._smell_diff_coef = smell_diff_coef + # Sensor index + self._sensor_index = 0 + @staticmethod def _make_food_num_fn( food_num_fn: str | tuple | FoodNumFn, @@ -1185,6 +1188,9 @@ def visualizer( """Create a visualizer for the environment""" from emevo.environments import moderngl_vis + if sensor_index is not None: + self._sensor_index = sensor_index + return moderngl_vis.MglVisualizer( x_range=self._x_range, y_range=self._y_range, @@ -1196,7 +1202,10 @@ def visualizer( sensor_fn=( self._get_sensors # type: ignore if sensor_index is None - else lambda stated: self._get_selected_sensor(stated, sensor_index) + else lambda stated: self._get_selected_sensor( + stated, + self._sensor_index, + ) ), **kwargs, )