Skip to content

Commit

Permalink
Merge pull request #850 from int-brain-lab/qc_viewer
Browse files Browse the repository at this point in the history
QC viewer
  • Loading branch information
bimac authored Dec 17, 2024
2 parents 4d94d8c + ad682fe commit 42c5b0b
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 126 deletions.
1 change: 1 addition & 0 deletions ibllib/qc/task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def check_iti_delays(data, subtract_pauses=False, iti_delay_secs=ITI_DELAY_SECS,
numpy.array
An array of boolean values, 1 per trial, where True means trial passes QC threshold.
"""
# Initialize array the length of completed trials
metric = np.full(data['intervals'].shape[0], np.nan)
passed = metric.copy()
pauses = (data['pause_duration'] if subtract_pauses else np.zeros_like(metric))[:-1]
Expand Down
317 changes: 194 additions & 123 deletions ibllib/qc/task_qc_viewer/ViewEphysQC.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
"""An interactive PyQT QC data frame."""

import logging

from PyQt5 import QtCore, QtWidgets
from PyQt5 import QtWidgets
from PyQt5.QtCore import (
Qt,
QModelIndex,
pyqtSignal,
pyqtSlot,
QCoreApplication,
QSettings,
QSize,
QPoint,
)
from PyQt5.QtGui import QPalette, QShowEvent
from PyQt5.QtWidgets import QMenu, QAction
from iblqt.core import ColoredDataFrameTableModel
from matplotlib.figure import Figure
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT
import pandas as pd
Expand All @@ -12,101 +26,17 @@
_logger = logging.getLogger(__name__)


class DataFrameModel(QtCore.QAbstractTableModel):
DtypeRole = QtCore.Qt.UserRole + 1000
ValueRole = QtCore.Qt.UserRole + 1001

def __init__(self, df=pd.DataFrame(), parent=None):
super(DataFrameModel, self).__init__(parent)
self._dataframe = df

def setDataFrame(self, dataframe):
self.beginResetModel()
self._dataframe = dataframe.copy()
self.endResetModel()

def dataFrame(self):
return self._dataframe

dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame)

@QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str)
def headerData(self, section: int, orientation: QtCore.Qt.Orientation,
role: int = QtCore.Qt.DisplayRole):
if role == QtCore.Qt.DisplayRole:
if orientation == QtCore.Qt.Horizontal:
return self._dataframe.columns[section]
else:
return str(self._dataframe.index[section])
return QtCore.QVariant()

def rowCount(self, parent=QtCore.QModelIndex()):
if parent.isValid():
return 0
return len(self._dataframe.index)

def columnCount(self, parent=QtCore.QModelIndex()):
if parent.isValid():
return 0
return self._dataframe.columns.size

def data(self, index, role=QtCore.Qt.DisplayRole):
if (not index.isValid() or not (0 <= index.row() < self.rowCount() and
0 <= index.column() < self.columnCount())):
return QtCore.QVariant()
row = self._dataframe.index[index.row()]
col = self._dataframe.columns[index.column()]
dt = self._dataframe[col].dtype

val = self._dataframe.iloc[row][col]
if role == QtCore.Qt.DisplayRole:
return str(val)
elif role == DataFrameModel.ValueRole:
return val
if role == DataFrameModel.DtypeRole:
return dt
return QtCore.QVariant()

def roleNames(self):
roles = {
QtCore.Qt.DisplayRole: b'display',
DataFrameModel.DtypeRole: b'dtype',
DataFrameModel.ValueRole: b'value'
}
return roles

def sort(self, col, order):
"""
Sort table by given column number.
:param col: the column number selected (between 0 and self._dataframe.columns.size)
:param order: the order to be sorted, 0 is descending; 1, ascending
:return:
"""
self.layoutAboutToBeChanged.emit()
col_name = self._dataframe.columns.values[col]
# print('sorting by ' + col_name)
self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True)
self._dataframe.reset_index(inplace=True, drop=True)
self.layoutChanged.emit()


class PlotCanvas(FigureCanvasQTAgg):

def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None):
fig = Figure(figsize=(width, height), dpi=dpi)

FigureCanvasQTAgg.__init__(self, fig)
self.setParent(parent)

FigureCanvasQTAgg.setSizePolicy(
self,
QtWidgets.QSizePolicy.Expanding,
QtWidgets.QSizePolicy.Expanding)
FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
FigureCanvasQTAgg.updateGeometry(self)
if wheel:
self.ax, self.ax2 = fig.subplots(
2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
self.ax, self.ax2 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
else:
self.ax = fig.add_subplot(111)
self.draw()
Expand All @@ -116,69 +46,210 @@ class PlotWindow(QtWidgets.QWidget):
def __init__(self, parent=None, wheel=None):
QtWidgets.QWidget.__init__(self, parent=None)
self.canvas = PlotCanvas(wheel=wheel)
self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting
self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting
self.vbl.addWidget(self.canvas)
self.setLayout(self.vbl)
self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self))


class GraphWindow(QtWidgets.QWidget):
_pinnedColumns = []

def __init__(self, parent=None, wheel=None):
QtWidgets.QWidget.__init__(self, parent=parent)
vLayout = QtWidgets.QVBoxLayout(self)

self.columnPinned = pyqtSignal(int, bool)

# load button
self.pushButtonLoad = QtWidgets.QPushButton('Select File', self)
self.pushButtonLoad.clicked.connect(self.loadFile)

# define table model & view
self.tableModel = ColoredDataFrameTableModel(self)
self.tableView = QtWidgets.QTableView(self)
self.tableView.setModel(self.tableModel)
self.tableView.setSortingEnabled(True)
self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter)
self.tableView.horizontalHeader().setSectionsMovable(True)
self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu)
self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu)
self.tableView.verticalHeader().hide()
self.tableView.doubleClicked.connect(self.tv_double_clicked)

# define colors for highlighted cells
p = self.tableView.palette()
p.setColor(QPalette.Highlight, Qt.black)
p.setColor(QPalette.HighlightedText, Qt.white)
self.tableView.setPalette(p)

# QAction for pinning columns
self.pinAction = QAction('Pin column', self)
self.pinAction.setCheckable(True)
self.pinAction.toggled.connect(self.pinColumn)

# Filter columns by name
self.lineEditFilter = QtWidgets.QLineEdit(self)
self.lineEditFilter.setPlaceholderText('Filter columns')
self.lineEditFilter.textChanged.connect(self.changeFilter)
self.lineEditFilter.setMinimumWidth(200)

# colormap picker
self.comboboxColormap = QtWidgets.QComboBox(self)
colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma', 'summer'}
self.comboboxColormap.addItems(sorted(list(colormaps)))
self.comboboxColormap.setCurrentText(self.tableModel.colormap)
self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap)

# slider for alpha values
self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self)
self.sliderAlpha.setMaximumWidth(100)
self.sliderAlpha.setMinimum(0)
self.sliderAlpha.setMaximum(255)
self.sliderAlpha.setValue(self.tableModel.alpha)
self.sliderAlpha.valueChanged.connect(self.tableModel.setAlpha)

# Horizontal layout
hLayout = QtWidgets.QHBoxLayout()
self.pathLE = QtWidgets.QLineEdit(self)
hLayout.addWidget(self.pathLE)
self.loadBtn = QtWidgets.QPushButton("Select File", self)
hLayout.addWidget(self.loadBtn)
hLayout.addWidget(self.lineEditFilter)
hLayout.addSpacing(50)
hLayout.addWidget(QtWidgets.QLabel('Colormap', self))
hLayout.addWidget(self.comboboxColormap)
hLayout.addWidget(QtWidgets.QLabel('Alpha', self))
hLayout.addWidget(self.sliderAlpha)
hLayout.addSpacing(50)
hLayout.addWidget(self.pushButtonLoad)

# Vertical layout
vLayout = QtWidgets.QVBoxLayout(self)
vLayout.addLayout(hLayout)
self.pandasTv = QtWidgets.QTableView(self)
vLayout.addWidget(self.pandasTv)
self.loadBtn.clicked.connect(self.load_file)
self.pandasTv.setSortingEnabled(True)
self.pandasTv.doubleClicked.connect(self.tv_double_clicked)
vLayout.addWidget(self.tableView)

# Recover layout from QSettings
self.settings = QSettings()
self.settings.beginGroup('MainWindow')
self.resize(self.settings.value('size', QSize(800, 600), QSize))
self.comboboxColormap.setCurrentText(self.settings.value('colormap', 'plasma', str))
self.sliderAlpha.setValue(self.settings.value('alpha', 255, int))
self.settings.endGroup()

self.wplot = PlotWindow(wheel=wheel)
self.wplot.show()
self.tableModel.dataChanged.connect(self.wplot.canvas.draw)

self.wheel = wheel

def load_file(self):
fileName, _ = QtWidgets.QFileDialog.getOpenFileName(
self, "Open File", "", "CSV Files (*.csv)")
self.pathLE.setText(fileName)
def closeEvent(self, _) -> bool:
self.settings.beginGroup('MainWindow')
self.settings.setValue('size', self.size())
self.settings.setValue('colormap', self.tableModel.colormap)
self.settings.setValue('alpha', self.tableModel.alpha)
self.settings.endGroup()
self.wplot.close()

def showEvent(self, a0: QShowEvent) -> None:
super().showEvent(a0)
self.activateWindow()

def contextMenu(self, pos: QPoint):
idx = self.sender().logicalIndexAt(pos)
action = self.pinAction
action.setData(idx)
action.setChecked(idx in self._pinnedColumns)
menu = QMenu(self)
menu.addAction(action)
menu.exec(self.sender().mapToGlobal(pos))

@pyqtSlot(bool)
@pyqtSlot(bool, int)
def pinColumn(self, pin: bool, idx: int | None = None):
idx = idx if idx is not None else self.sender().data()
if not pin and idx in self._pinnedColumns:
self._pinnedColumns.remove(idx)
if pin and idx not in self._pinnedColumns:
self._pinnedColumns.append(idx)
self.changeFilter(self.lineEditFilter.text())

def changeFilter(self, string: str):
headers = [
self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower()
for x in range(self.tableModel.columnCount())
]
tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)]
showAll = len(tokens) == 0
for idx, column in enumerate(headers):
show = showAll or any((t in column for t in tokens)) or idx in self._pinnedColumns
self.tableView.setColumnHidden(idx, not show)

def loadFile(self):
fileName, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '', 'CSV Files (*.csv)')
if len(fileName) == 0:
return
df = pd.read_csv(fileName)
self.update_df(df)

def update_df(self, df):
model = DataFrameModel(df)
self.pandasTv.setModel(model)
self.wplot.canvas.draw()

def tv_double_clicked(self):
df = self.pandasTv.model()._dataframe
ind = self.pandasTv.currentIndex()
start = df.loc[ind.row()]['intervals_0']
finish = df.loc[ind.row()]['intervals_1']
dt = finish - start
self.updateDataframe(df)

def updateDataframe(self, df: pd.DataFrame):
# clear pinned columns
self._pinnedColumns = []

# try to identify and sort columns containing timestamps
col_names = df.select_dtypes('number').columns
df_interp = df[col_names].replace([-np.inf, np.inf], np.nan)
df_interp = df_interp.interpolate(limit_direction='both')
cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]]
cols_mono = [c for c in cols_mono if df[c].nunique() > 1]
cols_mono = df_interp[cols_mono].mean().sort_values().keys()
for idx, col_name in enumerate(cols_mono):
df.insert(idx, col_name, df.pop(col_name))

# columns containing boolean values are sorted to the end
# of those, columns containing 'pass' in their title will be sorted by number of False values
col_names = df.columns
cols_bool = list(df.select_dtypes(['bool', 'boolean']).columns)
cols_pass = [c for c in cols_bool if 'pass' in c]
cols_bool = [c for c in cols_bool if c not in cols_pass] # I know. Friday evening, brain is fried ... sorry.
cols_pass = list((~df[cols_pass]).sum().sort_values().keys())
cols_bool += cols_pass
for col_name in cols_bool:
df = df.join(df.pop(col_name))

# trial_no should always be the first column
if 'trial_no' in col_names:
df.insert(0, 'trial_no', df.pop('trial_no'))

# define columns that should be pinned by default
for col in ['trial_no']:
self._pinnedColumns.append(df.columns.get_loc(col))

self.tableModel.setDataFrame(df)

def tv_double_clicked(self, index: QModelIndex):
data = self.tableModel.dataFrame.iloc[index.row()]
t0 = data['intervals_0']
t1 = data['intervals_1']
dt = t1 - t0
if self.wheel:
idx = np.searchsorted(
self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10]))
idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10]))
period = self.wheel['re_pos'][idx[0]:idx[1]]
if period.size == 0:
_logger.warning('No wheel data during trial #%i', ind.row())
_logger.warning('No wheel data during trial #%i', index.row())
else:
min_val, max_val = np.min(period), np.max(period)
self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1)
self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10)
self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10)

self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10)
self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10)
self.wplot.setWindowTitle(f"Trial {data.get('trial_no', '?')}")
self.wplot.canvas.draw()


def viewqc(qc=None, title=None, wheel=None):
qt.create_app()
app = qt.create_app()
app.setStyle('Fusion')
QCoreApplication.setOrganizationName('International Brain Laboratory')
QCoreApplication.setOrganizationDomain('internationalbrainlab.org')
QCoreApplication.setApplicationName('QC Viewer')
qcw = GraphWindow(wheel=wheel)
qcw.setWindowTitle(title)
if qc is not None:
qcw.update_df(qc)
qcw.updateDataframe(qc)
qcw.show()
return qcw
Loading

0 comments on commit 42c5b0b

Please sign in to comment.