diff --git a/src/emevo/environments/qt_vis.py b/src/emevo/environments/qt_vis.py index 7dbbc557..f8a67c2b 100644 --- a/src/emevo/environments/qt_vis.py +++ b/src/emevo/environments/qt_vis.py @@ -22,7 +22,7 @@ from PySide6.QtCore import Qt, QTimer, Signal, Slot from PySide6.QtGui import QGuiApplication, QMouseEvent, QPainter, QSurfaceFormat from PySide6.QtOpenGLWidgets import QOpenGLWidget -from PySide6.QtWidgets import QGridLayout, QWidget +from PySide6 import QtWidgets from emevo.environments.moderngl_vis import MglRenderer from emevo.environments.phyjax2d import Space, StateDict @@ -45,7 +45,7 @@ class AppState: paused_before: bool = False -class QtVisualizer(QOpenGLWidget): +class MglWidget(QOpenGLWidget): selectionChanged = Signal(int) def __init__( @@ -55,9 +55,9 @@ def __init__( y_range: float, space: Space, stated: StateDict, - figsize: tuple[float, float] | None = None, + figsize: tuple[float, float], sensor_fn: Callable[[StateDict], tuple[NDArray, NDArray]] | None = None, - parent: QWidget | None = None, + parent: QtWidgets.QWidget | None = None, ) -> None: # Set default format QSurfaceFormat.setDefaultFormat(_mgl_qsurface_fmt()) @@ -107,7 +107,7 @@ def paintGL(self) -> None: def render(self, stated: StateDict) -> None: self._fbo.use() self._ctx.clear(1.0, 1.0, 1.0) - self._renderer.render(self._env) # type: ignore + self._renderer.render(stated) # type: ignore def show(self, timer: QTimer): self._timer = timer @@ -121,30 +121,25 @@ def _emit_selected(self, index: int | None) -> None: def mousePressEvent(self, evt: QMouseEvent) -> None: position = self._scale_position(evt.position()) - query = self._env.get_space().point_query( - position, - 0.0, - shape_filter=make_filter(CollisionType.AGENT, CollisionType.FOOD), - ) - if len(query) == 1: - shape = query[0].shape - if shape is not None: - body_index = self._env.get_body_index(shape.body) - if body_index is not None: - self._state.pantool.start_drag(position, shape, body_index) - self._emit_selected(body_index) - self._paused_before = self._state.paused - self._state.paused = True - self._timer.stop() - self.update() + # query = self._env.get_space().point_query( + # position, + # 0.0, + # shape_filter=make_filter(CollisionType.AGENT, CollisionType.FOOD), + # ) + # if len(query) == 1: + # shape = query[0].shape + # if shape is not None: + # body_index = self._env.get_body_index(shape.body) + # if body_index is not None: + # self._state.pantool.start_drag(position, shape, body_index) + # self._emit_selected(body_index) + # self._paused_before = self._state.paused + # self._state.paused = True + # self._timer.stop() + # self.update() def mouseReleaseEvent(self, evt: QMouseEvent) -> None: - if self._state.pantool.is_dragging: - self._state.pantool.stop_drag(self._scale_position(evt.position())) - self._emit_selected(None) - self._state.paused = self._state.paused_before - self._timer.start() - self.update() + pass @Slot() def pause(self) -> None: @@ -155,7 +150,7 @@ def play(self) -> None: self._state.paused = False -class BarChart(QWidget): +class BarChart(QtWidgets.QWidget): def __init__( self, initial_values: dict[str, float | list[float]], @@ -200,7 +195,7 @@ def __init__( self._chart_view.chart().show() self._chart_view.chart().legend().show() # create main layout - layout = QGridLayout(self) + layout = QtWidgets.QGridLayout(self) layout.addWidget(self._chart_view, 1, 1) self.setLayout(layout) self.setVisible(True) @@ -209,9 +204,11 @@ def _make_barset(self, name: str, value: float | list[float]) -> QBarSet: barset = QBarSet(name) if isinstance(value, float): barset.append(value) - else: + elif isinstance(value, list): for v in value: barset.append(v) + else: + raise ValueError(f"Invalid value for barset: {value}") self.barsets[name] = barset self.series.append(barset) return barset @@ -233,9 +230,11 @@ def updateValues(self, values: dict[str, float | list[float]]) -> None: new_barsets.append(barset) elif isinstance(value, float): self.barsets[name].replace(0, value) - else: + elif isinstance(value, list): for i, vi in enumerate(value): self.barsets[name].replace(i, vi) + else: + raise ValueError(f"Invalid value for barset {value}") for name in list(self.barsets.keys()): if name not in values: