Skip to content

Commit

Permalink
Tweak on widget
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jan 12, 2024
1 parent cfb5ae6 commit a3edbce
Showing 1 changed file with 30 additions and 31 deletions.
61 changes: 30 additions & 31 deletions src/emevo/environments/qt_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from PySide6.QtCore import Qt, QTimer, Signal, Slot
from PySide6.QtGui import QGuiApplication, QMouseEvent, QPainter, QSurfaceFormat
from PySide6.QtOpenGLWidgets import QOpenGLWidget
from PySide6.QtWidgets import QGridLayout, QWidget
from PySide6 import QtWidgets

from emevo.environments.moderngl_vis import MglRenderer
from emevo.environments.phyjax2d import Space, StateDict
Expand All @@ -45,7 +45,7 @@ class AppState:
paused_before: bool = False


class QtVisualizer(QOpenGLWidget):
class MglWidget(QOpenGLWidget):
selectionChanged = Signal(int)

def __init__(
Expand All @@ -55,9 +55,9 @@ def __init__(
y_range: float,
space: Space,
stated: StateDict,
figsize: tuple[float, float] | None = None,
figsize: tuple[float, float],
sensor_fn: Callable[[StateDict], tuple[NDArray, NDArray]] | None = None,
parent: QWidget | None = None,
parent: QtWidgets.QWidget | None = None,
) -> None:
# Set default format
QSurfaceFormat.setDefaultFormat(_mgl_qsurface_fmt())
Expand Down Expand Up @@ -107,7 +107,7 @@ def paintGL(self) -> None:
def render(self, stated: StateDict) -> None:
self._fbo.use()
self._ctx.clear(1.0, 1.0, 1.0)
self._renderer.render(self._env) # type: ignore
self._renderer.render(stated) # type: ignore

def show(self, timer: QTimer):
self._timer = timer
Expand All @@ -121,30 +121,25 @@ def _emit_selected(self, index: int | None) -> None:

def mousePressEvent(self, evt: QMouseEvent) -> None:
position = self._scale_position(evt.position())
query = self._env.get_space().point_query(
position,
0.0,
shape_filter=make_filter(CollisionType.AGENT, CollisionType.FOOD),
)
if len(query) == 1:
shape = query[0].shape
if shape is not None:
body_index = self._env.get_body_index(shape.body)
if body_index is not None:
self._state.pantool.start_drag(position, shape, body_index)
self._emit_selected(body_index)
self._paused_before = self._state.paused
self._state.paused = True
self._timer.stop()
self.update()
# query = self._env.get_space().point_query(
# position,
# 0.0,
# shape_filter=make_filter(CollisionType.AGENT, CollisionType.FOOD),
# )
# if len(query) == 1:
# shape = query[0].shape
# if shape is not None:
# body_index = self._env.get_body_index(shape.body)
# if body_index is not None:
# self._state.pantool.start_drag(position, shape, body_index)
# self._emit_selected(body_index)
# self._paused_before = self._state.paused
# self._state.paused = True
# self._timer.stop()
# self.update()

def mouseReleaseEvent(self, evt: QMouseEvent) -> None:
if self._state.pantool.is_dragging:
self._state.pantool.stop_drag(self._scale_position(evt.position()))
self._emit_selected(None)
self._state.paused = self._state.paused_before
self._timer.start()
self.update()
pass

@Slot()
def pause(self) -> None:
Expand All @@ -155,7 +150,7 @@ def play(self) -> None:
self._state.paused = False


class BarChart(QWidget):
class BarChart(QtWidgets.QWidget):
def __init__(
self,
initial_values: dict[str, float | list[float]],
Expand Down Expand Up @@ -200,7 +195,7 @@ def __init__(
self._chart_view.chart().show()
self._chart_view.chart().legend().show()
# create main layout
layout = QGridLayout(self)
layout = QtWidgets.QGridLayout(self)
layout.addWidget(self._chart_view, 1, 1)
self.setLayout(layout)
self.setVisible(True)
Expand All @@ -209,9 +204,11 @@ def _make_barset(self, name: str, value: float | list[float]) -> QBarSet:
barset = QBarSet(name)
if isinstance(value, float):
barset.append(value)
else:
elif isinstance(value, list):
for v in value:
barset.append(v)
else:
raise ValueError(f"Invalid value for barset: {value}")
self.barsets[name] = barset
self.series.append(barset)
return barset
Expand All @@ -233,9 +230,11 @@ def updateValues(self, values: dict[str, float | list[float]]) -> None:
new_barsets.append(barset)
elif isinstance(value, float):
self.barsets[name].replace(0, value)
else:
elif isinstance(value, list):
for i, vi in enumerate(value):
self.barsets[name].replace(i, vi)
else:
raise ValueError(f"Invalid value for barset {value}")

for name in list(self.barsets.keys()):
if name not in values:
Expand Down

0 comments on commit a3edbce

Please sign in to comment.