Skip to content

Commit

Permalink
Replay widget
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jan 12, 2024
1 parent a3edbce commit 6642ded
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 50 deletions.
40 changes: 36 additions & 4 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def replace_net(
)

if visualizer is not None:
visualizer.render(env_state)
visualizer.render(env_state.physics) # type: ignore
visualizer.show()
# Extinct?
n_active = jnp.sum(env_state.unique_id.is_active()) # type: ignore
Expand Down Expand Up @@ -429,8 +429,6 @@ def evolve(
"alpha_action": slice_last(alpha, 3),
},
)
print(reward_fn_instance.alpha)
print(reward_fn_instance.weight)
else:
raise ValueError(f"Invalid reward_fn {reward_fn}")

Expand Down Expand Up @@ -491,10 +489,44 @@ def replay(
for i in range(start, end_index):
phys = phys_state.set_by_index(i, env_state.physics)
env_state = dataclasses.replace(env_state, physics=phys)
visualizer.render(env_state)
visualizer.render(env_state.physics)
visualizer.show()
visualizer.close()


@app.command()
def widget(
physstate_path: Path,
n_agents: int = 20,
start: int = 0,
end: Optional[int] = None,
cfconfig_path: Path = here.joinpath("../config/env/20231214-square.toml"),
env_override: str = "",
) -> None:
import sys

from PySide6.QtWidgets import QApplication

from emevo.analysis.qt_widget import CFEnvReplayWidget

with cfconfig_path.open("r") as f:
cfconfig = toml.from_toml(CfConfig, f.read())
cfconfig.n_initial_agents = n_agents
cfconfig.apply_override(env_override)
phys_state = SavedPhysicsState.load(physstate_path)
env = make("CircleForaging-v0", **dataclasses.asdict(cfconfig))
end_index = end if end is not None else phys_state.circle_axy.shape[0]

app = QApplication([])
widget = CFEnvReplayWidget(
int(cfconfig.xlim[1]),
int(cfconfig.ylim[1]),
env=env, # type: ignore
saved_physics=phys_state,
)
widget.show()
sys.exit(app.exec())


if __name__ == "__main__":
app()
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
analysis = ["networkx >= 3.0", "pygraphviz >= 1.0"]
analysis = [
"matplotlib >= 3.0",
"networkx >= 3.0",
"pygraphviz >= 1.0",
"PySide6 >= 6.5",
]
video = ["imageio-ffmpeg >= 0.4"]
widget = ["PySide6 >= 6.5"]

[project.readme]
file = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion smoke-tests/circle_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(
print("Parents: ", parents)

if visualizer is not None:
visualizer.render(state)
visualizer.render(state.physics)
visualizer.show()

print(f"Avg. μs for step: {np.mean(elapsed_list)}")
Expand Down
5 changes: 3 additions & 2 deletions smoke-tests/circle_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def step(key: chex.PRNGKey, state: State, obs: Obs) -> tuple[State, Obs, jax.Arr

for key in keys[1:]:
state, obs, act = step(key, state, obs)
del act
# print(f"Act: {act[0]}")
visualizer.render(state)
visualizer.render(state.physics) # type: ignore
visualizer.show()


Expand Down Expand Up @@ -201,7 +202,7 @@ def run_training(
ri = jnp.sum(jnp.squeeze(rewards_i, axis=-1), axis=0)
rewards = rewards + ri
if visualizer is not None:
visualizer.render(env_state)
visualizer.render(env_state.physics) # type: ignore
visualizer.show()
print(f"Rewards: {[x.item() for x in ri[: n_agents]]}")
if reset:
Expand Down
142 changes: 122 additions & 20 deletions src/emevo/environments/qt_vis.py → src/emevo/analysis/qt_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
from functools import partial
from typing import Callable

import jax
import matplotlib as mpl
import matplotlib.colors as mc
import moderngl
import numpy as np
import pyarrow as pa
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
from numpy.typing import NDArray
from PySide6 import QtWidgets
from PySide6.QtCharts import (
QBarCategoryAxis,
QBarSeries,
Expand All @@ -19,13 +25,15 @@
QChartView,
QValueAxis,
)
from PySide6.QtCore import Qt, QTimer, Signal, Slot
from PySide6.QtCore import QPointF, Qt, QTimer, Signal, Slot
from PySide6.QtGui import QGuiApplication, QMouseEvent, QPainter, QSurfaceFormat
from PySide6.QtOpenGLWidgets import QOpenGLWidget
from PySide6 import QtWidgets

from emevo.environments.circle_foraging import CircleForaging
from emevo.environments.moderngl_vis import MglRenderer
from emevo.environments.phyjax2d import Space, StateDict
from emevo.environments.phyjax2d import StateDict
from emevo.exp_utils import SavedPhysicsState
from emevo.plotting import CBarRenderer


def _mgl_qsurface_fmt() -> QSurfaceFormat:
Expand All @@ -51,38 +59,52 @@ class MglWidget(QOpenGLWidget):
def __init__(
self,
*,
x_range: float,
y_range: float,
space: Space,
stated: StateDict,
timer: QTimer,
env: CircleForaging,
saved_physics: SavedPhysicsState,
figsize: tuple[float, float],
sensor_fn: Callable[[StateDict], tuple[NDArray, NDArray]] | None = None,
parent: QtWidgets.QWidget | None = None,
) -> None:
# Set default format
QSurfaceFormat.setDefaultFormat(_mgl_qsurface_fmt())
super().__init__(parent)
# init renderer
self._env_state, _ = env.reset(jax.random.PRNGKey(0))
self._figsize = int(figsize[0]), int(figsize[1])
x_range, y_range = env._x_range, env._y_range
self._scaling = x_range / figsize[0], y_range / figsize[1]
self._phys_state = saved_physics
self._make_renderer = partial(
MglRenderer,
screen_width=self._figsize[0],
screen_height=self._figsize[1],
x_range=x_range,
y_range=y_range,
space=space,
stated=stated,
sensor_fn=sensor_fn,
space=env._physics,
stated=self._get_stated(0),
sensor_fn=env._get_sensors,
)
self._index = 0
self._state = AppState()
self._initialized = False
self._overlay_fns = []
self._initial_state = stated

# Set timer
self._timer = timer
self._timer.timeout.connect(self.update)

self.setFixedSize(*self._figsize)
self.setMouseTracking(True)

def _scale_position(self, position: QPointF) -> tuple[float, float]:
return (
position.x() * self._scaling[0],
(self._figsize[1] - position.y()) * self._scaling[1],
)

def _get_stated(self, index: int) -> StateDict:
return self._phys_state.set_by_index(index, self._env_state.physics)

def _set_default_viewport(self) -> None:
self._ctx.viewport = 0, 0, *self._figsize
self._fbo.viewport = 0, 0, *self._figsize
Expand All @@ -102,24 +124,21 @@ def paintGL(self) -> None:
self._fbo = self._ctx.detect_framebuffer()
self._renderer = self._make_renderer(self._ctx)
self._initialized = True
self.render(self._initial_state)
self._index += 1
self._render(self._get_stated(self._index))

def render(self, stated: StateDict) -> None:
def _render(self, stated: StateDict) -> None:
self._fbo.use()
self._ctx.clear(1.0, 1.0, 1.0)
self._renderer.render(stated) # type: ignore

def show(self, timer: QTimer):
self._timer = timer
self._timer.timeout.connect(self.update) # type: ignore

def _emit_selected(self, index: int | None) -> None:
if index is None:
self.selectionChanged.emit(-1)
else:
self.selectionChanged.emit(index)

def mousePressEvent(self, evt: QMouseEvent) -> None:
def mousePressEvent(self, evt: QMouseEvent) -> None: # type: ignore
position = self._scale_position(evt.position())
# query = self._env.get_space().point_query(
# position,
Expand All @@ -138,7 +157,7 @@ def mousePressEvent(self, evt: QMouseEvent) -> None:
# self._timer.stop()
# self.update()

def mouseReleaseEvent(self, evt: QMouseEvent) -> None:
def mouseReleaseEvent(self, evt: QMouseEvent) -> None: # type: ignore
pass

@Slot()
Expand Down Expand Up @@ -242,3 +261,86 @@ def updateValues(self, values: dict[str, float | list[float]]) -> None:
new_barsets.popleft().setColor(old_bs.color())
self.series.remove(old_bs)
self._update_yrange(values.values())


class CFEnvReplayWidget(QtWidgets.QWidget):
energyUpdated = Signal(float)
rewardUpdated = Signal(dict)
foodrankUpdated = Signal(dict)
valueUpdated = Signal(float)

def __init__(
self,
xlim: int,
ylim: int,
env: CircleForaging,
saved_physics: SavedPhysicsState,
profile_and_reward: pa.Table | None = None,
) -> None:
super().__init__()

timer = QTimer()
# Environment
self._mgl_widget = MglWidget(
timer=timer,
env=env,
saved_physics=saved_physics,
figsize=(xlim * 2, ylim * 2),
)
# Pause/Play
self._pause_button = QtWidgets.QPushButton("⏸️")
self._pause_button.clicked.connect(self._mgl_widget.pause)
self._play_button = QtWidgets.QPushButton("▶️")
self._play_button.clicked.connect(self._mgl_widget.play)
self._cbar_select_button = QtWidgets.QPushButton("Switch Value/Energy")
self._cbar_select_button.clicked.connect(self.change_cbar)
# Colorbar
self._cbar_renderer = CBarRenderer(xlim * 2, ylim // 4)
self._showing_energy = True
self._cbar_changed = True
self._cbar_canvas = FigureCanvasQTAgg(self._cbar_renderer._fig)
self._value_cm = mpl.colormaps["YlOrRd"]
self._energy_cm = mpl.colormaps["YlGnBu"]
self._norm = mc.Normalize(vmin=0.0, vmax=1.0)
if profile_and_reward is not None:
self._reward_widget = BarChart(
next(iter(self._rewards.values())).to_pydict()
)
# Layout buttons
buttons = QtWidgets.QHBoxLayout()
buttons.addWidget(self._pause_button)
buttons.addWidget(self._play_button)
buttons.addWidget(self._cbar_select_button)
# Total layout
total_layout = QtWidgets.QVBoxLayout()
total_layout.addLayout(buttons)
total_layout.addWidget(self._cbar_canvas)
if profile_and_reward is None:
total_layout.addWidget(self._mgl_widget)
else:
env_and_reward_layout = QtWidgets.QHBoxLayout()
env_and_reward_layout.addWidget(self._mgl_widget)
env_and_reward_layout.addWidget(self._reward_widget)
total_layout.addLayout(env_and_reward_layout)
self.setLayout(total_layout)
timer.start(30) # 40fps
self._arrow_cached = None
self._obs_cached = {}
# Signals
self._mgl_widget.selectionChanged.connect(self.updateRewards)
if profile_and_reward is not None:
self.rewardUpdated.connect(self._reward_widget.updateValues)
# Initial size
self.resize(xlim * 3, int(ylim * 2.4))

@Slot(int)
def updateRewards(self, body_index: int) -> None:
pass
# if self._rewards is None or body_index == -1:
# return
# self.rewardUpdated.emit(self._rewards[body_index].to_pydict())

@Slot()
def change_cbar(self) -> None:
self._showing_energy = not self._showing_energy
self._cbar_changed = True
2 changes: 1 addition & 1 deletion src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def visualizer(
figsize: tuple[float, float] | None = None,
backend: str = "pyglet",
**kwargs,
) -> Visualizer:
) -> Visualizer[StateDict]:
"""Create a visualizer for the environment"""
from emevo.environments import moderngl_vis

Expand Down
15 changes: 5 additions & 10 deletions src/emevo/environments/moderngl_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from __future__ import annotations

from typing import Any, Callable, ClassVar, Protocol
from typing import Any, Callable, ClassVar

import jax.numpy as jnp
import moderngl as mgl
Expand All @@ -16,10 +16,6 @@
from emevo.environments.phyjax2d import Circle, Segment, Space, State, StateDict


class HasStateD(Protocol):
stated: StateDict


NOWHERE: float = -1000.0


Expand Down Expand Up @@ -598,16 +594,15 @@ def get_image(self) -> NDArray:
w, h = self._figsize
return output.reshape(h, w, -1)[::-1]

def overlay(self, name: str, value: Any) -> None:
def overlay(self, name: str, value: Any) -> Any:
self._renderer.overlay(name, value)

def render(self, state: HasStateD) -> None:
def render(self, state: StateDict) -> None:
self._window.clear(1.0, 1.0, 1.0)
self._window.use()
self._renderer.render(stated=state.stated)
self._renderer.render(stated=state)

def show(self, *args, **kwargs) -> None:
del args, kwargs
def show(self) -> None:
self._window.swap_buffers()


Expand Down
Loading

0 comments on commit 6642ded

Please sign in to comment.