Skip to content

Commit

Permalink
Merge pull request #673 from int-brain-lab/HabituationNidq
Browse files Browse the repository at this point in the history
Habituation nidq
  • Loading branch information
k1o0 authored Dec 8, 2023
2 parents dd7b087 + ba2553d commit e0cac72
Show file tree
Hide file tree
Showing 14 changed files with 1,512 additions and 434 deletions.
9 changes: 9 additions & 0 deletions ibllib/io/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC):
"""

session_path = None
"""pathlib.Path: Absolute path of session folder."""

save_names = None
"""tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""

var_names = None
"""tuple of str: A list of names for the extracted variables. These become the returned output keys."""

default_path = Path('alf') # relative to session
"""pathlib.Path: The default output folder relative to `session_path`."""

def __init__(self, session_path=None):
# If session_path is None Path(session_path) will fail
Expand Down Expand Up @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor):
bpod_trials = None
settings = None
task_collection = None
frame2ttl = None
audio = None

def extract(self, bpod_trials=None, settings=None, **kwargs):
"""
Expand Down
2 changes: 2 additions & 0 deletions ibllib/io/extractors/biased_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor):
def _extract(self, extractor_classes=None, **kwargs) -> dict:
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
# Get all detected TTLs. These are stored for QC purposes
self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials)
# Exclude from trials table
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
save=False, task_collection=self.task_collection)
Expand Down
6 changes: 5 additions & 1 deletion ibllib/io/extractors/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,16 @@ def attribute_times(arr, events, tol=.1, injective=True, take='first'):
Returns
-------
numpy.array
An array the same length as `events`.
An array the same length as `events` containing indices of `arr` corresponding to each
event.
"""
if (take := take.lower()) not in ('first', 'nearest', 'after'):
raise ValueError('Parameter `take` must be either "first", "nearest", or "after"')
stack = np.ma.masked_invalid(arr, copy=False)
stack.fill_value = np.inf
# If there are no invalid values, the mask is False so let's ensure it's a bool array
if stack.mask is np.bool_(0):
stack.mask = np.zeros(arr.shape, dtype=bool)
assigned = np.full(events.shape, -1, dtype=int) # Initialize output array
min_tol = 0 if take == 'after' else -tol
for i, x in enumerate(events):
Expand Down
938 changes: 853 additions & 85 deletions ibllib/io/extractors/ephys_fpga.py

Large diffs are not rendered by default.

94 changes: 64 additions & 30 deletions ibllib/io/extractors/habituation_trials.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Habituation ChoiceWorld Bpod trials extraction."""
import logging
import numpy as np

import ibllib.io.raw_data_loaders as raw
from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
from ibllib.io.extractors.biased_trials import ContrastLR
from ibllib.io.extractors.training_trials import (
FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes
)
from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes

_logger = logging.getLogger(__name__)

Expand All @@ -24,9 +23,24 @@ def __init__(self, *args, **kwargs):
self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names)

def _extract(self) -> dict:
"""
Extract the Bpod trial events.
The Bpod state machine for this task has extremely misleading names! The 'iti' state is
actually the delay between valve open and trial end (the stimulus is still present during
this period), and the 'trial_start' state is actually the ITI during which there is a 1s
Bpod TTL and gray screen period.
Returns
-------
dict
A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute.
"""
# Extract all trials...

# Get all stim_sync events detected
# Get all detected TTLs. These are stored for QC purposes
self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials)
# These are the frame2TTL pulses as a list of lists, one per trial
ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials]

# Report missing events
Expand All @@ -38,10 +52,49 @@ def _extract(self) -> dict:
_logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)')

# Extract datasets common to trainingChoiceWorld
training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes]
training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes]
out, _ = run_extractor_classes(training, session_path=self.session_path, save=False,
bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)

"""
The 'trial_start' state is in fact the 1s grey screen period, therefore the first timestamp
is really the end of the previous trial and also the stimOff trigger time. The second
timestamp is the true trial start time.
"""
(_, *ends), starts = zip(*[
t['behavior_data']['States timestamps']['trial_start'][-1] for t in self.bpod_trials]
)

# StimOffTrigger times
out['stimOffTrigger_times'] = np.array(ends)

# StimOff times
"""
There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
If 1 or more pulses are missing, we can not be confident of assigning the correct one.
"""
out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]])

# Trial intervals
"""
In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim
off time often happens after the trial end TTL front, i.e. after the 'trial_start' start
begins. For these trials, we set the trial end time as the stim off time.
"""
# NB: We lose the last trial because the stim off event occurs at trial_num + 1
n_trials = out['stimOff_times'].size
out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :]

to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1])
if np.any(to_correct):
_logger.debug(
'%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
sum(to_correct), len(to_correct))
out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct]

# itiIn times
out['itiIn_times'] = np.r_[ends, np.nan]

# GoCueTriggerTimes is the same event as StimOnTriggerTimes
out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy()

Expand Down Expand Up @@ -75,38 +128,19 @@ def _extract(self) -> dict:
trial_volume = [x['reward_amount'] for x in self.bpod_trials]
out['rewardVolume'] = np.array(trial_volume).astype(np.float64)

# StimOffTrigger times
# StimOff occurs at trial start (ignore the first trial's state update)
out['stimOffTrigger_times'] = np.array(
[tr["behavior_data"]["States timestamps"]
["trial_start"][0][0] for tr in self.bpod_trials[1:]]
)

# StimOff times
"""
There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse.
If 1 or more pulses are missing, we can not be confident of assigning the correct one.
"""
trigg = out['stimOffTrigger_times']
out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan
for sync, off in zip(ttls[1:], trigg)])

# FeedbackType is always positive
out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8)

# ItiIn times
out['itiIn_times'] = np.array(
[tr["behavior_data"]["States timestamps"]
["iti"][0][0] for tr in self.bpod_trials]
)

# Phase and position
out['position'] = np.array([t['position'] for t in self.bpod_trials])
out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials])

# NB: We lose the last trial because the stim off event occurs at trial_num + 1
n_trials = out['stimOff_times'].size
# return [out[k][:n_trials] for k in self.var_names]
# Double-check that the early and late trial events occur within the trial intervals
idx = ~np.isnan(out['stimOn_times'][:n_trials])
assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \
'Stim on events occurring outside trial intervals'

# Truncate arrays and return in correct order
return {k: out[k][:n_trials] for k in self.var_names}


Expand Down
Loading

0 comments on commit e0cac72

Please sign in to comment.