diff --git a/src/emevo/analysis/qt_widget.py b/src/emevo/analysis/qt_widget.py index 02dfb37..e73d1d2 100644 --- a/src/emevo/analysis/qt_widget.py +++ b/src/emevo/analysis/qt_widget.py @@ -51,7 +51,7 @@ def _mgl_qsurface_fmt() -> QSurfaceFormat: return fmt -N_MAX_SCAN: int = 20480 +N_MAX_SCAN: int = 10000 N_MAX_CACHED_LOG: int = 100 @@ -87,6 +87,7 @@ def __init__( x_range, y_range = env._x_range, env._y_range self._scaling = x_range / figsize[0], y_range / figsize[1] self._phys_state = saved_physics + self._index = start self._make_renderer = partial( MglRenderer, screen_width=self._figsize[0], @@ -94,13 +95,12 @@ def __init__( x_range=x_range, y_range=y_range, space=env._physics, - stated=self._get_stated(0), + stated=self._get_stated(), sensor_fn=env._get_sensors, sc_color_opt=env._food_color, ) self._env = env self._get_colors = get_colors - self._index = start self._end_index = self._phys_state.circle_axy.shape[0] if end is None else end self._paused = False self._initialized = False @@ -121,8 +121,8 @@ def _scale_position(self, position: QPointF) -> tuple[float, float]: (self._figsize[1] - position.y()) * self._scaling[1], ) - def _get_stated(self, index: int) -> StateDict: - return self._phys_state.set_by_index(index, self._env_state.physics) + def _get_stated(self) -> StateDict: + return self._phys_state.set_by_index(self._index, self._env_state.physics) def _set_default_viewport(self) -> None: self._ctx.viewport = 0, 0, *self._figsize @@ -149,7 +149,7 @@ def paintGL(self) -> None: if not self._paused and self._index < self._end_index - 1: self._index += 1 self.stepChanged.emit(self._index) - stated = self._get_stated(self._index) + stated = self._get_stated() if self._get_colors is None: circle_colors = None else: @@ -163,7 +163,7 @@ def exitable(self) -> bool: def mousePressEvent(self, evt: QMouseEvent) -> None: # type: ignore position = self._scale_position(evt.position()) - circle = self._get_stated(self._index).circle + circle = self._get_stated().circle overlap = _overlap( jnp.array(position), self._env._physics.shaped.circle, @@ -347,6 +347,8 @@ def __init__( pause_button.clicked.connect(self._mgl_widget.pause) play_button = QtWidgets.QPushButton("▶️") play_button.clicked.connect(self._mgl_widget.play) + export_button = QtWidgets.QPushButton("📤") + export_button.clicked.connect(self._exportData) # Colorbar radiobutton_1 = QtWidgets.QRadioButton("Energy") radiobutton_2 = QtWidgets.QRadioButton("Num. Children") @@ -571,6 +573,19 @@ def cbarFood2(self, checked: bool) -> None: self._cbar_state = CBarState.FOOD_REWARD2 self._cbar_changed = True + @Slot() + def exportData(self) -> None: + stated = self._mgl_widget.get_stated() + # TODO: filename + np.savez_compressed( + "exported.npz", + circle_axy=np.array(stated.circle.p.into_axy()), + circle_is_active=np.array(stated.circle.is_active), + static_circle_axy=np.array(stated.static_circle.p.into_axy()), + static_circle_is_active=np.array(stated.static_circle.is_active), + static_circle_label=np.array(stated.static_circle.label), + ) + def start_widget(widget_cls: type[QtWidgets.QWidget], **kwargs) -> None: app = QtWidgets.QApplication([])