From 036d6ea9f89c3dfa4cec8d122b2f4d942dac5104 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Wed, 15 Nov 2023 13:27:14 +0200 Subject: [PATCH 1/7] Issue #666 --- ibllib/io/extractors/habituation_trials.py | 93 +++++++++++++++------- 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index 9dedbd3d5..59a29a269 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -1,12 +1,11 @@ +"""Habituation ChoiceWorld Bpod trials extraction.""" import logging import numpy as np import ibllib.io.raw_data_loaders as raw from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes from ibllib.io.extractors.biased_trials import ContrastLR -from ibllib.io.extractors.training_trials import ( - FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes -) +from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes _logger = logging.getLogger(__name__) @@ -24,9 +23,24 @@ def __init__(self, *args, **kwargs): self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) def _extract(self) -> dict: + """ + Extract the Bpod trial events. + + The Bpod state machine for this task has extremely misleading names! The 'iti' state is + actually the delay between valve open and trial end (the stimulus is still present during + this period), and the 'trial_start' state is actually the ITI during which there is a 1s + Bpod TTL and gray screen period. + + Returns + ------- + dict + A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. + """ # Extract all trials... - # Get all stim_sync events detected + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) + # These are the frame2TTL pulses as a list of lists, one per trial ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] # Report missing events @@ -38,10 +52,45 @@ def _extract(self) -> dict: _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') # Extract datasets common to trainingChoiceWorld - training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes] + training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) + """ + The 'trial_start' state is in fact the 1s grey screen period, therefore the first timestamp + is really the end of the previous trial and also the stimOff trigger time. The second + timestamp is the true trial start time. + """ + (_, *ends), starts = zip(*[ + t['behavior_data']['States timestamps']['trial_start'][-1] for t in self.bpod_trials] + ) + + # StimOffTrigger times + out['stimOffTrigger_times'] = np.array(ends) + + # StimOff times + """ + There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. + If 1 or more pulses are missing, we can not be confident of assigning the correct one. + """ + out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan + for sync, off in zip(ttls[1:], ends)]) + + # Trial intervals + """ + In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim + off time often happens after the trial end TTL front, i.e. after the 'trial_start' start + begins. For these trials, we set the trial end time as the stim off time. + """ + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 + n_trials = out['stimOff_times'].size + out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] + to_update = out['intervals'][:, 1] < out['stimOff_times'] + out['intervals'][to_update, 1] = out['stimOff_times'][to_update] + + # itiIn times + out['itiIn_times'] = np.r_[ends, np.nan] + # GoCueTriggerTimes is the same event as StimOnTriggerTimes out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() @@ -75,38 +124,22 @@ def _extract(self) -> dict: trial_volume = [x['reward_amount'] for x in self.bpod_trials] out['rewardVolume'] = np.array(trial_volume).astype(np.float64) - # StimOffTrigger times - # StimOff occurs at trial start (ignore the first trial's state update) - out['stimOffTrigger_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["trial_start"][0][0] for tr in self.bpod_trials[1:]] - ) - - # StimOff times - """ - There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. - If 1 or more pulses are missing, we can not be confident of assigning the correct one. - """ - trigg = out['stimOffTrigger_times'] - out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan - for sync, off in zip(ttls[1:], trigg)]) - # FeedbackType is always positive out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) - # ItiIn times - out['itiIn_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["iti"][0][0] for tr in self.bpod_trials] - ) - # Phase and position out['position'] = np.array([t['position'] for t in self.bpod_trials]) out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) - # NB: We lose the last trial because the stim off event occurs at trial_num + 1 - n_trials = out['stimOff_times'].size - # return [out[k][:n_trials] for k in self.var_names] + # Double-check that the early and late trial events occur within the trial intervals + idx = ~np.isnan(out['stimOn_times'][:n_trials]) + if np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]): + _logger.warning('Stim on events occurring outside trial intervals') + idx = ~np.isnan(out['stimOff_times']) + if np.any(out['stimOff_times'][idx] > out['intervals'][idx, 1]): + _logger.warning('Stim off events occurring outside trial intervals') + + # Truncate arrays and return in correct order return {k: out[k][:n_trials] for k in self.var_names} From 0159729e7788bec355e63fa19d5ee8a84ff520bf Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 17 Nov 2023 15:19:16 +0200 Subject: [PATCH 2/7] FpgaTrials refactor and FpgaHabituationTrials subclass --- ibllib/io/extractors/base.py | 9 + ibllib/io/extractors/biased_trials.py | 2 + ibllib/io/extractors/ephys_fpga.py | 866 +++++++++++++++++-- ibllib/io/extractors/habituation_trials.py | 19 +- ibllib/io/session_params.py | 1 - ibllib/pipes/behavior_tasks.py | 64 +- ibllib/pipes/dynamic_pipeline.py | 28 +- ibllib/qc/task_metrics.py | 29 +- ibllib/tests/extractors/test_ephys_fpga.py | 187 +--- ibllib/tests/extractors/test_ephys_trials.py | 189 ++++ 10 files changed, 1092 insertions(+), 302 deletions(-) diff --git a/ibllib/io/extractors/base.py b/ibllib/io/extractors/base.py index c1b46b22e..cfc9557f4 100644 --- a/ibllib/io/extractors/base.py +++ b/ibllib/io/extractors/base.py @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC): """ session_path = None + """pathlib.Path: Absolute path of session folder.""" + save_names = None + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + var_names = None + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + default_path = Path('alf') # relative to session + """pathlib.Path: The default output folder relative to `session_path`.""" def __init__(self, session_path=None): # If session_path is None Path(session_path) will fail @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor): bpod_trials = None settings = None task_collection = None + frame2ttl = None + audio = None def extract(self, bpod_trials=None, settings=None, **kwargs): """ diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index 16d8f8111..e2912d11e 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor): def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 74ac1e551..ad2cb0ab5 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -1,14 +1,46 @@ -"""Data extraction from raw FPGA output -Complete FPGA data extraction depends on Bpod extraction +"""Data extraction from raw FPGA output. + +The behaviour extraction happens in the following stages: + + 1. The NI DAQ events are extracted into a map of event times and TTL polarities. + 2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol. + 3. As protocols may be chained together within a given recording, the period of a given task + protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`). + 4. Physical behaviour events such as stim on and reward time are separated out by TTL length or + sequence within the trial. + 5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events. + 6. The Bpod software events are then converted to FPGA time. + +Examples +-------- +For simple extraction, use the FPGATrials class: + +>>> extractor = FpgaTrials(session_path) +>>> trials, _ = extractor.extract(update=False, save=False) + +Notes +----- +Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a +Neuropixels recording system, however a sync and channel map extracted from a different DAQ format +can be passed to the FpgaTrials class. + +See Also +-------- +For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class. + +TODO notes on subclassing various methods of FpgaTrials for custom hardware. """ -from collections import OrderedDict import logging +from itertools import cycle from pathlib import Path import uuid import re +import warnings import matplotlib.pyplot as plt +from matplotlib.colors import TABLEAU_COLORS import numpy as np +from packaging import version import spikeglx import neurodsp.utils @@ -21,17 +53,22 @@ from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves -import ibllib.plots as plots +from ibllib import plots from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS _logger = logging.getLogger(__name__) -SYNC_BATCH_SIZE_SECS = 100 # number of samples to read at once in bin file for sync +SYNC_BATCH_SIZE_SECS = 100 +"""int: Number of samples to read at once in bin file for sync.""" + WHEEL_RADIUS_CM = 1 # stay in radians +"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians.""" + WHEEL_TICKS = 1024 +"""int: The number of encoder pulses per channel for one complete rotation.""" -BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 # throws an error if bpod to fpga clock drift is higher -F2TTL_THRESH = 0.01 # consecutive pulses with less than this threshold ignored +BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 +"""int: Throws an error if Bpod to FPGA clock drift is higher than this value.""" CHMAPS = {'3A': {'ap': @@ -62,10 +99,11 @@ {'imec_sync': 6} }, } +"""dict: The default channel indices corresponding to various devices for different recording systems.""" def data_for_keys(keys, data): - """Check keys exist in 'data' dict and contain values other than None""" + """Check keys exist in 'data' dict and contain values other than None.""" return data is not None and all(k in data and data.get(k, None) is not None for k in keys) @@ -157,6 +195,8 @@ def _assign_events_bpod(bpod_t, bpod_polarities, ignore_first_valve=True): :param bpod_fronts: numpy vector containing polarity of fronts (1 rise, -1 fall) :param ignore_first_valve (True): removes detected valve events at indices le 2 :return: numpy arrays of times t_trial_start, t_valve_open and t_iti_in + + TODO Remove function (now using FpgaTrials._assign_events) """ TRIAL_START_TTL_LEN = 2.33e-4 # the TTL length is 0.1ms but this has proven to drift on # some bpods and this is the highest possible value that discriminates trial start from valve @@ -258,6 +298,8 @@ def _assign_events_audio(audio_t, audio_polarities, return_indices=False, displa :param display (False): for debug mode, displays the raw fronts overlaid with detections :return: numpy arrays t_ready_tone_in, t_error_tone_in :return: numpy arrays ind_ready_tone_in, ind_error_tone_in if return_indices=True + + TODO Remove function (now using FpgaTrials._assign_events) """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) @@ -285,13 +327,29 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): """ Assign events to a trial given trial start times and event times. - Trials without an event - result in nan value in output time vector. + Trials without an event result in nan value in output time vector. The output has a consistent size with t_trial_start and ready to output to alf. - :param t_trial_start: numpy vector of trial start times - :param t_event: numpy vector of event times to assign to trials - :param take: 'last' or 'first' (optional, default 'last'): index to take in case of duplicates - :return: numpy array of event times with the same shape of trial start. + + Parameters + ---------- + t_trial_start : numpy.array + An array of start times, used to bin edges for assigning values from `t_event`. + t_event : numpy.array + An array of event times to assign to trials. + take : str {'first', 'last'}, int + 'first' takes first event > t_trial_start; 'last' takes last event < the next + t_trial_start; an int defines the index to take for events within trial bounds. The index + may be negative. + + Returns + ------- + numpy.array + An array the length of `t_trial_start` containing values from `t_event`. Unassigned values + are replaced with np.nan. + + See Also + -------- + FpgaTrials._assign_events - Assign trial events based on TTL length. """ # make sure the events are sorted try: @@ -316,7 +374,7 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): else: # if the index is arbitrary, needs to be numeric (could be negative if from the end) iall = np.unique(ind) minsize = take + 1 if take >= 0 else - take - # for each trial, take the takenth element if there are enough values in trial + # for each trial, take the take nth element if there are enough values in trial for iu in iall: match = t_event[iu == ind] if len(match) >= minsize: @@ -382,25 +440,39 @@ def _clean_audio(audio, display=False): return audio -def _clean_frame2ttl(frame2ttl, display=False): +def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False): """ + Clean the frame2ttl events. + Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold - of F2TTL_THRESH + of F2TTL_THRESH. + + Parameters + ---------- + frame2ttl : dict + A dictionary of frame2TTL events, with keys {'times', 'polarities'}. + threshold : float + Consecutive pulses occurring with this many seconds ignored. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + """ dt = np.diff(frame2ttl['times']) - iko = np.where(np.logical_and(dt < F2TTL_THRESH, frame2ttl['polarities'][:-1] == -1))[0] + iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0] iko = np.unique(np.r_[iko, iko + 1]) frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko), 'polarities': np.delete(frame2ttl['polarities'], iko)} if iko.size > (0.1 * frame2ttl['times'].size): _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) ' - f'frame to TTL polarity switches below {F2TTL_THRESH} secs') + f'frame to TTL polarity switches below {threshold} secs') if display: # pragma: no cover - from ibllib.plots import squares - plt.figure() - squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9]) - squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9]) + fig, (ax0, ax1) = plt.subplots(2, sharex=True) + plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0) + plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1) import seaborn as sns sns.displot(dt[dt < 0.05], binwidth=0.0005) @@ -425,9 +497,9 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): Returns ------- - np.array + numpy.array Wheel timestamps in seconds. - np.array + numpy.array Wheel positions in radians. """ # Assume two separate edge count channels @@ -440,7 +512,7 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): return re_ts, re_pos -def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tmin=None, tmax=None): +def extract_behaviour_sync(sync, chmap, display=False, bpod_trials=None, tmin=None, tmax=None): """ Extract task related event times from the sync. @@ -463,6 +535,8 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ------- dict A map of trial event timestamps. + + TODO Remove this function (now using FpgaTrials.extract_behaviour_sync) """ bpod = get_sync_fronts(sync, chmap['bpod'], tmin=tmin, tmax=tmax) if bpod.times.size == 0: @@ -476,6 +550,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm t_trial_start, t_valve_open, t_iti_in = _assign_events_bpod(bpod['times'], bpod['polarities']) if not bpod_trials: raise ValueError('No Bpod trials to align') + intervals_bpod = bpod_trials['intervals'] # If there are no detected trial start times or more than double the trial end pulses, # the trial start pulses may be too small to be detected, in which case, sync using the ini_in if t_trial_start.size == 0 or (t_trial_start.size / t_iti_in.size) < .5: @@ -486,12 +561,12 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm # if it's drifting too much if drift > 200 and bpod_end.size != t_iti_in.size: raise err.SyncBpodFpgaException('sync cluster f*ck') - t_trial_start = fcn(bpod_trials['intervals_bpod'][:, 0]) + t_trial_start = fcn(intervals_bpod[:, 0]) else: # one issue is that sometimes bpod pulses may not have been detected, in this case # perform the sync bpod/FPGA, and add the start that have not been detected _logger.info('Attempting to align on trial start') - bpod_start = bpod_trials['intervals_bpod'][:, 0] + bpod_start = intervals_bpod[:, 0] fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( bpod_start, t_trial_start, return_indices=True) # if it's drifting too much @@ -703,34 +778,39 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): class FpgaTrials(extractors_base.BaseExtractor): - save_names = ('_ibl_trials.intervals_bpod.npy', - '_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, + save_names = ('_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy', '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy') - var_names = ('intervals_bpod', - 'goCueTrigger_times', 'stimOnTrigger_times', + var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times', 'valveOpen_times', 'phase', 'position', 'quiescence', 'table', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude') - # Fields from bpod extractor that we want to re-sync to FPGA bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" - # Fields from bpod extractor that we want to save bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight', - 'probabilityLeft', 'intervals_bpod', 'phase', 'position', 'quiescence') + 'probabilityLeft', 'phase', 'position', 'quiescence') + """tuple of str: Fields from bpod extractor that we want to save.""" + + sync_field = 'intervals_0' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" - """str: The Bpod events to synchronize (must be present in sync channel map).""" - sync_field = 'intervals' + bpod = None + """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot.""" def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): - """An extractor for all ephys trial data, in FPGA time""" + """An extractor for ephysChoiceWorld trials data, in FPGA time. + + This class may be subclassed to handle moderate variations in hardware and task protocol, + however there is flexible + """ super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials @@ -781,7 +861,7 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): if not self.bpod_trials: self.bpod_trials = self.bpod_extractor.extract(save=False) table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() - self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') + self.bpod_fields += tuple([x for x in table_keys if x not in excluded]) @staticmethod def _time_fields(trials_attr) -> set: @@ -802,72 +882,266 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) + def load_sync(self, sync_collection='raw_ephys_data', **kwargs): + """Load the DAQ sync and channel map data. + + This method may be subclassed for novel DAQ systems. The sync must contain the following + keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1} + corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints + corresponding to channel number. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + kwargs + Optional arguments used by subclass methods. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and + the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + return get_sync_and_chn_map(self.session_path, sync_collection) + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs) -> dict: - """Extracts ephys trials by combining Bpod and FPGA sync pulses""" - # extract the behaviour data from bpod + """Extracts ephys trials by combining Bpod and FPGA sync pulses. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Below are the steps involved: + 0. Load sync and bpod trials, if required. + 1. Determine protocol period and discard sync events outside the task. + 2. Classify and attribute DAQ TTLs to trial events (see :meth:`FpgaTrials.extract_behaviour_sync`). + 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. + 4. Convert Bpod software event times to DAQ clock. + 5. Extract the wheel from the DAQ rotary encoder signal, if required. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :meth:`FpgaTrials.load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for subclass methods to use. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrials.var_names` as keys. + """ if sync is None or chmap is None: - _sync, _chmap = get_sync_and_chn_map(self.session_path, sync_collection) + _sync, _chmap = self.load_sync(sync_collection) sync = sync or _sync chmap = chmap or _chmap - if not self.bpod_trials: + if not self.bpod_trials: # extract the behaviour data from bpod self.bpod_trials, *_ = bpod_extract_all( session_path=self.session_path, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) + # Explode trials table df - trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) - table_columns = trials_table.keys() - self.bpod_trials.update(trials_table) - self.bpod_trials['intervals_bpod'] = np.copy(self.bpod_trials['intervals']) + if 'table' in self.var_names: + trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) + table_columns = trials_table.keys() + self.bpod_trials.update(trials_table) + else: + if 'table' in self.bpod_trials: + _logger.error( + '"table" found in Bpod trials but missing from `var_names` attribute and will' + 'therefore not be extracted. This is likely in error.') + table_columns = None # Get the spacer times for this protocol - if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer + if any(arg in kwargs for arg in ('tmin', 'tmax')): + tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') + elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol bpod = get_sync_fronts(sync, chmap['bpod']) tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) else: tmin = tmax = None - # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC - fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( - sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) - assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials - self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) - - # checks consistency and compute dt with bpod - self.bpod2fpga, drift_ppm, ibpod, ifpga = neurodsp.utils.sync_timestamps( - self.bpod_trials[f'{self.sync_field}_bpod'][:, 0], fpga_trials.pop(self.sync_field)[:, 0], - return_indices=True) - nbpod = self.bpod_trials[f'{self.sync_field}_bpod'].shape[0] - npfga = fpga_trials['feedback_times'].shape[0] - nsync = len(ibpod) - _logger.info(f'N trials: {nbpod} bpod, {npfga} FPGA, {nsync} merged, sync {drift_ppm} ppm') - if drift_ppm > BPOD_FPGA_DRIFT_THRESHOLD_PPM: - _logger.warning('BPOD/FPGA synchronization shows values greater than %i ppm', - BPOD_FPGA_DRIFT_THRESHOLD_PPM) - out = OrderedDict() + # Remove unnecessary data from sync + selection = np.logical_and( + sync['times'] <= (tmax if tmax is not None else sync['times'][-1]), + sync['times'] >= (tmin if tmin is not None else sync['times'][0]), + ) + sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) + _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', + *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) + + # Get the trial events from the DAQ sync TTLs + fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) + + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) + fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) + + out = dict() out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) # extract the wheel data - wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) - from ibllib.io.extractors.training_wheel import extract_first_movement_times - if not self.settings: - self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = self.settings.get('QUIESCENT_PERIOD', None) - first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) - out.update({'firstMovement_times': first_move_onsets}) + if any(x.startswith('wheel') for x in self.var_names): + wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) + from ibllib.io.extractors.training_wheel import extract_first_movement_times + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) + first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) + out.update({'firstMovement_times': first_move_onsets}) + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) + # Re-create trials table - trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) - out['table'] = trials_table.to_df() + if table_columns: + trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) + out['table'] = trials_table.to_df() - out.update({f'wheel_{k}': v for k, v in wheel.items()}) - out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) - out = {k: out[k] for k in self.var_names if k in out} # Reorder output + out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output assert self.var_names == tuple(out.keys()) return out + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The + first trial start TTL of the session is longer and must be handled differently. The trial + start TTL is used to assign the other trial events to each trial. + + The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest + of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio + tones. The first of these after each trial start is taken to be the go cue time. Error + tones are longer audio TTLs and assigned as the last of such occurrence after each trial + start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. + The feedback times are times of either valve open or error tone as there should be only one + such event per trial. + + The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs + removed): the first TTL after each trial start is assumed to be the stim onset time; the + second to last and last are taken as the stimulus freeze and offset times, respectively. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + # The first trial pulse is longer and often assigned to another event. + # Here we move the earliest non-trial_start event to the trial_start array. + t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start + pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] + if pretrial: + (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event + dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log + _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) + bpod_event_intervals['trial_start'] = np.r_[ + bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] + ] + bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + # Drop last trial start if incomplete + t_trial_start = bpod_event_intervals['trial_start'][:len(t_trial_end), 0] + t_valve_open = bpod_event_intervals['valve_open'][:, 0] + t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] + t_error_tone_in = audio_event_intervals['error_tone'][:, 0] + + start_times = start_times or t_trial_start + + trials = alfio.AlfBunch({ + 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first'), + 'errorCue_times': _assign_events_to_trial(start_times, t_error_tone_in), + 'valveOpen_times': _assign_events_to_trial(start_times, t_valve_open), + 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2), + 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first'), + 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times']), + 'itiIn_times': _assign_events_to_trial(start_times, t_iti_in) + }) + + # feedback times are valve open on correct trials and error tone in on incorrect trials + trials['feedback_times'] = np.copy(trials['valveOpen_times']) + ind_err = np.isnan(trials['valveOpen_times']) + trials['feedback_times'][ind_err] = trials['errorCue_times'][ind_err] + trials['intervals'] = np.c_[start_times, t_trial_end] + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2]) + ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 5]) + + return trials + def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -875,6 +1149,432 @@ def get_wheel_positions(self, *args, **kwargs): """ return get_wheel_positions(*args, **kwargs) + def get_stimulus_update_times(self, sync, chmap, display=False, **_): + """ + Extract stimulus update times from sync. + + Gets the stimulus times from the frame2ttl channel and cleans the signal. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts. + """ + frame2ttl = get_sync_fronts(sync, chmap['frame2ttl']) + frame2ttl = _clean_frame2ttl(frame2ttl, display=display) + return frame2ttl + + def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_): + """ + Extract audio times from sync. + + Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL + event by length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain an 'audio' key. + audio_event_ttls : dict + A map of event names to (min, max) TTL length. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing audio TTL fronts. + dict + A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array. + """ + audio = get_sync_fronts(sync, chmap['audio']) + audio = _clean_audio(audio) + + if audio['times'].size == 0: + _logger.error('No audio sync fronts found.') + + if audio_event_ttls is None: + # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error + # tone should be between 400ms and 1200ms + audio_event_ttls = {'ready_tone': (0, 0.11), 'error_tone': (0.4, 1.2)} + audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display) + + return audio, audio_event_intervals + + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by + length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. + This is handled in the :meth:`FpgaTrials.extract_behaviour_sync` method. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has + # proven to drift on some Bpods and this is the highest possible value that + # discriminates trial start from valve. Valve open events are between 50ms to 300 ms. + # ITI events are above 400 ms. + bpod_event_ttls = { + 'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + return bpod, bpod_event_intervals + + @staticmethod + def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False): + """ + Classify TTL events by length. + + Outputs the synchronisation events such as trial intervals, valve opening, and audio. + + Parameters + ---------- + ts : numpy.array + Numpy vector containing times of TTL fronts. + polarities : numpy.array + Numpy vector containing polarity of TTL fronts (1 rise, -1 fall). + event_lengths : dict of tuple + A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1. + precedence : str {'shortest', 'longest', 'dict order'} + In the case of overlapping event TTL lengths, assign shortest/longest first or go by + the `event_lengths` dict order. + display : bool + If true, plots the TTLs with coloured lines delineating the assigned events. + + Returns + ------- + Dict[str, numpy.array] + A dictionary of events and their intervals as an Nx2 array. + + See Also + -------- + _assign_events_to_trial - classify TTLs by event order within a given trial period. + """ + event_intervals = dict.fromkeys(event_lengths) + assert 'unassigned' not in event_lengths.keys() + + if len(ts) == 0: + return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')} + + # make sure that there are no 2 consecutive fall or consecutive rise events + assert np.all(np.abs(np.diff(polarities)) == 2) + if polarities[0] == -1: + ts = np.delete(ts, 0) + if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN + ts = np.r_[ts, np.nan] + # take only even time differences: i.e. from rising to falling fronts + dt = np.diff(ts)[::2] + + # Assign events from shortest TTL to largest + assigned = np.zeros(ts.shape, dtype=bool) + if precedence.lower() == 'shortest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1])) + elif precedence.lower() == 'longest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True) + elif precedence.lower() == 'dict order': + event_items = event_lengths.items() + else: + raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".') + for event, (min_len, max_len) in event_items: + _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len) + i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2 + i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned + event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]] + assigned[np.r_[i_event, i_event + 1]] = True + + # Include the unassigned events for convenience and debugging + event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2) + + # Assert that event TTLs mutually exclusive + all_assigned = np.concatenate(list(event_intervals.values())).flatten() + assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events' + + # some debug plots when needed + if display: # pragma: no cover + plt.figure() + plots.squares(ts, polarities, label='raw fronts') + for event, intervals in event_intervals.items(): + plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event) + plt.legend() + + # Return map of event intervals in the same order as `event_lengths` dict + return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')} + + @staticmethod + def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): + """ + Sync the Bpod clock to FPGA one using the provided trial event. + + It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both + intervals is not supported so to sync on trial start times, `sync_field` should be + 'intervals_0'. + + Parameters + ---------- + bpod_trials : dict + A dictionary of extracted Bpod trial events. + fpga_trials : dict + A dictionary of trial events extracted from FPGA sync events (see + `extract_behaviour_sync` method). + sync_field : str + The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the + column index, e.g. 'intervals_0'. + + Returns + ------- + function + Interpolation function such that f(timestamps_bpod) = timestamps_fpga. + float + The clock drift in parts per million. + numpy.array of int + The indices of the Bpod trial events in the FPGA trial events array. + numpy.array of int + The indices of the FPGA trial events in the Bpod trial events array. + + Raises + ------ + ValueError + The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. + """ + _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') + if sync_field not in bpod_trials: + # handle syncing on intervals + if not (m := re.match(r'(.*)_(\d)', sync_field)): + raise ValueError(f'Sync field "{sync_field}" not in extracted bpod trials') + sync_field, i = m.groups() + timestamps_bpod = bpod_trials[sync_field][:, int(i)] + timestamps_fpga = fpga_trials[sync_field][:, int(i)] + elif sync_field not in fpga_trials: + raise ValueError(f'Sync field "{sync_field}" not in extracted fpga trials') + else: + timestamps_bpod = bpod_trials[sync_field] + timestamps_fpga = fpga_trials[sync_field] + + # Sync the two timestamps + fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( + timestamps_bpod, timestamps_fpga, return_indices=True) + + # If it's drifting too much throw warning or error + _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', + len(timestamps_bpod), len(timestamps_fpga), len(ibpod), drift) + if drift > 200 and timestamps_bpod.size != timestamps_fpga.size: + raise err.SyncBpodFpgaException('sync cluster f*ck') + elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: + _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', + BPOD_FPGA_DRIFT_THRESHOLD_PPM) + + return fcn, drift, ibpod, ifpga + + +class FpgaTrialsHabituation(FpgaTrials): + """Extract habituationChoiceWorld trial events from an NI DAQ.""" + + save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy', + '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', + '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy', + '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy', + None, None, None, None, None) + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + + var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', + 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', + 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', + 'stimCenterTrigger_times', 'position', 'phase') + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + + bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times', + 'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times', + 'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" + + bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase') + """tuple of str: Fields from Bpod extractor that we want to save.""" + + sync_field = 'feedback_times' # valve open events + """str: The trial event to synchronize (must be present in extracted trials).""" + + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: + """ + Extract habituationChoiceWorld trial events from an NI DAQ. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock + using the valve open times, instead of the trial start times. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the `load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. + """ + # Version check: the ITI in TTL was added in a later version + iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) + if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): + """A second 1s TTL was added in this version during the 'iti' state, however this is + unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" + raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') + + # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse + if 'bpod_event_ttls' not in kwargs: + kwargs['bpod_event_ttls'] = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} + trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, + task_collection=task_collection, **kwargs) + + n = trials['intervals'].shape[0] # number of trials + trials['intervals'][:, 1] = self.bpod2fpga(self.bpod_trials['intervals'][:n, 1]) + + return trials + + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + This is called by the superclass `_extract` method. The key difference here is that the + `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}: + raise ValueError( + 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') + + # The first trial pulse is shorter and assigned to valve_open. Here we remove the first + # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was + # incomplete in Bpod. + n_trials = self.bpod_trials['intervals'].shape[0] + t_valve_open = bpod_event_intervals['valve_open'][1:, 0] # drop first spurious valve event + t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] + t_trial_start = np.r_[0, bpod_event_intervals['trial_iti'][:, 1]] + t_trial_end = bpod_event_intervals['trial_iti'][:, 0] + + start_times = start_times or t_trial_start + + trials = alfio.AlfBunch({ + 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first')[:n_trials], + 'feedback_times': _assign_events_to_trial(start_times, t_valve_open)[:n_trials], + 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2)[:n_trials], + 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first')[:n_trials], + 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times'])[:n_trials], + # These 'raw' intervals will be used in the sync + 'intervals_1': _assign_events_to_trial(start_times, t_trial_end), + 'intervals_0': start_times + }) + + # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off + trials['intervals'] = np.c_[trials['intervals_0'], trials['intervals_1']][:n_trials, :] + to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < trials['intervals'][:, 0]) + if np.any(to_correct): + _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) + trials['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] + to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > trials['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + trials['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2]) + ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 5]) + + return trials + def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None, task_collection='raw_behavior_data', protocol_number=None, **kwargs): @@ -883,7 +1583,11 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ - sync - wheel - behaviour - - video time stamps + + These `extract_all` functions should be deprecated as they make assumptions about hardware + parameters. Additionally the FpgaTrials class now automatically loads DAQ sync files, extracts + the Bpod trials, and returns a dict instead of a tuple. Therefore this function is entirely + redundant. See the examples for the correct way to extract NI DAQ behaviour sessions. Parameters ---------- @@ -909,6 +1613,10 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ list of pathlib.Path, None If save is True, a list of file paths to the extracted data. """ + warnings.warn( + 'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; ' + 'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.', + FutureWarning) # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index 59a29a269..655ea2de1 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -73,8 +73,7 @@ def _extract(self) -> dict: There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. If 1 or more pulses are missing, we can not be confident of assigning the correct one. """ - out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan - for sync, off in zip(ttls[1:], ends)]) + out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) # Trial intervals """ @@ -85,8 +84,13 @@ def _extract(self) -> dict: # NB: We lose the last trial because the stim off event occurs at trial_num + 1 n_trials = out['stimOff_times'].size out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] - to_update = out['intervals'][:, 1] < out['stimOff_times'] - out['intervals'][to_update, 1] = out['stimOff_times'][to_update] + + to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] # itiIn times out['itiIn_times'] = np.r_[ends, np.nan] @@ -133,11 +137,8 @@ def _extract(self) -> dict: # Double-check that the early and late trial events occur within the trial intervals idx = ~np.isnan(out['stimOn_times'][:n_trials]) - if np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]): - _logger.warning('Stim on events occurring outside trial intervals') - idx = ~np.isnan(out['stimOff_times']) - if np.any(out['stimOff_times'][idx] > out['intervals'][idx, 1]): - _logger.warning('Stim off events occurring outside trial intervals') + assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ + 'Stim on events occurring outside trial intervals' # Truncate arrays and return in correct order return {k: out[k][:n_trials] for k in self.var_names} diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 5bcaf2873..fd9854455 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -413,7 +413,6 @@ def iter_dict(d): for d in filter(lambda x: isinstance(x, dict), v): iter_dict(d) elif isinstance(v, dict) and 'collection' in v: - print(k) # if the key already exists, append the collection name to the list if k in collection_map: clist = collection_map[k] if isinstance(collection_map[k], list) else [collection_map[k]] diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 6f1c8d506..85e21c7ac 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -14,7 +14,7 @@ from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map +from ibllib.io.extractors.ephys_fpga import FpgaTrials, FpgaTrialsHabituation, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -102,14 +102,61 @@ def _run_qc(self, trials_data=None, update=True): qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Currently only the data field is accessed + # Update extractor fields qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + qc.extractor.frame_ttls = self.extractor.frame2ttl # used in iblapps QC viewer + qc.extractor.audio_ttls = self.extractor.audio # used in iblapps QC viewer + qc.extractor.settings = self.extractor.settings namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) return qc +class HabituationTrialsNidq(HabituationTrialsBpod): + priority = 90 + job_size = 'small' + + @property + def signature(self): + signature = super().signature + signature['input_files'] = [ + ('_iblrig_taskData.raw.*', self.collection, True), + ('_iblrig_taskSettings.raw.*', self.collection, True), + (f'_{self.sync_namespace}_sync.channels.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.polarities.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.times.npy', self.sync_collection, True), + ('*wiring.json', self.sync_collection, False), + ('*.meta', self.sync_collection, True)] + return signature + + def _extract_behaviour(self, save=True, **kwargs): + """Extract the habituationChoiceWorld trial data using NI DAQ clock.""" + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) + + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrialsHabituation( + self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + + # NB: The stimOff times are called stimCenter times for habituation choice world + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files + + def _run_qc(self, trials_data=None, update=True, **_): + """Run and update QC. + + This adds the bpod TTLs to the QC object *after* the QC is run in the super call method. + The raw Bpod TTLs are not used by the QC however they are used in the iblapps QC plot. + """ + qc = super()._run_qc(trials_data=trials_data, update=update) + qc.extractor.bpod_ttls = self.extractor.bpod + return qc + + class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): priority = 100 job_size = 'small' @@ -286,9 +333,9 @@ def _run_qc(self, trials_data=None, update=True): else: qc = TaskQC(self.session_path, one=self.one, log=_logger) qc_extractor.wheel_encoding = 'X1' - qc_extractor.settings = self.extractor.settings - qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( - self.session_path, task_collection=self.collection) + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) qc.extractor = qc_extractor # Aggregate and update Alyx QC fields @@ -370,14 +417,15 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False): qc = HabituationQC(self.session_path, one=self.one, log=_logger) else: qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc_extractor.settings = self.extractor.settings # Add Bpod wheel data wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] qc_extractor.wheel_encoding = 'X4' - qc_extractor.frame_ttls = self.extractor.frame2ttl - qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.bpod_ttls = self.extractor.bpod + qc_extractor.settings = self.extractor.settings qc.extractor = qc_extractor # Aggregate and update Alyx QC fields diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index bc2caaf1b..3c72853fb 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -230,22 +230,28 @@ def make_pipeline(session_path, **pkwargs): # - choice_world_biased # - choice_world_training # - choice_world_habituation - if 'habituation' in protocol: - registration_class = btasks.HabituationRegisterRaw - behaviour_class = btasks.HabituationTrialsBpod - compute_status = False - elif 'passiveChoiceWorld' in protocol: + if 'passiveChoiceWorld' in protocol: registration_class = btasks.PassiveRegisterRaw behaviour_class = btasks.PassiveTask compute_status = False elif sync_kwargs['sync'] == 'bpod': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsBpod - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsBpod + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsBpod + compute_status = True elif sync_kwargs['sync'] == 'nidq': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsNidq - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsNidq + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsNidq + compute_status = True else: raise NotImplementedError tasks[f'RegisterRaw_{protocol}_{i:02}'] = type(f'RegisterRaw_{protocol}_{i:02}', (registration_class,), {})( diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 42361645d..d746626d5 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -1,4 +1,5 @@ -"""Behaviour QC +"""Behaviour QC. + This module runs a list of quality control metrics on the behaviour data. Examples @@ -179,20 +180,22 @@ def run(self, update=False, namespace='task', **kwargs): return outcome, results @staticmethod - def compute_session_status_from_dict(results): + def compute_session_status_from_dict(results, criteria=None): """ Given a dictionary of results, computes the overall session QC for each key and aggregates in a single value - :param results: a dictionary of qc keys containing (usually scalar) values + :param results: a dictionary of qc keys containing (usually scalar) values. + :param criteria: a dictionary of qc keys containing map of PASS, WARNING, FAIL thresholds. :return: Overall session QC outcome as a string :return: A dict of QC tests and their outcomes """ indices = np.zeros(len(results), dtype=int) + criteria = criteria or TaskQC.criteria for i, k in enumerate(results): - if k in TaskQC.criteria.keys(): - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria[k]) + if k in criteria.keys(): + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria[k]) else: - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria['default']) + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria['default']) def key_map(x): return 'NOT_SET' if x < 0 else list(TaskQC.criteria['default'].keys())[x] @@ -213,14 +216,19 @@ def compute_session_status(self): # Get mean passed of each check, or None if passed is None or all NaN results = {k: None if v is None or np.isnan(v).all() else np.nanmean(v) for k, v in self.passed.items()} - session_outcome, outcomes = self.compute_session_status_from_dict(results) + session_outcome, outcomes = self.compute_session_status_from_dict(results, self.criteria) return session_outcome, results, outcomes class HabituationQC(TaskQC): - def compute(self, download_data=None): - """Compute and store the QC metrics + criteria = dict() + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_phase_distribution'] = {'PASS': 0.99, 'NOT_SET': 0} # This rarely passes due to low trial num + + def compute(self, download_data=None, **kwargs): + """Compute and store the QC metrics. + Runs the QC on the session and stores a map of the metrics for each datapoint for each test, and a map of which datapoints passed for each test :return: @@ -228,7 +236,7 @@ def compute(self, download_data=None): if self.extractor is None: # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data - self.load_data(download_data=ensure_data) + self.load_data(download_data=ensure_data, **kwargs) self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks @@ -302,6 +310,7 @@ def compute(self, download_data=None): passed[check] = (metric <= 2 * np.pi) & (metric >= 0) metrics[check] = metric + # This is not very useful as a check because there are so few trials check = prefix + 'phase_distribution' metric, _ = np.histogram(data['phase']) _, p = chisquare(metric) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index ca211e426..465322810 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -1,15 +1,11 @@ +"""Tests for ephys FPGA sync and FPGA wheel extraction.""" import unittest import tempfile from pathlib import Path -import pickle -import logging import numpy as np -from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units from ibllib.io.extractors import ephys_fpga -from ibllib.io.extractors.training_wheel import extract_wheel_moves -import brainbox.behavior.wheel as wh import spikeglx @@ -88,189 +84,12 @@ def test_ibl_sync_maps(self): self.assertEqual(s, ephys_fpga.CHMAPS['3B']['ap']) -class TestWheelExtraction(unittest.TestCase): - - def setUp(self) -> None: - self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - - def test_x1_decoding(self): - p_ = np.array([1, 2, 1, 0]) - t_ = np.array([2, 6, 11, 15]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - def test_x4_decoding(self): - p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 - t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(np.isclose(p, p_))) - - def test_x2_decoding(self): - p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 - t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - -class TestExtractedWheelUnits(unittest.TestCase): - """Tests the infer_wheel_units function""" - - wheel_radius_cm = 3.1 - - def setUp(self) -> None: - """ - Create the wheel position data for testing: the positions attribute holds a dictionary of - units, each holding a dictionary of encoding types to test, e.g. - - positions = { - 'rad': { - 'X1': ..., - 'X2': ..., - 'X4': ... - }, - 'cm': { - 'X1': ..., - 'X2': ..., - 'X4': ... - } - } - :return: - """ - def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): - radius = 1 if unit == 'rad' else wheel_radius - return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc - - # A pseudo-random sequence of integrated fronts - seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) - encs = (1, 2, 4) # Encoding types to test - units = ('rad', 'cm') # Units to test - self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} - - def test_extract_wheel_moves(self): - for unit in self.positions.keys(): - for encoding, pos in self.positions[unit].items(): - result = infer_wheel_units(pos) - self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') - expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) - self.assertEqual(expected, result[1], - f'failed to determine number of ticks for {encoding} in {unit}') - self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') - - -class TestWheelMovesExtraction(unittest.TestCase): - - def setUp(self) -> None: - """ - Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a - numpy array of timestamps and one of positions; outputs is a tuple of outputs from - the functions. For details, see help on TestWheel.setUp method in module - brainbox.tests.test_behavior - """ - pickle_file = Path(__file__).parents[3].joinpath( - 'brainbox', 'tests', 'fixtures', 'wheel_test.p') - if not pickle_file.exists(): - self.test_data = None - else: - with open(pickle_file, 'rb') as f: - self.test_data = pickle.load(f) - - # Some trial times for trial_data[1] - self.trials = { - 'goCue_times': np.array([162.5, 105.6, 55]), - 'feedback_times': np.array([164.3, 108.3, 56]) - } - - def test_extract_wheel_moves(self): - test_data = self.test_data[1] - # Wrangle data into expected form - re_ts = test_data[0][0] - re_pos = test_data[0][1] - - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) - - n = 56 # expected number of movements - self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), - 'failed to return the correct number of intervals') - self.assertEqual(wheel_moves['peakAmplitude'].size, n) - self.assertEqual(wheel_moves['peakVelocity_times'].size, n) - - # Check the first 3 intervals - ints = np.array( - [[24.78462599, 25.22562599], - [29.58762599, 31.15062599], - [31.64262599, 31.81662599]]) - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - # Check amplitudes - actual = wheel_moves['peakAmplitude'][-3:] - expected = [0.50255486, -1.70103154, 1.00740789] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') - - # Check peak velocities - actual = wheel_moves['peakVelocity_times'][-3:] - expected = [175.13662599, 176.65762599, 178.57262599] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') - - # Test extraction in rad - re_pos = wh.cm_to_rad(re_pos) - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) - - # Check the first 3 intervals. As position thresholds are adjusted by units and - # encoding, we should expect the intervals to be identical to above - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - def test_movement_log(self): - """ - Integration test for inferring the units and decoding type for wheel data input for - extract_wheel_moves. Only expected to work for the default wheel diameter. - """ - ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - - for unit in ['cm', 'rad']: - for i in (1, 2, 4): - encoding = 'X' + str(i) - r = 3.1 if unit == 'cm' else 1 - # print(encoding, unit) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) - expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' - with self.assertLogs(logger, level='INFO') as cm: - ephys_fpga.extract_wheel_moves(t, p) - self.assertEqual([expected], cm.output) - - def test_extract_first_movement_times(self): - test_data = self.test_data[1] - wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) - first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) - np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) - np.testing.assert_array_equal(is_final, [False, True, False]) - np.testing.assert_array_equal(ind, [46, 18]) - - class TestEphysFPGA_TTLsExtraction(unittest.TestCase): def test_audio_ttl_wiring_camera(self): """ + Test ephys_fpga._clean_audio function. + Test removal of spurious TTLs due to a wrong wiring of the camera onto the soundcard example eid: e349a2e7-50a3-47ca-bc45-20d1899854ec """ diff --git a/ibllib/tests/extractors/test_ephys_trials.py b/ibllib/tests/extractors/test_ephys_trials.py index d5483792f..7d77079af 100644 --- a/ibllib/tests/extractors/test_ephys_trials.py +++ b/ibllib/tests/extractors/test_ephys_trials.py @@ -1,15 +1,23 @@ import unittest from pathlib import Path +import pickle + import numpy as np from ibllib.io.extractors import ephys_fpga, biased_trials import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units +from ibllib.io.extractors.training_wheel import extract_wheel_moves +import brainbox.behavior.wheel as wh class TestEphysSyncExtraction(unittest.TestCase): def test_bpod_trace_extraction(self): + """Test ephys_fpga._assign_events_bpod function. + TODO Remove this test and corresponding function. + """ t_valve_open_ = np.array([117.12136667, 122.3873, 127.82903333, 140.56083333, 143.55326667, 155.29713333, 164.9186, 167.91133333, 171.39736667, 178.0305, 181.70343333]) @@ -48,6 +56,7 @@ def test_bpod_trace_extraction(self): self.assertTrue(np.all(np.isclose(t_valve_open, t_valve_open_))) def test_align_to_trial(self): + """Test ephys_fpga._assign_events_to_trial function.""" # simple test with one missing at the end t_trial_start = np.arange(0, 5) * 10 t_event = np.arange(0, 5) * 10 + 2 @@ -95,6 +104,7 @@ def test_align_to_trial(self): self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) def test_wheel_trace_from_sync(self): + """Test ephys_fpga._rotary_encoder_positions_from_fronts function.""" pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) ta = np.array([1, 2, 3, 4, 5, 6]) tb = np.array([0.5, 3.2, 3.3, 3.4, 5.25, 5.5]) @@ -137,5 +147,184 @@ def test_get_probabilityLeft(self): self.assertTrue(all([x in [0.2, 0.5, 0.8] for x in np.unique(pLeft1)])) +class TestWheelExtraction(unittest.TestCase): + + def setUp(self) -> None: + self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + + def test_x1_decoding(self): + p_ = np.array([1, 2, 1, 0]) + t_ = np.array([2, 6, 11, 15]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + def test_x4_decoding(self): + p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 + t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(np.isclose(p, p_))) + + def test_x2_decoding(self): + p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 + t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + +class TestExtractedWheelUnits(unittest.TestCase): + """Tests the infer_wheel_units function""" + + wheel_radius_cm = 3.1 + + def setUp(self) -> None: + """ + Create the wheel position data for testing: the positions attribute holds a dictionary of + units, each holding a dictionary of encoding types to test, e.g. + + positions = { + 'rad': { + 'X1': ..., + 'X2': ..., + 'X4': ... + }, + 'cm': { + 'X1': ..., + 'X2': ..., + 'X4': ... + } + } + :return: + """ + def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): + radius = 1 if unit == 'rad' else wheel_radius + return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc + + # A pseudo-random sequence of integrated fronts + seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) + encs = (1, 2, 4) # Encoding types to test + units = ('rad', 'cm') # Units to test + self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} + + def test_extract_wheel_moves(self): + for unit in self.positions.keys(): + for encoding, pos in self.positions[unit].items(): + result = infer_wheel_units(pos) + self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') + expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) + self.assertEqual(expected, result[1], + f'failed to determine number of ticks for {encoding} in {unit}') + self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') + + +class TestWheelMovesExtraction(unittest.TestCase): + + def setUp(self) -> None: + """ + Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a + numpy array of timestamps and one of positions; outputs is a tuple of outputs from + the functions. For details, see help on TestWheel.setUp method in module + brainbox.tests.test_behavior + """ + pickle_file = Path(__file__).parents[3].joinpath( + 'brainbox', 'tests', 'fixtures', 'wheel_test.p') + if not pickle_file.exists(): + self.test_data = None + else: + with open(pickle_file, 'rb') as f: + self.test_data = pickle.load(f) + + # Some trial times for trial_data[1] + self.trials = { + 'goCue_times': np.array([162.5, 105.6, 55]), + 'feedback_times': np.array([164.3, 108.3, 56]) + } + + def test_extract_wheel_moves(self): + test_data = self.test_data[1] + # Wrangle data into expected form + re_ts = test_data[0][0] + re_pos = test_data[0][1] + + logname = 'ibllib.io.extractors.training_wheel' + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) + + n = 56 # expected number of movements + self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), + 'failed to return the correct number of intervals') + self.assertEqual(wheel_moves['peakAmplitude'].size, n) + self.assertEqual(wheel_moves['peakVelocity_times'].size, n) + + # Check the first 3 intervals + ints = np.array( + [[24.78462599, 25.22562599], + [29.58762599, 31.15062599], + [31.64262599, 31.81662599]]) + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + # Check amplitudes + actual = wheel_moves['peakAmplitude'][-3:] + expected = [0.50255486, -1.70103154, 1.00740789] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') + + # Check peak velocities + actual = wheel_moves['peakVelocity_times'][-3:] + expected = [175.13662599, 176.65762599, 178.57262599] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') + + # Test extraction in rad + re_pos = wh.cm_to_rad(re_pos) + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) + + # Check the first 3 intervals. As position thresholds are adjusted by units and + # encoding, we should expect the intervals to be identical to above + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + def test_movement_log(self): + """ + Integration test for inferring the units and decoding type for wheel data input for + extract_wheel_moves. Only expected to work for the default wheel diameter. + """ + ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + logname = 'ibllib.io.extractors.training_wheel' + + for unit in ['cm', 'rad']: + for i in (1, 2, 4): + encoding = 'X' + str(i) + r = 3.1 if unit == 'cm' else 1 + # print(encoding, unit) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) + expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' + with self.assertLogs(logname, level='INFO') as cm: + ephys_fpga.extract_wheel_moves(t, p) + self.assertEqual([expected], cm.output) + + def test_extract_first_movement_times(self): + test_data = self.test_data[1] + wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) + first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) + np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) + np.testing.assert_array_equal(is_final, [False, True, False]) + np.testing.assert_array_equal(ind, [46, 18]) + + if __name__ == '__main__': unittest.main(exit=False, verbosity=2) From d2294982f5c7462f16fee375c4d44ea9767fc9f1 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Tue, 5 Dec 2023 14:54:49 +0200 Subject: [PATCH 3/7] Remove test reference to module constant --- ibllib/tests/extractors/test_ephys_fpga.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index 465322810..fdfe27218 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -208,14 +208,15 @@ def test_frame2ttl_flickers(self): switches under a given threshold """ DISPLAY = False # for debug purposes - diff = ephys_fpga.F2TTL_THRESH * np.array([0.5, 10]) + F2TTL_THRESH = 0.01 + diff = F2TTL_THRESH * np.array([0.5, 10]) # flicker ends with a polarity switch - downgoing pulse is removed t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 1])])] + 1 frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} expected = {'times': np.array([1., 1.1, 1.2, 1.31]), 'polarities': np.array([1, -1, 1, -1])} - frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY) + frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY, threshold=F2TTL_THRESH) assert all([np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_]) # stand-alone flicker From 8aff0ad6875fba0367545cb8673400722cc29830 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 13:01:35 +0200 Subject: [PATCH 4/7] mesoscope trials extractor refactor; fix attribute_times numpy version bug; FpgaTrials.build_trials method call after sync --- ibllib/io/extractors/camera.py | 6 +- ibllib/io/extractors/ephys_fpga.py | 65 +++-- ibllib/io/extractors/mesoscope.py | 373 +++++++++++++++++++++-------- ibllib/io/raw_daq_loaders.py | 2 +- 4 files changed, 320 insertions(+), 126 deletions(-) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 7612c3e9e..93554c86a 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -513,12 +513,16 @@ def attribute_times(arr, events, tol=.1, injective=True, take='first'): Returns ------- numpy.array - An array the same length as `events`. + An array the same length as `events` containing indices of `arr` corresponding to each + event. """ if (take := take.lower()) not in ('first', 'nearest', 'after'): raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') stack = np.ma.masked_invalid(arr, copy=False) stack.fill_value = np.inf + # If there are no invalid values, the mask is False so let's ensure it's a bool array + if stack.mask is np.bool_(0): + stack.mask = np.zeros(arr.shape, dtype=bool) assigned = np.full(events.shape, -1, dtype=int) # Initialize output array min_tol = 0 if take == 'after' else -tol for i, x in enumerate(events): diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index ad2cb0ab5..aa042ce8e 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -900,8 +900,8 @@ def load_sync(self, sync_collection='raw_ephys_data', **kwargs): Returns ------- one.alf.io.AlfBunch - A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and - the corresponding channel numbers. + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. dict A map of channel names and their corresponding indices. """ @@ -992,24 +992,9 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', # Get the trial events from the DAQ sync TTLs fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) - # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + # Sync clocks and build final trials datasets + out = self.build_trials(fpga_trials, sync=sync, chmap=chmap, **kwargs) - if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': - # One issue is that sometimes pulses may not have been detected, in this case - # add the events that have not been detected and re-extract the behaviour sync. - # This is only really relevant for the Bpod interval events as the other TTLs are - # from devices where a missing TTL likely means the Bpod event was truly absent. - _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') - bpod_start = self.bpod_trials['intervals'][:, 0] - missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) - t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) - fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) - - out = dict() - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) # extract the wheel data if any(x.startswith('wheel') for x in self.var_names): wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) @@ -1096,9 +1081,16 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * ] bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + t_trial_start = bpod_event_intervals['trial_start'][:, 0] t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - # Drop last trial start if incomplete - t_trial_start = bpod_event_intervals['trial_start'][:len(t_trial_end), 0] + # Some protocols, e.g. Guido's ephys biased opto task, have no trial end TTL. + # This is not essential as the trial start is used to sync the clocks. + if t_trial_end.size == 0: + _logger.warning('No trial end / ITI in TTLs found') + t_trial_end = np.full_like(t_trial_start, np.nan) + else: + # Drop last trial start if incomplete + t_trial_start = t_trial_start[:len(t_trial_end)] t_valve_open = bpod_event_intervals['valve_open'][:, 0] t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] t_error_tone_in = audio_event_intervals['error_tone'][:, 0] @@ -1136,12 +1128,33 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() - ax.set_yticks([0, 1, 2]) - ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 5]) return trials + def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) + fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) + + out = dict() + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) + return out + def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -1569,9 +1582,9 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() - ax.set_yticks([0, 1, 2]) - ax.set_yticklabels(['bpod', 'f2ttl', 'audio']) - ax.set_ylim([0, 5]) + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 4]) return trials diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 4def5ed3a..84a7622e7 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -1,22 +1,24 @@ """Mesoscope (timeline) data extraction.""" import logging +from itertools import cycle import numpy as np +from scipy.signal import find_peaks import one.alf.io as alfio from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from neurodsp.utils import falls +from matplotlib.colors import TABLEAU_COLORS from pkg_resources import parse_version from ibllib.plots.misc import squares, vertical_lines from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, correct_counter_discontinuities, load_timeline_sync_and_chmap) import ibllib.io.extractors.base as extractors_base -from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, get_sync_fronts, get_protocol_period +from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial from ibllib.io.extractors.training_wheel import extract_wheel_moves from ibllib.io.extractors.camera import attribute_times -from ibllib.io.extractors.ephys_fpga import _assign_events_bpod +from brainbox.behavior.wheel import velocity_filtered _logger = logging.getLogger(__name__) @@ -102,103 +104,240 @@ def plot_timeline(timeline, channels=None, raw=True): class TimelineTrials(FpgaTrials): """Similar extraction to the FPGA, however counter and position channels are treated differently.""" - """one.alf.io.AlfBunch: The timeline data object""" timeline = None + """one.alf.io.AlfBunch: The timeline data object.""" + + sync_field = 'itiIn_times' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): """An extractor for all ephys trial data, in Timeline time""" super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: - if not (sync or chmap): - sync, chmap = load_timeline_sync_and_chmap( - self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): + """Load the DAQ sync and channel map data. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + if not self.timeline: + self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') + sync, chmap = load_timeline_sync_and_chmap( + self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + return sync, chmap + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: + trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) if kwargs.get('display', False): plot_timeline(self.timeline, channels=chmap.keys(), raw=True) - trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) - - # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials['table'] - bpod = get_sync_fronts(sync, chmap['bpod']) - if kwargs.get('protocol_number') is None: - tmin = trials_table.intervals_0.iloc[0] - 1 - tmax = trials_table.intervals_1.iloc[-1] - # Ensure wheel is cut off based on trials - mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) - trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] - trials['wheel_position'] = trials['wheel_position'][mask] - mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) - trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] + return trials + + def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + """ + Extract task related event times from the sync. + + TODO Change docstring + The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The + first trial start TTL of the session is longer and must be handled differently. The trial + start TTL is used to assign the other trial events to each trial. + + The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest + of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio + tones. The first of these after each trial start is taken to be the go cue time. Error + tones are longer audio TTLs and assigned as the last of such occurrence after each trial + start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. + The feedback times are times of either valve open or error tone as there should be only one + such event per trial. + + The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs + removed): the first TTL after each trial start is assumed to be the stim onset time; the + second to last and last are taken as the stimulus freeze and offset times, respectively. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + start_times : numpy.array + An optional array of timestamps to separate trial events by. This is useful if after + syncing the clocks, some trial start TTLs are found to be missed. If None, uses + 'trial_start' Bpod event. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_end" and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + trials = alfio.AlfBunch({ + 'itiIn_times': t_iti_in, + 'intervals_1': t_trial_end, + 'valveOpen_intervals': bpod_event_intervals['valve_open'], + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorTone_times': audio_event_intervals['error_tone'][:, 0] + }) + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.items(), cycle(color_map)): + vertical_lines(event_times.flat, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 4]) + + return trials + + def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + + out = dict() + out['intervals'] = self.bpod2fpga(self.bpod_trials['intervals']) + out['itiIn_times'] = fpga_trials['itiIn_times'][ifpga] + start_times = out['intervals'][:, 0] + + # Extract valve open times from the DAQ + valve_driver_ttls = fpga_trials.pop('valveOpen_intervals') + correct = self.bpod_trials['feedbackType'] == 1 + # If there is a reward_valve channel, the valve has + if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): + # TODO Let's look at the expected open length based on calibration and reward volume + # import scipy.interpolate + # # FIXME support v7 settings? + # fcn_vol2time = scipy.interpolate.pchip( + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] + # ) + # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 + + # Use the driver TTLs to find the valve open times that correspond to the valve opening + valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) + if valve_open_times.size != np.sum(correct): + _logger.warning( + 'Number of valve open times does not equal number of correct trials (%i != %i)', + valve_open_times.size, np.sum(correct)) + + out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_open_times) else: - tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) - bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) - - self.frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin, tmax) # save for later access by QC - - # Replace valve open times with those extracted from the DAQ - # TODO Let's look at the expected open length based on calibration and reward volume - assert len(bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' - _, driver_out, _, = _assign_events_bpod(bpod['times'], bpod['polarities'], False) - # Use the driver TTLs to find the valve open times that correspond to the valve opening - valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) - assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion - correct = trials_table.feedbackType == 1 - trials['valveOpen_times'][correct] = valve_open_times - trials_table.feedback_times[correct] = valve_open_times - - # Replace audio events - self.audio = get_sync_fronts(sync, chmap['audio'], tmin, tmax) - # Attempt to assign the go cue and error tone onsets based on TTL length - go_cue, error_cue = self._assign_events_audio(self.audio['times'], self.audio['polarities']) - - assert error_cue.size == np.sum(~correct), 'N detected error tones does not match number of incorrect trials' - assert go_cue.size <= len(trials_table), 'More go cue tones detected than trials!' - - if go_cue.size < len(trials_table): - _logger.warning('%i go cue tones missed', len(trials_table) - go_cue.size) + # Use the valve controller TTLs recorded on the Bpod channel as the reward time + out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_driver_ttls[:, 0]) + + # Stimulus times extracted the same as usual + out['stimFreeze_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2) + out['stimOn_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first') + out['stimOff_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times']) + + # Audio times + error_cue = fpga_trials['errorTone_times'] + if error_cue.size != np.sum(~correct): + _logger.warning( + 'N detected error tones does not match number of incorrect trials (%i != %i)', + error_cue.size, np.sum(~correct)) + go_cue = fpga_trials['goCue_times'] + out['goCue_times'] = _assign_events_to_trial(start_times, go_cue, take='first') + out['errorCue_times'] = _assign_events_to_trial(start_times, error_cue) + + if go_cue.size > start_times.size: + _logger.warning( + 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) + elif go_cue.size < start_times.size: """ If the error cues are all assigned and some go cues are missed it may be that some - responses were so fast that the go cue and error tone merged. + responses were so fast that the go cue and error tone merged, or the go cue TTL was too + long. """ + _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) assert not np.any(np.isnan(go_trig)) - assert err_trig.size == go_trig.size - - def first_true(arr): - """Return the index of the first True value in an array.""" - indices = np.where(arr)[0] - return None if len(indices) == 0 else indices[0] + assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue - _go_cue = np.full(len(trials_table), np.nan) - for i, intervals in enumerate(trials_table[['intervals_0', 'intervals_1']].values): - idx = first_true(np.logical_and(go_cue > intervals[0], go_cue < intervals[1])) - if idx is not None: - _go_cue[i] = go_cue[idx] + _go_cue = _assign_events_to_trial(start_times, go_cue, take='first') + error_cue = _assign_events_to_trial(start_times, error_cue) + missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH raw = timeline_get_channel(self.timeline, 'audio') raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH - for i in np.where(np.isnan(_go_cue))[0]: - # Get the timestamp of the first HIGH after the trigger times - _go_cue[i] = ups[first_true(ups > go_trig[i])] - idx = first_true(np.logical_and( - error_cue > trials_table['intervals_0'][i], - error_cue < trials_table['intervals_1'][i])) - if np.isnan(err_trig[i]): - if idx is not None: - error_cue = np.delete(error_cue, idx) # Remove mis-assigned error tone time - else: - error_cue[idx] = ups[first_true(ups > err_trig[i])] - go_cue = _go_cue - - trials_table.feedback_times[~correct] = error_cue - trials_table.goCue_times = go_cue - return {k: trials[k] for k in self.var_names} + + # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). + # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) + idx = attribute_times(ups, go_trig, tol=0.2, take='after') + # Trial indices that didn't have detected goCue and now has been assigned an `ups` index + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + _go_cue[assigned] = ups[idx[assigned]] + + # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) + error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) + i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) + error_cue[i_to_remove] = np.nan + + # For those trials where go cue was merged with the error cue and therefore mis-assigned, + # we must re-assign the error cue times as the first HIGH after the error trigger. + idx = attribute_times(ups, err_trig, tol=0.2, take='after') + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + error_cue[assigned] = ups[idx[assigned]] + out['goCue_times'] = _go_cue + out['errorCue_times'] = error_cue + + # Because we're not + assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ + 'audio tones not assigned correctly; tones likely missed' + + # Feedback times + out['feedback_times'] = np.copy(out['valveOpen_times']) + ind_err = np.isnan(out['valveOpen_times']) + out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] + + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + return out def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ @@ -234,7 +373,7 @@ def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding=' # Timeline evenly samples counter so we extract only change points d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) pos = raw[ind + 1] pos -= pos[0] # Start from zero pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians @@ -290,7 +429,7 @@ def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding= ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') return wheel, moves - def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=10, driver_ttls=None): + def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): """ Get the valve open times from the raw timeline voltage trace. @@ -299,44 +438,82 @@ def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=1 display : bool Plot detected times on the raw voltage trace. threshold : float - The threshold for applying to analogue channels. - floor_percentile : float - 10% removes the percentile value of the analog trace before thresholding. This is to - avoid DC offset drift. + The threshold of voltage change to apply. The default was set by eye; units should be + Volts per sample but doesn't appear to be. driver_ttls : numpy.array An optional array of driver TTLs to use for assigning with the valve times. Returns ------- numpy.array - The detected valve open times. - - TODO extract close times too + The detected valve open intervals. + numpy.array + If driver_ttls is not None, returns an array of open times that occurred directly after + the driver TTLs. """ + WARN_THRESH = 10e-3 # open time threshold below which to log warning tl = self.timeline info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 - offset = np.percentile(values, floor_percentile, axis=0) - idx = falls(values - offset, step=threshold) # Voltage falls when valve opens - open_times = tl['timestamps'][idx] + + # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz + # making simple thresholding an issue. For this reason we convolve the signal with a + # window and detect the peaks and troughs. + if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz + _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) + dt = 1e-3 # change in voltage takes ~1ms when changing valve open state + N = dt / (1 / Fs) # this means voltage change occurs over N samples + vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time + ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) + downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) + + # Convert these times into intervals + ixs = np.argsort(np.r_[downs, ups]) # sort indices + times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times + polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted + missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN + times = np.insert(times, missing + int(polarities[0] == -1), np.nan) + if polarities[-1] == -1: # ensure ends with a valve close + times = np.r_[times, np.nan] + if polarities[0] == 1: # ensure starts with a valve open + # It seems it can start out at -5V (open), then when the reward happens it closes and + # immediately opens. In this case we insert discard the first open time. + times = np.r_[np.nan, times] + intervals = times.reshape(-1, 2) + + # Log warning of improbably short intervals + short = np.sum(np.diff(intervals) < WARN_THRESH) + if short > 0: + _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) + # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL if driver_ttls is not None: # Returns an array of open_times indices, one for each driver TTL - ind = attribute_times(open_times, driver_ttls, tol=.1, take='after') - open_times = open_times[ind[ind >= 0]] + ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') + open_times = intervals[ind[ind >= 0], 0] # TODO Log any > 40ms? Difficult to report missing valve times because of calibration if display: fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) - ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), 'k-o') + ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') if driver_ttls is not None: - vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') - ax1.plot(tl['timestamps'], values - offset, 'k-o') + x = np.empty_like(driver_ttls.flatten()) + x[0::2] = driver_ttls[:, 0] + x[1::2] = driver_ttls[:, 1] + y = np.ones_like(x) + y[1::2] -= 2 + squares(x, y, ax=ax0, yrange=[0, 5]) + # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') + ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') + ax1.plot(tl['timestamps'], values, 'k-o') ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') - ax1.plot(tl['timestamps'][idx], np.zeros_like(idx), 'r*') - if driver_ttls is not None: - ax1.plot(open_times, np.zeros_like(open_times), 'g*') - return open_times + + ax2 = ax1.twinx() + ax2.set_ylabel('dV', color='grey') + ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') + ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') + ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') + return intervals if driver_ttls is None else (intervals, open_times) def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ @@ -360,7 +537,7 @@ def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) - # take only even time differences: ie. from rising to falling fronts + # take only even time differences: i.e. from rising to falling fronts dt = np.diff(audio_times) onsets = audio_polarities[:-1] == 1 diff --git a/ibllib/io/raw_daq_loaders.py b/ibllib/io/raw_daq_loaders.py index add980130..8ac58c3e7 100644 --- a/ibllib/io/raw_daq_loaders.py +++ b/ibllib/io/raw_daq_loaders.py @@ -292,7 +292,7 @@ def extract_sync_timeline(timeline, chmap=None, floor_percentile=10, threshold=N # Bidirectional; extract indices where delta != 0 raw = correct_counter_discontinuities(raw) d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) sync.polarities = np.concatenate((sync.polarities, np.sign(d[ind]).astype('i1'))) ind += 1 else: From 6b1416cb9a03dc3f2cd7dd064ab6830ae532666d Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 18:17:58 +0200 Subject: [PATCH 5/7] Build trials after syncing Bpod clock --- ibllib/io/extractors/ephys_fpga.py | 301 ++++++++++++++++------------ ibllib/io/extractors/mesoscope.py | 136 +++++++------ ibllib/pipes/ephys_preprocessing.py | 11 +- 3 files changed, 252 insertions(+), 196 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index aa042ce8e..3c805293f 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -36,6 +36,7 @@ import uuid import re import warnings +from functools import partial import matplotlib.pyplot as plt from matplotlib.colors import TABLEAU_COLORS @@ -917,8 +918,9 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', Below are the steps involved: 0. Load sync and bpod trials, if required. 1. Determine protocol period and discard sync events outside the task. - 2. Classify and attribute DAQ TTLs to trial events (see :meth:`FpgaTrials.extract_behaviour_sync`). + 2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`). 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. + 4. Assign classified TTL events to trial events based on order within the trial. 4. Convert Bpod software event times to DAQ clock. 5. Extract the wheel from the DAQ rotary encoder signal, if required. @@ -989,11 +991,8 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) - # Get the trial events from the DAQ sync TTLs - fpga_trials = self.extract_behaviour_sync(sync, chmap, **kwargs) - - # Sync clocks and build final trials datasets - out = self.build_trials(fpga_trials, sync=sync, chmap=chmap, **kwargs) + # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets + out = self.build_trials(sync=sync, chmap=chmap, **kwargs) # extract the wheel data if any(x.startswith('wheel') for x in self.var_names): @@ -1016,7 +1015,7 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', assert self.var_names == tuple(out.keys()) return out - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. @@ -1042,10 +1041,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. @@ -1068,50 +1063,58 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' '`bpod_event_ttls` kwarg may be incorrect.') - # The first trial pulse is longer and often assigned to another event. - # Here we move the earliest non-trial_start event to the trial_start array. - t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start - pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] - if pretrial: - (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event - dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log - _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) - bpod_event_intervals['trial_start'] = np.r_[ - bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] - ] - bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] - - t_trial_start = bpod_event_intervals['trial_start'][:, 0] t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - # Some protocols, e.g. Guido's ephys biased opto task, have no trial end TTL. - # This is not essential as the trial start is used to sync the clocks. - if t_trial_end.size == 0: - _logger.warning('No trial end / ITI in TTLs found') - t_trial_end = np.full_like(t_trial_start, np.nan) - else: - # Drop last trial start if incomplete - t_trial_start = t_trial_start[:len(t_trial_end)] - t_valve_open = bpod_event_intervals['valve_open'][:, 0] - t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] - t_error_tone_in = audio_event_intervals['error_tone'][:, 0] - - start_times = start_times or t_trial_start - - trials = alfio.AlfBunch({ - 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first'), - 'errorCue_times': _assign_events_to_trial(start_times, t_error_tone_in), - 'valveOpen_times': _assign_events_to_trial(start_times, t_valve_open), - 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2), - 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first'), - 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times']), - 'itiIn_times': _assign_events_to_trial(start_times, t_iti_in) + fpga_events = alfio.AlfBunch({ + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorCue_times': audio_event_intervals['error_tone'][:, 0], + 'valveOpen_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'itiIn_times': t_iti_in, + 'intervals_0': bpod_event_intervals['trial_start'][:, 0], + 'intervals_1': t_trial_end }) - # feedback times are valve open on correct trials and error tone in on incorrect trials - trials['feedback_times'] = np.copy(trials['valveOpen_times']) - ind_err = np.isnan(trials['valveOpen_times']) - trials['feedback_times'][ind_err] = trials['errorCue_times'][ind_err] - trials['intervals'] = np.c_[start_times, t_trial_end] + # Sync the Bpod clock to the DAQ. + # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is + # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is + # dropped after assigning the FPGA events, using the `ifpga` index. Doing this after + # assigning the FPGA trial events ensures the last trial has the correct timestamps. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod]) + else: + t_trial_start = fpga_events['intervals_0'] + + # Assign the FPGA events to individual trials + fpga_trials = { + 'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'), + 'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']), + 'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']), + 'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']), + 'stimFreeze_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take=-2), + 'stimOn_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take='first'), + 'stimOff_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times']) + } + + # Feedback times are valve open on correct trials and error tone in on incorrect trials + fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) + ind_err = np.isnan(fpga_trials['valveOpen_times']) + fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] + + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + out.update({k: fpga_trials[k][ifpga] for k in fpga_trials.keys()}) if display: # pragma: no cover width = 0.5 @@ -1125,34 +1128,13 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') color_map = TABLEAU_COLORS.keys() - for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)): plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) ax.legend() ax.set_yticks([0, 1, 2, 3]) ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 5]) - return trials - - def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): - # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) - - if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': - # One issue is that sometimes pulses may not have been detected, in this case - # add the events that have not been detected and re-extract the behaviour sync. - # This is only really relevant for the Bpod interval events as the other TTLs are - # from devices where a missing TTL likely means the Bpod event was truly absent. - _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') - bpod_start = self.bpod_trials['intervals'][:, 0] - missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) - t_trial_start = np.sort(np.r_[fpga_trials['intervals'][:, 0], missing_bpod]) - fpga_trials = self.extract_behaviour_sync(sync, chmap, start_times=t_trial_start, **kwargs) - - out = dict() - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) return out def get_wheel_positions(self, *args, **kwargs): @@ -1233,7 +1215,7 @@ def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. - This is handled in the :meth:`FpgaTrials.extract_behaviour_sync` method. + This method accounts for this. Parameters ---------- @@ -1268,6 +1250,22 @@ def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, bpod_event_intervals = self._assign_events( bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0: + return bpod, bpod_event_intervals + + # The first trial pulse is longer and often assigned to another event. + # Here we move the earliest non-trial_start event to the trial_start array. + t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start + pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] + if pretrial: + (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event + dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log + _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) + bpod_event_intervals['trial_start'] = np.r_[ + bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] + ] + bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + return bpod, bpod_event_intervals @staticmethod @@ -1364,8 +1362,8 @@ def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): bpod_trials : dict A dictionary of extracted Bpod trial events. fpga_trials : dict - A dictionary of trial events extracted from FPGA sync events (see - `extract_behaviour_sync` method). + A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync` + method). sync_field : str The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the column index, e.g. 'intervals_0'. @@ -1387,27 +1385,28 @@ def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. """ _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') - if sync_field not in bpod_trials: - # handle syncing on intervals - if not (m := re.match(r'(.*)_(\d)', sync_field)): - raise ValueError(f'Sync field "{sync_field}" not in extracted bpod trials') - sync_field, i = m.groups() - timestamps_bpod = bpod_trials[sync_field][:, int(i)] - timestamps_fpga = fpga_trials[sync_field][:, int(i)] - elif sync_field not in fpga_trials: - raise ValueError(f'Sync field "{sync_field}" not in extracted fpga trials') - else: - timestamps_bpod = bpod_trials[sync_field] - timestamps_fpga = fpga_trials[sync_field] + bpod_fpga_timestamps = [None, None] + for i, trials in enumerate((bpod_trials, fpga_trials)): + if sync_field not in trials: + # handle syncing on intervals + if not (m := re.match(r'(.*)_(\d)', sync_field)): + # If missing from bpod trials, either the sync field is incorrect, + # or the Bpod extractor is incorrect. If missing from the fpga events, check + # the sync field and the `extract_behaviour_sync` method. + raise ValueError( + f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events') + _sync_field, n = m.groups() + bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)] + else: + bpod_fpga_timestamps[i] = trials[sync_field] # Sync the two timestamps - fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( - timestamps_bpod, timestamps_fpga, return_indices=True) + fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True) # If it's drifting too much throw warning or error _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', - len(timestamps_bpod), len(timestamps_fpga), len(ibpod), drift) - if drift > 200 and timestamps_bpod.size != timestamps_fpga.size: + *map(len, bpod_fpga_timestamps), len(ibpod), drift) + if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size: raise err.SyncBpodFpgaException('sync cluster f*ck') elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', @@ -1481,24 +1480,65 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. """ # Version check: the ITI in TTL was added in a later version + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): """A second 1s TTL was added in this version during the 'iti' state, however this is unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') - # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse - if 'bpod_event_ttls' not in kwargs: - kwargs['bpod_event_ttls'] = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, task_collection=task_collection, **kwargs) - n = trials['intervals'].shape[0] # number of trials - trials['intervals'][:, 1] = self.bpod2fpga(self.bpod_trials['intervals'][:n, 1]) - return trials - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse. + Also the first trial pulse is incorrectly assigned due to its abnormal length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse + bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + # The first trial pulse is shorter and assigned to valve_open. Here we remove the first + # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was + # incomplete in Bpod. + bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :], + bpod_event_intervals['trial_iti']] + bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :] + + return bpod, bpod_event_intervals + + def build_trials(self, sync, chmap, display=False, **kwargs): """ Extract task related event times from the sync. @@ -1511,10 +1551,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. display : bool, matplotlib.pyplot.Axes Show the full session sync pulses display. @@ -1532,40 +1568,47 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * raise ValueError( 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') - # The first trial pulse is shorter and assigned to valve_open. Here we remove the first - # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was - # incomplete in Bpod. + fpga_events = alfio.AlfBunch({ + 'feedback_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'intervals_0': bpod_event_intervals['trial_iti'][:, 1], + 'intervals_1': bpod_event_intervals['trial_iti'][:, 0], + 'goCue_times': audio_event_intervals['ready_tone'][:, 0] + }) n_trials = self.bpod_trials['intervals'].shape[0] - t_valve_open = bpod_event_intervals['valve_open'][1:, 0] # drop first spurious valve event - t_ready_tone_in = audio_event_intervals['ready_tone'][:, 0] - t_trial_start = np.r_[0, bpod_event_intervals['trial_iti'][:, 1]] - t_trial_end = bpod_event_intervals['trial_iti'][:, 0] - start_times = start_times or t_trial_start + # Sync the Bpod clock to the DAQ. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN) + assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0']) trials = alfio.AlfBunch({ - 'goCue_times': _assign_events_to_trial(start_times, t_ready_tone_in, take='first')[:n_trials], - 'feedback_times': _assign_events_to_trial(start_times, t_valve_open)[:n_trials], - 'stimFreeze_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2)[:n_trials], - 'stimOn_times': _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first')[:n_trials], - 'stimOff_times': _assign_events_to_trial(start_times, self.frame2ttl['times'])[:n_trials], - # These 'raw' intervals will be used in the sync - 'intervals_1': _assign_events_to_trial(start_times, t_trial_end), - 'intervals_0': start_times + 'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first')[:n_trials], + 'feedback_times': assign_to_trial(fpga_events['feedback_times'])[:n_trials], + 'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2)[:n_trials], + 'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first')[:n_trials], + 'stimOff_times': assign_to_trial(self.frame2ttl['times'])[:n_trials], }) # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off - trials['intervals'] = np.c_[trials['intervals_0'], trials['intervals_1']][:n_trials, :] - to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < trials['intervals'][:, 0]) + to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < out['intervals'][:, 0]) if np.any(to_correct): _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) - trials['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] - to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > trials['intervals'][:, 1]) + out['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] + to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > out['intervals'][:, 1]) if np.any(to_correct): _logger.debug( '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', sum(to_correct), len(to_correct)) - trials['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + out['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + + out.update({k: trials[k][ifpga] for k in trials.keys()}) if display: # pragma: no cover width = 0.5 @@ -1586,7 +1629,7 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) ax.set_ylim([0, 4]) - return trials + return out def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None, @@ -1630,6 +1673,7 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ 'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; ' 'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.', FutureWarning) + return_extractor = kwargs.pop('return_extractor', False) # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' @@ -1646,7 +1690,10 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ task_collection=task_collection, protocol_number=protocol_number, **kwargs) if not isinstance(outputs, dict): outputs = {k: v for k, v in zip(trials.var_names, outputs)} - return outputs, files + if return_extractor: + return outputs, files, trials + else: + return outputs, files def get_sync_and_chn_map(session_path, sync_collection): diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 84a7622e7..e4ca6766b 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -1,6 +1,5 @@ """Mesoscope (timeline) data extraction.""" import logging -from itertools import cycle import numpy as np from scipy.signal import find_peaks @@ -8,7 +7,6 @@ from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from matplotlib.colors import TABLEAU_COLORS from pkg_resources import parse_version from ibllib.plots.misc import squares, vertical_lines @@ -146,26 +144,51 @@ def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwa plot_timeline(self.timeline, channels=chmap.keys(), raw=True) return trials - def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, **kwargs): + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): """ - Extract task related event times from the sync. + Extract Bpod times from sync. + + Unlike the superclass method. This one doesn't reassign the first trial pulse. - TODO Change docstring - The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The - first trial start TTL of the session is longer and must be handled differently. The trial - start TTL is used to assign the other trial events to each trial. + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # The trial start TTLs are often too short for the low sampling rate of the DAQ and are + # therefore not used in extraction + bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod, bpod_event_intervals = super().get_bpod_event_times( + sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) + + # TODO Here we can make use of the 'bpod_rising_edge' channel, if available + return bpod, bpod_event_intervals + + def build_trials(self, sync=None, chmap=None, **kwargs): + """ + Extract task related event times from the sync. - The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest - of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio - tones. The first of these after each trial start is taken to be the go cue time. Error - tones are longer audio TTLs and assigned as the last of such occurrence after each trial - start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. - The feedback times are times of either valve open or error tone as there should be only one - such event per trial. + The two major differences are that the sampling rate is lower for imaging so the short Bpod + trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. - The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs - removed): the first TTL after each trial start is assumed to be the stim onset time; the - second to last and last are taken as the stimulus freeze and offset times, respectively. + Second, the valve used at the mesoscope has a way to record the raw voltage across the + solenoid, giving a more accurate readout of the valve's activity. If the reward_valve + channel is present on the DAQ, this is used to extract the valve open times. Parameters ---------- @@ -173,12 +196,6 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' chmap : dict Map of channel names and their corresponding index. Default to constant. - start_times : numpy.array - An optional array of timestamps to separate trial events by. This is useful if after - syncing the clocks, some trial start TTLs are found to be missed. If None, uses - 'trial_start' Bpod event. - display : bool, matplotlib.pyplot.Axes - Show the full session sync pulses display. Returns ------- @@ -201,46 +218,36 @@ def extract_behaviour_sync(self, sync, chmap, start_times=None, display=False, * '`bpod_event_ttls` kwarg may be incorrect.') t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T - trials = alfio.AlfBunch({ + fpga_events = alfio.AlfBunch({ 'itiIn_times': t_iti_in, 'intervals_1': t_trial_end, - 'valveOpen_intervals': bpod_event_intervals['valve_open'], 'goCue_times': audio_event_intervals['ready_tone'][:, 0], 'errorTone_times': audio_event_intervals['error_tone'][:, 0] }) - if display: # pragma: no cover - width = 0.5 - ymax = 5 - if isinstance(display, bool): - plt.figure('Bpod FPGA Sync') - ax = plt.gca() - else: - ax = display - squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') - squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') - squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') - color_map = TABLEAU_COLORS.keys() - for (event_name, event_times), c in zip(trials.items(), cycle(color_map)): - vertical_lines(event_times.flat, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) - ax.legend() - ax.set_yticks([0, 1, 2, 3]) - ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) - ax.set_ylim([0, 4]) - - return trials - - def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): # Sync the Bpod clock to the DAQ - self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_trials, self.sync_field) + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) out = dict() - out['intervals'] = self.bpod2fpga(self.bpod_trials['intervals']) - out['itiIn_times'] = fpga_trials['itiIn_times'][ifpga] + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + start_times = out['intervals'][:, 0] + last_trial_end = out['intervals'][-1, 1] + + def assign_to_trial(events, take='last'): + """Assign DAQ events to trials. + + Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), + there may be an extra last trial that's not in the Bpod intervals as the extractor + ignores the last trial. This function trims the input array before assigning so that + the last trial's events are correctly assigned. + """ + return _assign_events_to_trial(start_times, events[events <= last_trial_end], take) + out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) # Extract valve open times from the DAQ - valve_driver_ttls = fpga_trials.pop('valveOpen_intervals') + valve_driver_ttls = bpod_event_intervals['valve_open'] correct = self.bpod_trials['feedbackType'] == 1 # If there is a reward_valve channel, the valve has if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): @@ -260,25 +267,25 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): 'Number of valve open times does not equal number of correct trials (%i != %i)', valve_open_times.size, np.sum(correct)) - out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_open_times) + out['valveOpen_times'] = assign_to_trial(valve_open_times) else: # Use the valve controller TTLs recorded on the Bpod channel as the reward time - out['valveOpen_times'] = _assign_events_to_trial(start_times, valve_driver_ttls[:, 0]) + out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) # Stimulus times extracted the same as usual - out['stimFreeze_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take=-2) - out['stimOn_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times'], take='first') - out['stimOff_times'] = _assign_events_to_trial(start_times, self.frame2ttl['times']) + out['stimFreeze_times'] = assign_to_trial(self.frame2ttl['times'], take=-2) + out['stimOn_times'] = assign_to_trial(self.frame2ttl['times'], take='first') + out['stimOff_times'] = assign_to_trial(self.frame2ttl['times']) # Audio times - error_cue = fpga_trials['errorTone_times'] + error_cue = fpga_events['errorTone_times'] if error_cue.size != np.sum(~correct): _logger.warning( 'N detected error tones does not match number of incorrect trials (%i != %i)', error_cue.size, np.sum(~correct)) - go_cue = fpga_trials['goCue_times'] - out['goCue_times'] = _assign_events_to_trial(start_times, go_cue, take='first') - out['errorCue_times'] = _assign_events_to_trial(start_times, error_cue) + go_cue = fpga_events['goCue_times'] + out['goCue_times'] = assign_to_trial(go_cue, take='first') + out['errorCue_times'] = assign_to_trial(error_cue) if go_cue.size > start_times.size: _logger.warning( @@ -296,8 +303,8 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue - _go_cue = _assign_events_to_trial(start_times, go_cue, take='first') - error_cue = _assign_events_to_trial(start_times, error_cue) + _go_cue = assign_to_trial(go_cue, take='first') + error_cue = assign_to_trial(error_cue) missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH @@ -334,9 +341,6 @@ def build_trials(self, fpga_trials, sync=None, chmap=None, **kwargs): ind_err = np.isnan(out['valveOpen_times']) out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - return out def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 26cef7050..09591ce2d 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -694,7 +694,8 @@ def _behaviour_criterion(self): ) def _extract_behaviour(self): - dsets, out_files = ephys_fpga.extract_all(self.session_path, save=True) + dsets, out_files, self.extractor = ephys_fpga.extract_all( + self.session_path, save=True, return_extractor=True) return dsets, out_files @@ -709,8 +710,12 @@ def _run(self, plot_qc=True): qc = TaskQC(self.session_path, one=self.one, log=_logger) qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one) # Extract extra datasets required for QC - qc.extractor.data = dsets - qc.extractor.extract_data() + qc.extractor.data = qc.extractor.rename_data(dsets) + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc.extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc.extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc.extractor.wheel_encoding = 'X4' + # Aggregate and update Alyx QC fields qc.run(update=True) From bac237450b6b1011c336af7d564f73311005e45d Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 18:48:12 +0200 Subject: [PATCH 6/7] Include wheel in Bpod trials dict passed to FpgaTrials --- ibllib/io/extractors/ephys_fpga.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 3c805293f..187d216f6 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -1677,14 +1677,14 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' - bpod_trials, *_ = bpod_extract_all( + bpod_trials, bpod_wheel, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) # Sync Bpod trials to FPGA sync, chmap = get_sync_and_chn_map(session_path, sync_collection) # sync, chmap = get_main_probe_sync(session_path, bin_exists=bin_exists) - trials = FpgaTrials(session_path, bpod_trials=bpod_trials) + trials = FpgaTrials(session_path, bpod_trials=bpod_trials | bpod_wheel) outputs, files = trials.extract( save=save, sync=sync, chmap=chmap, path_out=save_path, task_collection=task_collection, protocol_number=protocol_number, **kwargs) From ba2553dbb757d5049ec39da5f9f6b3d88baa6ab3 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 8 Dec 2023 19:18:05 +0200 Subject: [PATCH 7/7] Add more fields to qc extractor --- ibllib/pipes/ephys_preprocessing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 09591ce2d..7ea845d18 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -715,6 +715,10 @@ def _run(self, plot_qc=True): qc.extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod qc.extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] qc.extractor.wheel_encoding = 'X4' + qc.extractor.settings = self.extractor.settings + qc.extractor.frame_ttls = self.extractor.frame2ttl + qc.extractor.audio_ttls = self.extractor.audio + qc.extractor.bpod_ttls = self.extractor.bpod # Aggregate and update Alyx QC fields qc.run(update=True)