Skip to content

Commit

Permalink
FpgaTrials refactor and FpgaHabituationTrials subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Nov 23, 2023
1 parent 036d6ea commit 0159729
Show file tree
Hide file tree
Showing 10 changed files with 1,092 additions and 302 deletions.
9 changes: 9 additions & 0 deletions ibllib/io/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC):
"""

session_path = None
"""pathlib.Path: Absolute path of session folder."""

save_names = None
"""tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""

var_names = None
"""tuple of str: A list of names for the extracted variables. These become the returned output keys."""

default_path = Path('alf') # relative to session
"""pathlib.Path: The default output folder relative to `session_path`."""

def __init__(self, session_path=None):
# If session_path is None Path(session_path) will fail
Expand Down Expand Up @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor):
bpod_trials = None
settings = None
task_collection = None
frame2ttl = None
audio = None

def extract(self, bpod_trials=None, settings=None, **kwargs):
"""
Expand Down
2 changes: 2 additions & 0 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
# Get all detected TTLs. These are stored for QC purposes
self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials)
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
Expand Down
866 changes: 787 additions & 79 deletions ibllib/io/extractors/ephys_fpga.py

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions ibllib/io/extractors/habituation_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def _extract(self) -> dict:
There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
If 1 or more pulses are missing, we can not be confident of assigning the correct one.
"""
out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan
for sync, off in zip(ttls[1:], ends)])
out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]])

# Trial intervals
"""
Expand All @@ -85,8 +84,13 @@ def _extract(self) -> dict:
# NB: We lose the last trial because the stim off event occurs at trial_num + 1
n_trials = out['stimOff_times'].size
out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :]
to_update = out['intervals'][:, 1] < out['stimOff_times']
out['intervals'][to_update, 1] = out['stimOff_times'][to_update]

to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1])
if np.any(to_correct):
_logger.debug(
'%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
sum(to_correct), len(to_correct))
out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct]

# itiIn times
out['itiIn_times'] = np.r_[ends, np.nan]
Expand Down Expand Up @@ -133,11 +137,8 @@ def _extract(self) -> dict:

# Double-check that the early and late trial events occur within the trial intervals
idx = ~np.isnan(out['stimOn_times'][:n_trials])
if np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]):
_logger.warning('Stim on events occurring outside trial intervals')
idx = ~np.isnan(out['stimOff_times'])
if np.any(out['stimOff_times'][idx] > out['intervals'][idx, 1]):
_logger.warning('Stim off events occurring outside trial intervals')
assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \
'Stim on events occurring outside trial intervals'

# Truncate arrays and return in correct order
return {k: out[k][:n_trials] for k in self.var_names}
Expand Down
1 change: 0 additions & 1 deletion ibllib/io/session_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def iter_dict(d):
for d in filter(lambda x: isinstance(x, dict), v):
iter_dict(d)
elif isinstance(v, dict) and 'collection' in v:
print(k)
# if the key already exists, append the collection name to the list
if k in collection_map:
clist = collection_map[k] if isinstance(collection_map[k], list) else [collection_map[k]]
Expand Down
64 changes: 56 additions & 8 deletions ibllib/pipes/behavior_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ibllib.qc.task_metrics import HabituationQC, TaskQC
from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map
from ibllib.io.extractors.ephys_fpga import FpgaTrials, FpgaTrialsHabituation, get_sync_and_chn_map
from ibllib.io.extractors.mesoscope import TimelineTrials
from ibllib.pipes import training_status
from ibllib.plots.figures import BehaviourPlots
Expand Down Expand Up @@ -102,14 +102,61 @@ def _run_qc(self, trials_data=None, update=True):
qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection,
one=self.one, sync_type=self.sync, task_collection=self.collection)

# Currently only the data field is accessed
# Update extractor fields
qc.extractor.data = qc.extractor.rename_data(trials_data.copy())
qc.extractor.frame_ttls = self.extractor.frame2ttl # used in iblapps QC viewer
qc.extractor.audio_ttls = self.extractor.audio # used in iblapps QC viewer
qc.extractor.settings = self.extractor.settings

namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}'
qc.run(update=update, namespace=namespace)
return qc


class HabituationTrialsNidq(HabituationTrialsBpod):
priority = 90
job_size = 'small'

@property
def signature(self):
signature = super().signature
signature['input_files'] = [
('_iblrig_taskData.raw.*', self.collection, True),
('_iblrig_taskSettings.raw.*', self.collection, True),
(f'_{self.sync_namespace}_sync.channels.npy', self.sync_collection, True),
(f'_{self.sync_namespace}_sync.polarities.npy', self.sync_collection, True),
(f'_{self.sync_namespace}_sync.times.npy', self.sync_collection, True),
('*wiring.json', self.sync_collection, False),
('*.meta', self.sync_collection, True)]
return signature

def _extract_behaviour(self, save=True, **kwargs):
"""Extract the habituationChoiceWorld trial data using NI DAQ clock."""
# Extract Bpod trials
bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs)

# Sync Bpod trials to FPGA
sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection)
self.extractor = FpgaTrialsHabituation(
self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor)

# NB: The stimOff times are called stimCenter times for habituation choice world
outputs, files = self.extractor.extract(
save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection),
task_collection=self.collection, protocol_number=self.protocol_number, **kwargs)
return outputs, files

def _run_qc(self, trials_data=None, update=True, **_):
"""Run and update QC.
This adds the bpod TTLs to the QC object *after* the QC is run in the super call method.
The raw Bpod TTLs are not used by the QC however they are used in the iblapps QC plot.
"""
qc = super()._run_qc(trials_data=trials_data, update=update)
qc.extractor.bpod_ttls = self.extractor.bpod
return qc


class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask):
priority = 100
job_size = 'small'
Expand Down Expand Up @@ -286,9 +333,9 @@ def _run_qc(self, trials_data=None, update=True):
else:
qc = TaskQC(self.session_path, one=self.one, log=_logger)
qc_extractor.wheel_encoding = 'X1'
qc_extractor.settings = self.extractor.settings
qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts(
self.session_path, task_collection=self.collection)
qc_extractor.settings = self.extractor.settings
qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts(
self.session_path, task_collection=self.collection)
qc.extractor = qc_extractor

# Aggregate and update Alyx QC fields
Expand Down Expand Up @@ -370,14 +417,15 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False):
qc = HabituationQC(self.session_path, one=self.one, log=_logger)
else:
qc = TaskQC(self.session_path, one=self.one, log=_logger)
qc_extractor.settings = self.extractor.settings
# Add Bpod wheel data
wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps'])
qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod
qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position']
qc_extractor.wheel_encoding = 'X4'
qc_extractor.frame_ttls = self.extractor.frame2ttl
qc_extractor.audio_ttls = self.extractor.audio
qc_extractor.frame_ttls = self.extractor.frame2ttl
qc_extractor.audio_ttls = self.extractor.audio
qc_extractor.bpod_ttls = self.extractor.bpod
qc_extractor.settings = self.extractor.settings
qc.extractor = qc_extractor

# Aggregate and update Alyx QC fields
Expand Down
28 changes: 17 additions & 11 deletions ibllib/pipes/dynamic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,28 @@ def make_pipeline(session_path, **pkwargs):
# - choice_world_biased
# - choice_world_training
# - choice_world_habituation
if 'habituation' in protocol:
registration_class = btasks.HabituationRegisterRaw
behaviour_class = btasks.HabituationTrialsBpod
compute_status = False
elif 'passiveChoiceWorld' in protocol:
if 'passiveChoiceWorld' in protocol:
registration_class = btasks.PassiveRegisterRaw
behaviour_class = btasks.PassiveTask
compute_status = False
elif sync_kwargs['sync'] == 'bpod':
registration_class = btasks.TrialRegisterRaw
behaviour_class = btasks.ChoiceWorldTrialsBpod
compute_status = True
if 'habituation' in protocol:
registration_class = btasks.HabituationRegisterRaw
behaviour_class = btasks.HabituationTrialsBpod
compute_status = False
else:
registration_class = btasks.TrialRegisterRaw
behaviour_class = btasks.ChoiceWorldTrialsBpod
compute_status = True
elif sync_kwargs['sync'] == 'nidq':
registration_class = btasks.TrialRegisterRaw
behaviour_class = btasks.ChoiceWorldTrialsNidq
compute_status = True
if 'habituation' in protocol:
registration_class = btasks.HabituationRegisterRaw
behaviour_class = btasks.HabituationTrialsNidq
compute_status = False
else:
registration_class = btasks.TrialRegisterRaw
behaviour_class = btasks.ChoiceWorldTrialsNidq
compute_status = True
else:
raise NotImplementedError
tasks[f'RegisterRaw_{protocol}_{i:02}'] = type(f'RegisterRaw_{protocol}_{i:02}', (registration_class,), {})(
Expand Down
29 changes: 19 additions & 10 deletions ibllib/qc/task_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Behaviour QC
"""Behaviour QC.
This module runs a list of quality control metrics on the behaviour data.
Examples
Expand Down Expand Up @@ -179,20 +180,22 @@ def run(self, update=False, namespace='task', **kwargs):
return outcome, results

@staticmethod
def compute_session_status_from_dict(results):
def compute_session_status_from_dict(results, criteria=None):
"""
Given a dictionary of results, computes the overall session QC for each key and aggregates
in a single value
:param results: a dictionary of qc keys containing (usually scalar) values
:param results: a dictionary of qc keys containing (usually scalar) values.
:param criteria: a dictionary of qc keys containing map of PASS, WARNING, FAIL thresholds.
:return: Overall session QC outcome as a string
:return: A dict of QC tests and their outcomes
"""
indices = np.zeros(len(results), dtype=int)
criteria = criteria or TaskQC.criteria
for i, k in enumerate(results):
if k in TaskQC.criteria.keys():
indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria[k])
if k in criteria.keys():
indices[i] = TaskQC._thresholding(results[k], thresholds=criteria[k])
else:
indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria['default'])
indices[i] = TaskQC._thresholding(results[k], thresholds=criteria['default'])

def key_map(x):
return 'NOT_SET' if x < 0 else list(TaskQC.criteria['default'].keys())[x]
Expand All @@ -213,22 +216,27 @@ def compute_session_status(self):
# Get mean passed of each check, or None if passed is None or all NaN
results = {k: None if v is None or np.isnan(v).all() else np.nanmean(v)
for k, v in self.passed.items()}
session_outcome, outcomes = self.compute_session_status_from_dict(results)
session_outcome, outcomes = self.compute_session_status_from_dict(results, self.criteria)
return session_outcome, results, outcomes


class HabituationQC(TaskQC):

def compute(self, download_data=None):
"""Compute and store the QC metrics
criteria = dict()
criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022
criteria['_task_phase_distribution'] = {'PASS': 0.99, 'NOT_SET': 0} # This rarely passes due to low trial num

def compute(self, download_data=None, **kwargs):
"""Compute and store the QC metrics.
Runs the QC on the session and stores a map of the metrics for each datapoint for each
test, and a map of which datapoints passed for each test
:return:
"""
if self.extractor is None:
# If download_data is None, decide based on whether eid or session path was provided
ensure_data = self.download_data if download_data is None else download_data
self.load_data(download_data=ensure_data)
self.load_data(download_data=ensure_data, **kwargs)
self.log.info(f'Session {self.session_path}: Running QC on habituation data...')

# Initialize checks
Expand Down Expand Up @@ -302,6 +310,7 @@ def compute(self, download_data=None):
passed[check] = (metric <= 2 * np.pi) & (metric >= 0)
metrics[check] = metric

# This is not very useful as a check because there are so few trials
check = prefix + 'phase_distribution'
metric, _ = np.histogram(data['phase'])
_, p = chisquare(metric)
Expand Down
Loading

0 comments on commit 0159729

Please sign in to comment.