Skip to content

Commit

Permalink
Add action reward selector to widget
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Oct 22, 2024
1 parent 37e9f10 commit 1d4a157
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
77 changes: 60 additions & 17 deletions src/emevo/analysis/qt_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
QValueAxis,
)
from PySide6.QtCore import QPointF, Qt, QTimer, Signal, Slot
from PySide6.QtGui import QGuiApplication, QMouseEvent, QPainter, QSurfaceFormat
from PySide6.QtGui import (
QColorSpace,
QGuiApplication,
QMouseEvent,
QPainter,
QSurfaceFormat,
)
from PySide6.QtOpenGLWidgets import QOpenGLWidget

from emevo.environments.circle_foraging import CircleForaging
Expand All @@ -49,6 +55,7 @@ def _mgl_qsurface_fmt() -> QSurfaceFormat:
fmt.setRenderableType(QSurfaceFormat.OpenGL)
fmt.setProfile(QSurfaceFormat.OpenGLContextProfile.CoreProfile)
fmt.setSwapBehavior(QSurfaceFormat.SwapBehavior.DoubleBuffer)
fmt.setColorSpace(QColorSpace.SRgb)
return fmt


Expand Down Expand Up @@ -97,8 +104,9 @@ def __init__(
y_range=y_range,
space=env._physics,
stated=self._get_stated(),
sensor_fn=env._get_sensors,
sc_color_opt=env._food_color,
sensor_color=np.array([0.0, 0.0, 0.0, 0.2], dtype=np.float32),
sensor_fn=self._get_selected_sensors,
)
self._env = env
self._get_colors = get_colors
Expand All @@ -116,6 +124,20 @@ def __init__(
self.setFixedSize(*self._figsize)
self.setMouseTracking(True)
self._ctx, self._fbo = None, None
self._selected = None

def _get_selected_sensors(self, *args, **kwargs) -> tuple[jax.Array, jax.Array]:
p1, p2 = self._env._get_sensors(*args, **kwargs)
zeros = jnp.ones_like(p1)
if self._selected is None:
return zeros, zeros
else:
i = self._selected.item()
from_ = i * self._env._n_sensors
to = (i + 1) * self._env._n_sensors
p1 = zeros.at[from_:to].add(p1[from_:to])
p2 = zeros.at[from_:to].add(p2[from_:to])
return p1, p2

def _scale_position(self, position: QPointF) -> tuple[float, float]:
return (
Expand Down Expand Up @@ -173,6 +195,7 @@ def mousePressEvent(self, evt: QMouseEvent) -> None: # type: ignore
)
(selected,) = jnp.nonzero(overlap)
if 0 < selected.shape[0]:
self._selected = selected
self.selectionChanged.emit(selected[0].item(), self._index)

@Slot()
Expand Down Expand Up @@ -287,6 +310,7 @@ class CBarState(str, enum.Enum):
ENERGY = "energy"
N_CHILDREN = "n-children"
FOOD_REWARD = "food-reward"
ACTION_REWARD = "action-reward"
FOOD_REWARD2 = "food-reward2" # Poison or poor foods


Expand Down Expand Up @@ -352,12 +376,14 @@ def __init__(
radiobutton_1 = QtWidgets.QRadioButton("Energy")
radiobutton_2 = QtWidgets.QRadioButton("Num. Children")
radiobutton_3 = QtWidgets.QRadioButton("Food Reward")
radiobutton_4 = QtWidgets.QRadioButton("Another Food Reward")
radiobutton_4 = QtWidgets.QRadioButton("Action Reward")
radiobutton_5 = QtWidgets.QRadioButton("Another Food Reward")
radiobutton_1.setChecked(True)
radiobutton_1.toggled.connect(self.cbarEnergy)
radiobutton_2.toggled.connect(self.cbarNChildren)
radiobutton_3.toggled.connect(self.cbarFood)
radiobutton_4.toggled.connect(self.cbarFood2)
radiobutton_4.toggled.connect(self.cbarAction)
radiobutton_5.toggled.connect(self.cbarFood2)
self._cbar_state = CBarState.ENERGY
self._cbar_renderer = CBarRenderer(int(xlim * 2), int(ylim * 0.4))
self._showing_energy = True
Expand All @@ -367,6 +393,7 @@ def __init__(
self._energy_cm = mpl.colormaps["YlGnBu"]
self._n_children_cm = mpl.colormaps["PuBuGn"]
self._food_cm = mpl.colormaps["plasma"]
self.__cm = mpl.colormaps["plasma"]
self._norm = mc.Normalize(vmin=0.0, vmax=1.0)
self._cm_fixed_minmax = {} if cm_fixed_minmax is None else cm_fixed_minmax
if profile_and_rewards is not None:
Expand All @@ -381,11 +408,12 @@ def __init__(
left_control.addLayout(buttons)
left_control.addWidget(self._slider_label)
left_control.addWidget(self._slider)
cbar_selector = QtWidgets.QVBoxLayout()
cbar_selector.addWidget(radiobutton_1)
cbar_selector.addWidget(radiobutton_2)
cbar_selector.addWidget(radiobutton_3)
cbar_selector.addWidget(radiobutton_4)
cbar_selector = QtWidgets.QGridLayout()
cbar_selector.addWidget(radiobutton_1, 0, 0)
cbar_selector.addWidget(radiobutton_2, 1, 0)
cbar_selector.addWidget(radiobutton_3, 0, 1)
cbar_selector.addWidget(radiobutton_4, 1, 1)
cbar_selector.addWidget(radiobutton_5, 2, 1)
control = QtWidgets.QHBoxLayout()
control.addLayout(left_control)
control.addLayout(cbar_selector)
Expand All @@ -411,9 +439,9 @@ def __init__(
self.rewardUpdated.connect(self._reward_widget.updateValues)
# Initial size
if profile_and_rewards is None:
self.resize(xlim * scale * 1.5, ylim * scale * 1.5)
self.resize(xlim * scale * 1.6, ylim * scale * 1.75)
else:
self.resize(xlim * scale * 1.5, ylim * scale * 1.2)
self.resize(xlim * scale * 1.6, ylim * scale * 1.4)
self._self_terminate = self_terminate

def _check_exit(self) -> None:
Expand Down Expand Up @@ -487,16 +515,25 @@ def _get_colors(self, step_index: int) -> NDArray:
value = np.zeros(self._n_max_agents)
for slot, uid in zip(log["slots"], log["unique_id"]):
rew = self._get_rewards(uid)
if "scale_food" in rew:
rew_food = rew["w_food"] * (10 ** rew["scale_food"])
elif "food" in rew:
rew_food = rew["food"]
elif "food_1" in rew:
if "food_1" in rew:
rew_food = rew["food_1"]
else:
warnings.warn("Unsupported reward", stacklevel=1)
rew_food = 0.0
value[slot] = rew_food
elif self._cbar_state is CBarState.ACTION_REWARD:
title = "Action Reward"
cm = self._food_cm
value = np.zeros(self._n_max_agents)
for slot, uid in zip(log["slots"], log["unique_id"]):
rew = self._get_rewards(uid)
if "action" in rew:
rew_food = rew["action"]
else:
print(rew)
warnings.warn("Unsupported reward", stacklevel=1)
rew_food = 0.0
value[slot] = rew_food
elif self._cbar_state is CBarState.FOOD_REWARD2:
title = "Food Reward"
cm = self._food_cm
Expand Down Expand Up @@ -544,7 +581,7 @@ def updateRewards(self, selected_slot: int, step_index: int) -> None:
for slot, uid in zip(log["slots"], log["unique_id"]):
if slot == selected_slot:
self.rewardUpdated.emit(
f"Reward function of {uid}",
f"Reward function of {uid} (slot: {slot})",
self._get_rewards(uid),
)
return
Expand All @@ -567,6 +604,12 @@ def cbarFood(self, checked: bool) -> None:
self._cbar_state = CBarState.FOOD_REWARD
self._cbar_changed = True

@Slot(bool)
def cbarAction(self, checked: bool) -> None:
if checked:
self._cbar_state = CBarState.ACTION_REWARD
self._cbar_changed = True

@Slot(bool)
def cbarFood2(self, checked: bool) -> None:
if checked:
Expand Down
9 changes: 8 additions & 1 deletion src/emevo/environments/moderngl_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
voffsets: tuple[int, ...] = (),
hoffsets: tuple[int, ...] = (),
sc_color_opt: NDArray | None = None,
sensor_color: NDArray | None = None,
sensor_fn: Callable[[StateDict], tuple[NDArray, NDArray]] | None = None,
) -> None:
self._context = context
Expand Down Expand Up @@ -332,7 +333,11 @@ def __init__(
vertex_shader=_LINE_VERTEX_SHADER,
geometry_shader=_LINE_GEOMETRY_SHADER,
fragment_shader=_LINE_FRAGMENT_SHADER,
color=np.array([0.0, 0.0, 0.0, 0.1], dtype=np.float32),
color=(
np.array([0.0, 0.0, 0.0, 0.1], dtype=np.float32)
if sensor_color is None
else sensor_color
),
width=np.array([0.001], dtype=np.float32),
)

Expand Down Expand Up @@ -461,6 +466,7 @@ def __init__(
space: Space,
stated: StateDict,
food_color: NDArray,
sensor_color: NDArray | None = None,
figsize: tuple[float, float] | None = None,
voffsets: tuple[int, ...] = (),
hoffsets: tuple[int, ...] = (),
Expand Down Expand Up @@ -493,6 +499,7 @@ def __init__(
voffsets=voffsets,
hoffsets=hoffsets,
sc_color_opt=food_color,
sensor_color=sensor_color,
sensor_fn=sensor_fn,
)

Expand Down

0 comments on commit 1d4a157

Please sign in to comment.