Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Oct 10, 2024
2 parents 9448a94 + b50d561 commit 8b40b8a
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions src/emevo/analysis/qt_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _mgl_qsurface_fmt() -> QSurfaceFormat:
return fmt


N_MAX_SCAN: int = 20480
N_MAX_SCAN: int = 10000
N_MAX_CACHED_LOG: int = 100


Expand Down Expand Up @@ -87,20 +87,20 @@ def __init__(
x_range, y_range = env._x_range, env._y_range
self._scaling = x_range / figsize[0], y_range / figsize[1]
self._phys_state = saved_physics
self._index = start
self._make_renderer = partial(
MglRenderer,
screen_width=self._figsize[0],
screen_height=self._figsize[1],
x_range=x_range,
y_range=y_range,
space=env._physics,
stated=self._get_stated(0),
stated=self._get_stated(),
sensor_fn=env._get_sensors,
sc_color_opt=env._food_color,
)
self._env = env
self._get_colors = get_colors
self._index = start
self._end_index = self._phys_state.circle_axy.shape[0] if end is None else end
self._paused = False
self._initialized = False
Expand All @@ -121,8 +121,8 @@ def _scale_position(self, position: QPointF) -> tuple[float, float]:
(self._figsize[1] - position.y()) * self._scaling[1],
)

def _get_stated(self, index: int) -> StateDict:
return self._phys_state.set_by_index(index, self._env_state.physics)
def _get_stated(self) -> StateDict:
return self._phys_state.set_by_index(self._index, self._env_state.physics)

def _set_default_viewport(self) -> None:
self._ctx.viewport = 0, 0, *self._figsize
Expand All @@ -149,7 +149,7 @@ def paintGL(self) -> None:
if not self._paused and self._index < self._end_index - 1:
self._index += 1
self.stepChanged.emit(self._index)
stated = self._get_stated(self._index)
stated = self._get_stated()
if self._get_colors is None:
circle_colors = None
else:
Expand All @@ -163,7 +163,7 @@ def exitable(self) -> bool:

def mousePressEvent(self, evt: QMouseEvent) -> None: # type: ignore
position = self._scale_position(evt.position())
circle = self._get_stated(self._index).circle
circle = self._get_stated().circle
overlap = _overlap(
jnp.array(position),
self._env._physics.shaped.circle,
Expand Down Expand Up @@ -347,6 +347,8 @@ def __init__(
pause_button.clicked.connect(self._mgl_widget.pause)
play_button = QtWidgets.QPushButton("▶️")
play_button.clicked.connect(self._mgl_widget.play)
export_button = QtWidgets.QPushButton("📤")
export_button.clicked.connect(self._exportData)
# Colorbar
radiobutton_1 = QtWidgets.QRadioButton("Energy")
radiobutton_2 = QtWidgets.QRadioButton("Num. Children")
Expand Down Expand Up @@ -571,6 +573,19 @@ def cbarFood2(self, checked: bool) -> None:
self._cbar_state = CBarState.FOOD_REWARD2
self._cbar_changed = True

@Slot()
def exportData(self) -> None:
stated = self._mgl_widget.get_stated()
# TODO: filename
np.savez_compressed(
"exported.npz",
circle_axy=np.array(stated.circle.p.into_axy()),
circle_is_active=np.array(stated.circle.is_active),
static_circle_axy=np.array(stated.static_circle.p.into_axy()),
static_circle_is_active=np.array(stated.static_circle.is_active),
static_circle_label=np.array(stated.static_circle.label),
)


def start_widget(widget_cls: type[QtWidgets.QWidget], **kwargs) -> None:
app = QtWidgets.QApplication([])
Expand Down

0 comments on commit 8b40b8a

Please sign in to comment.