diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 86f0e4a9b..b5f3cd7e7 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -682,6 +682,7 @@ def check_iti_delays(data, subtract_pauses=False, iti_delay_secs=ITI_DELAY_SECS, numpy.array An array of boolean values, 1 per trial, where True means trial passes QC threshold. """ + # Initialize array the length of completed trials metric = np.full(data['intervals'].shape[0], np.nan) passed = metric.copy() pauses = (data['pause_duration'] if subtract_pauses else np.zeros_like(metric))[:-1] diff --git a/ibllib/qc/task_qc_viewer/ViewEphysQC.py b/ibllib/qc/task_qc_viewer/ViewEphysQC.py index 48155b270..cae7431c2 100644 --- a/ibllib/qc/task_qc_viewer/ViewEphysQC.py +++ b/ibllib/qc/task_qc_viewer/ViewEphysQC.py @@ -1,7 +1,21 @@ """An interactive PyQT QC data frame.""" + import logging -from PyQt5 import QtCore, QtWidgets +from PyQt5 import QtWidgets +from PyQt5.QtCore import ( + Qt, + QModelIndex, + pyqtSignal, + pyqtSlot, + QCoreApplication, + QSettings, + QSize, + QPoint, +) +from PyQt5.QtGui import QPalette, QShowEvent +from PyQt5.QtWidgets import QMenu, QAction +from iblqt.core import ColoredDataFrameTableModel from matplotlib.figure import Figure from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT import pandas as pd @@ -12,101 +26,17 @@ _logger = logging.getLogger(__name__) -class DataFrameModel(QtCore.QAbstractTableModel): - DtypeRole = QtCore.Qt.UserRole + 1000 - ValueRole = QtCore.Qt.UserRole + 1001 - - def __init__(self, df=pd.DataFrame(), parent=None): - super(DataFrameModel, self).__init__(parent) - self._dataframe = df - - def setDataFrame(self, dataframe): - self.beginResetModel() - self._dataframe = dataframe.copy() - self.endResetModel() - - def dataFrame(self): - return self._dataframe - - dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) - - @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) - def headerData(self, section: int, orientation: QtCore.Qt.Orientation, - role: int = QtCore.Qt.DisplayRole): - if role == QtCore.Qt.DisplayRole: - if orientation == QtCore.Qt.Horizontal: - return self._dataframe.columns[section] - else: - return str(self._dataframe.index[section]) - return QtCore.QVariant() - - def rowCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return len(self._dataframe.index) - - def columnCount(self, parent=QtCore.QModelIndex()): - if parent.isValid(): - return 0 - return self._dataframe.columns.size - - def data(self, index, role=QtCore.Qt.DisplayRole): - if (not index.isValid() or not (0 <= index.row() < self.rowCount() and - 0 <= index.column() < self.columnCount())): - return QtCore.QVariant() - row = self._dataframe.index[index.row()] - col = self._dataframe.columns[index.column()] - dt = self._dataframe[col].dtype - - val = self._dataframe.iloc[row][col] - if role == QtCore.Qt.DisplayRole: - return str(val) - elif role == DataFrameModel.ValueRole: - return val - if role == DataFrameModel.DtypeRole: - return dt - return QtCore.QVariant() - - def roleNames(self): - roles = { - QtCore.Qt.DisplayRole: b'display', - DataFrameModel.DtypeRole: b'dtype', - DataFrameModel.ValueRole: b'value' - } - return roles - - def sort(self, col, order): - """ - Sort table by given column number. - - :param col: the column number selected (between 0 and self._dataframe.columns.size) - :param order: the order to be sorted, 0 is descending; 1, ascending - :return: - """ - self.layoutAboutToBeChanged.emit() - col_name = self._dataframe.columns.values[col] - # print('sorting by ' + col_name) - self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) - self._dataframe.reset_index(inplace=True, drop=True) - self.layoutChanged.emit() - - class PlotCanvas(FigureCanvasQTAgg): - def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): fig = Figure(figsize=(width, height), dpi=dpi) FigureCanvasQTAgg.__init__(self, fig) self.setParent(parent) - FigureCanvasQTAgg.setSizePolicy( - self, - QtWidgets.QSizePolicy.Expanding, - QtWidgets.QSizePolicy.Expanding) + FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvasQTAgg.updateGeometry(self) if wheel: - self.ax, self.ax2 = fig.subplots( - 2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) + self.ax, self.ax2 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) else: self.ax = fig.add_subplot(111) self.draw() @@ -116,69 +46,210 @@ class PlotWindow(QtWidgets.QWidget): def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=None) self.canvas = PlotCanvas(wheel=wheel) - self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting + self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting self.vbl.addWidget(self.canvas) self.setLayout(self.vbl) self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self)) class GraphWindow(QtWidgets.QWidget): + _pinnedColumns = [] + def __init__(self, parent=None, wheel=None): QtWidgets.QWidget.__init__(self, parent=parent) - vLayout = QtWidgets.QVBoxLayout(self) + + self.columnPinned = pyqtSignal(int, bool) + + # load button + self.pushButtonLoad = QtWidgets.QPushButton('Select File', self) + self.pushButtonLoad.clicked.connect(self.loadFile) + + # define table model & view + self.tableModel = ColoredDataFrameTableModel(self) + self.tableView = QtWidgets.QTableView(self) + self.tableView.setModel(self.tableModel) + self.tableView.setSortingEnabled(True) + self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) + self.tableView.horizontalHeader().setSectionsMovable(True) + self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) + self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu) + self.tableView.verticalHeader().hide() + self.tableView.doubleClicked.connect(self.tv_double_clicked) + + # define colors for highlighted cells + p = self.tableView.palette() + p.setColor(QPalette.Highlight, Qt.black) + p.setColor(QPalette.HighlightedText, Qt.white) + self.tableView.setPalette(p) + + # QAction for pinning columns + self.pinAction = QAction('Pin column', self) + self.pinAction.setCheckable(True) + self.pinAction.toggled.connect(self.pinColumn) + + # Filter columns by name + self.lineEditFilter = QtWidgets.QLineEdit(self) + self.lineEditFilter.setPlaceholderText('Filter columns') + self.lineEditFilter.textChanged.connect(self.changeFilter) + self.lineEditFilter.setMinimumWidth(200) + + # colormap picker + self.comboboxColormap = QtWidgets.QComboBox(self) + colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma', 'summer'} + self.comboboxColormap.addItems(sorted(list(colormaps))) + self.comboboxColormap.setCurrentText(self.tableModel.colormap) + self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap) + + # slider for alpha values + self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) + self.sliderAlpha.setMaximumWidth(100) + self.sliderAlpha.setMinimum(0) + self.sliderAlpha.setMaximum(255) + self.sliderAlpha.setValue(self.tableModel.alpha) + self.sliderAlpha.valueChanged.connect(self.tableModel.setAlpha) + + # Horizontal layout hLayout = QtWidgets.QHBoxLayout() - self.pathLE = QtWidgets.QLineEdit(self) - hLayout.addWidget(self.pathLE) - self.loadBtn = QtWidgets.QPushButton("Select File", self) - hLayout.addWidget(self.loadBtn) + hLayout.addWidget(self.lineEditFilter) + hLayout.addSpacing(50) + hLayout.addWidget(QtWidgets.QLabel('Colormap', self)) + hLayout.addWidget(self.comboboxColormap) + hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) + hLayout.addWidget(self.sliderAlpha) + hLayout.addSpacing(50) + hLayout.addWidget(self.pushButtonLoad) + + # Vertical layout + vLayout = QtWidgets.QVBoxLayout(self) vLayout.addLayout(hLayout) - self.pandasTv = QtWidgets.QTableView(self) - vLayout.addWidget(self.pandasTv) - self.loadBtn.clicked.connect(self.load_file) - self.pandasTv.setSortingEnabled(True) - self.pandasTv.doubleClicked.connect(self.tv_double_clicked) + vLayout.addWidget(self.tableView) + + # Recover layout from QSettings + self.settings = QSettings() + self.settings.beginGroup('MainWindow') + self.resize(self.settings.value('size', QSize(800, 600), QSize)) + self.comboboxColormap.setCurrentText(self.settings.value('colormap', 'plasma', str)) + self.sliderAlpha.setValue(self.settings.value('alpha', 255, int)) + self.settings.endGroup() + self.wplot = PlotWindow(wheel=wheel) self.wplot.show() + self.tableModel.dataChanged.connect(self.wplot.canvas.draw) + self.wheel = wheel - def load_file(self): - fileName, _ = QtWidgets.QFileDialog.getOpenFileName( - self, "Open File", "", "CSV Files (*.csv)") - self.pathLE.setText(fileName) + def closeEvent(self, _) -> bool: + self.settings.beginGroup('MainWindow') + self.settings.setValue('size', self.size()) + self.settings.setValue('colormap', self.tableModel.colormap) + self.settings.setValue('alpha', self.tableModel.alpha) + self.settings.endGroup() + self.wplot.close() + + def showEvent(self, a0: QShowEvent) -> None: + super().showEvent(a0) + self.activateWindow() + + def contextMenu(self, pos: QPoint): + idx = self.sender().logicalIndexAt(pos) + action = self.pinAction + action.setData(idx) + action.setChecked(idx in self._pinnedColumns) + menu = QMenu(self) + menu.addAction(action) + menu.exec(self.sender().mapToGlobal(pos)) + + @pyqtSlot(bool) + @pyqtSlot(bool, int) + def pinColumn(self, pin: bool, idx: int | None = None): + idx = idx if idx is not None else self.sender().data() + if not pin and idx in self._pinnedColumns: + self._pinnedColumns.remove(idx) + if pin and idx not in self._pinnedColumns: + self._pinnedColumns.append(idx) + self.changeFilter(self.lineEditFilter.text()) + + def changeFilter(self, string: str): + headers = [ + self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() + for x in range(self.tableModel.columnCount()) + ] + tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] + showAll = len(tokens) == 0 + for idx, column in enumerate(headers): + show = showAll or any((t in column for t in tokens)) or idx in self._pinnedColumns + self.tableView.setColumnHidden(idx, not show) + + def loadFile(self): + fileName, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '', 'CSV Files (*.csv)') + if len(fileName) == 0: + return df = pd.read_csv(fileName) - self.update_df(df) - - def update_df(self, df): - model = DataFrameModel(df) - self.pandasTv.setModel(model) - self.wplot.canvas.draw() - - def tv_double_clicked(self): - df = self.pandasTv.model()._dataframe - ind = self.pandasTv.currentIndex() - start = df.loc[ind.row()]['intervals_0'] - finish = df.loc[ind.row()]['intervals_1'] - dt = finish - start + self.updateDataframe(df) + + def updateDataframe(self, df: pd.DataFrame): + # clear pinned columns + self._pinnedColumns = [] + + # try to identify and sort columns containing timestamps + col_names = df.select_dtypes('number').columns + df_interp = df[col_names].replace([-np.inf, np.inf], np.nan) + df_interp = df_interp.interpolate(limit_direction='both') + cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]] + cols_mono = [c for c in cols_mono if df[c].nunique() > 1] + cols_mono = df_interp[cols_mono].mean().sort_values().keys() + for idx, col_name in enumerate(cols_mono): + df.insert(idx, col_name, df.pop(col_name)) + + # columns containing boolean values are sorted to the end + # of those, columns containing 'pass' in their title will be sorted by number of False values + col_names = df.columns + cols_bool = list(df.select_dtypes(['bool', 'boolean']).columns) + cols_pass = [c for c in cols_bool if 'pass' in c] + cols_bool = [c for c in cols_bool if c not in cols_pass] # I know. Friday evening, brain is fried ... sorry. + cols_pass = list((~df[cols_pass]).sum().sort_values().keys()) + cols_bool += cols_pass + for col_name in cols_bool: + df = df.join(df.pop(col_name)) + + # trial_no should always be the first column + if 'trial_no' in col_names: + df.insert(0, 'trial_no', df.pop('trial_no')) + + # define columns that should be pinned by default + for col in ['trial_no']: + self._pinnedColumns.append(df.columns.get_loc(col)) + + self.tableModel.setDataFrame(df) + + def tv_double_clicked(self, index: QModelIndex): + data = self.tableModel.dataFrame.iloc[index.row()] + t0 = data['intervals_0'] + t1 = data['intervals_1'] + dt = t1 - t0 if self.wheel: - idx = np.searchsorted( - self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10])) + idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) period = self.wheel['re_pos'][idx[0]:idx[1]] if period.size == 0: - _logger.warning('No wheel data during trial #%i', ind.row()) + _logger.warning('No wheel data during trial #%i', index.row()) else: min_val, max_val = np.min(period), np.max(period) self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) - self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10) - self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10) - + self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10) + self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10) + self.wplot.setWindowTitle(f"Trial {data.get('trial_no', '?')}") self.wplot.canvas.draw() def viewqc(qc=None, title=None, wheel=None): - qt.create_app() + app = qt.create_app() + app.setStyle('Fusion') + QCoreApplication.setOrganizationName('International Brain Laboratory') + QCoreApplication.setOrganizationDomain('internationalbrainlab.org') + QCoreApplication.setApplicationName('QC Viewer') qcw = GraphWindow(wheel=wheel) qcw.setWindowTitle(title) if qc is not None: - qcw.update_df(qc) + qcw.updateDataframe(qc) qcw.show() return qcw diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index 89a8d172f..b9f212a5c 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -140,7 +140,8 @@ def create_plots(self, axes, 'ymin': 0, 'ymax': 4, 'linewidth': 2, - 'ax': axes + 'ax': axes, + 'alpha': 0.5, } bnc1 = self.qc.extractor.frame_ttls @@ -240,7 +241,8 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N if isinstance(qc_or_session, QcFrame): qc = qc_or_session elif isinstance(qc_or_session, TaskQC): - qc = QcFrame(qc_or_session) + task_qc = qc_or_session + qc = QcFrame(task_qc) else: # assumed to be eid or session path one = one or ONE(mode='local' if local else 'auto') if not is_session_path(Path(qc_or_session)): @@ -284,8 +286,22 @@ def show_session_task_qc(qc_or_session=None, bpod_only=False, local=False, one=N trial_events=list(events), color_map=cm, linestyle=ls) + # Update table and callbacks - w.update_df(qc.frame) + n_trials = qc.frame.shape[0] + if 'task_qc' in locals(): + df_trials = pd.DataFrame({ + k: v for k, v in task_qc.extractor.data.items() + if v.size == n_trials and not k.startswith('wheel') + }) + df = df_trials.merge(qc.frame, left_index=True, right_index=True) + else: + df = qc.frame + df_pass = pd.DataFrame({k: v for k, v in qc.qc.passed.items() if isinstance(v, np.ndarray) and v.size == n_trials}) + df_pass.drop('_task_passed_trial_checks', axis=1, errors='ignore', inplace=True) + df_pass.rename(columns=lambda x: x.replace('_task', 'passed'), inplace=True) + df = df.merge(df_pass.astype('boolean'), left_index=True, right_index=True) + w.updateDataframe(df) qt.run_app() return qc diff --git a/ibllib/tests/qc/test_task_qc_viewer.py b/ibllib/tests/qc/test_task_qc_viewer.py index 6db045f91..7115f371f 100644 --- a/ibllib/tests/qc/test_task_qc_viewer.py +++ b/ibllib/tests/qc/test_task_qc_viewer.py @@ -66,6 +66,7 @@ def test_show_session_task_qc(self, trials_tasks_mock, run_app_mock): qc_mock.compute_session_status.return_value = ('Fail', qc_mock.metrics, {'foo': 'FAIL'}) qc_mock.extractor.data = {'intervals': np.array([[0, 1]])} qc_mock.extractor.frame_ttls = qc_mock.extractor.audio_ttls = qc_mock.extractor.bpod_ttls = mock.MagicMock() + qc_mock.passed = dict() active_task = mock.Mock(spec=ChoiceWorldTrialsNidq, unsafe=True) active_task.run_qc.return_value = qc_mock diff --git a/requirements.txt b/requirements.txt index 92066d9a0..b890b3e5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ tqdm>=4.32.1 iblatlas>=0.5.3 ibl-neuropixel>=1.5.0 iblutil>=1.13.0 +iblqt>=0.3.2 mtscomp>=1.0.1 ONE-api>=2.11 phylib>=2.6.0 diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..253516e9f --- /dev/null +++ b/ruff.toml @@ -0,0 +1,4 @@ +line-length = 130 + +[format] +quote-style = "single"