Skip to content

Commit

Permalink
training status - only extract bpod data, don't run qc and don't sync
Browse files Browse the repository at this point in the history
  • Loading branch information
mayofaulkner committed Nov 30, 2023
1 parent 822f183 commit def1353
Showing 1 changed file with 42 additions and 63 deletions.
105 changes: 42 additions & 63 deletions ibllib/pipes/training_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from ibllib.io.raw_data_loaders import load_bpod
from ibllib.oneibl.registration import _get_session_times
from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type
from ibllib.io.extractors.base import get_session_extractor_type
from ibllib.io.session_params import read_params
import ibllib.pipes.dynamic_pipeline as dyn
from ibllib.io.extractors.bpod_trials import get_bpod_extractor

from iblutil.util import setup_logger
from ibllib.plots.snapshot import ReportSnapshot
Expand All @@ -22,6 +22,7 @@
import seaborn as sns
import boto3
from botocore.exceptions import ProfileNotFound, ClientError
from itertools import chain

logger = setup_logger(__name__)

Expand Down Expand Up @@ -87,43 +88,6 @@ def upload_training_table_to_aws(lab, subject):
return


def get_trials_task(session_path, one):
# If experiment description file then process this
experiment_description_file = read_params(session_path)
if experiment_description_file is not None:
tasks = []
pipeline = dyn.make_pipeline(session_path)
trials_tasks = [t for t in pipeline.tasks if 'Trials' in t]
for task in trials_tasks:
t = pipeline.tasks.get(task)
t.__init__(session_path, **t.kwargs)
tasks.append(t)
else:
# Otherwise default to old way of doing things
pipeline = get_pipeline(session_path)
if pipeline == 'training':
from ibllib.pipes.training_preprocessing import TrainingTrials
tasks = [TrainingTrials(session_path)]
elif pipeline == 'ephys':
from ibllib.pipes.ephys_preprocessing import EphysTrials
tasks = [EphysTrials(session_path)]
else:
try:
# try and look if there is a custom extractor in the personal projects extraction class
import projects.base
task_type = get_session_extractor_type(session_path)
PipelineClass = projects.base.get_pipeline(task_type)
pipeline = PipelineClass(session_path, one)
trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task)
task = pipeline.tasks.get(trials_task_name)
task.__init__(session_path)
tasks = [task]
except Exception:
tasks = []

return tasks


def save_path(subj_path):
return Path(subj_path).joinpath('training.csv')

Expand Down Expand Up @@ -155,7 +119,7 @@ def load_existing_dataframe(subj_path):
def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
"""
Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
if this also fails, will then attempt to re-extraxt locally
if this also fails, will then attempt to re-extract locally
:param sess_path: session path
:param one: ONE instance
:param force: when True and if the session trials can't be found, will attempt to re-extract from the disk
Expand Down Expand Up @@ -207,19 +171,24 @@ def load_trials(sess_path, one, collections=None, force=True, mode='raise'):
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except Exception:
# Finally try to rextract the trials data locally
# Finally try to re-extract the trials data locally
try:
# Get the tasks that need to be run
tasks = get_trials_task(sess_path, one)
if len(tasks) > 0:
for task in tasks:
status = task.run()
if status == 0:
return load_trials(sess_path, collections=collections, one=one, force=False)
else:
return
raw_collections, _ = get_data_collection(sess_path)

if len(raw_collections) == 0:
return None

trials_dict = {}
for i, collection in enumerate(raw_collections):
extractor = get_bpod_extractor(sess_path, task_collection=collection)
trials_data, _ = extractor.extract(task_collection=collection, save=False)
trials_dict[i] = alfio.AlfBunch.from_df(trials_data['table'])

if len(trials_dict) > 1:
trials = training.concatenate_trials(trials_dict)
else:
trials = None
trials = trials_dict[0]

except Exception as e:
if mode == 'raise':
raise Exception(f'Exhausted all possibilities for loading trials for {sess_path}') from e
Expand Down Expand Up @@ -468,20 +437,29 @@ def get_data_collection(session_path):
:param session_path: path of session
:return:
"""
experiment_description_file = read_params(session_path)
if experiment_description_file is not None:
pipeline = dyn.make_pipeline(session_path)
trials_tasks = [t for t in pipeline.tasks if 'Trials' in t]
collections = [pipeline.tasks.get(task).kwargs['collection'] for task in trials_tasks]
if len(collections) == 1 and collections[0] == 'raw_behavior_data':
alf_collections = ['alf']
elif all(['raw_task_data' in c for c in collections]):
alf_collections = [f'alf/task_{c[-2:]}' for c in collections]
else:
alf_collections = None
experiment_description = read_params(session_path)
collections = []
if experiment_description is not None:
task_protocols = experiment_description.get('tasks', [])
for i, (protocol, task_info) in enumerate(chain(*map(dict.items, task_protocols))):
if 'passiveChoiceWorld' in protocol:
continue
collection = task_info.get('collection', f'raw_task_data_{i:02}')
if collection == 'raw_passive_data':
continue
collections.append(collection)
else:
collections = ['raw_behavior_data']
settings = Path(session_path).rglob('_iblrig_taskSettings.raw.json')
for setting in settings:
if setting.parent.name != 'raw_passive_data':
collections.append(setting.parent.name)

if len(collections) == 1 and collections[0] == 'raw_behavior_data':
alf_collections = ['alf']
elif all(['raw_task_data' in c for c in collections]):
alf_collections = [f'alf/task_{c[-2:]}' for c in collections]
else:
alf_collections = None

return collections, alf_collections

Expand Down Expand Up @@ -561,6 +539,7 @@ def get_training_info_for_session(session_paths, one, force=True):

un_protocols = np.unique(protocols)
# Example, training, training, biased - training would be combined, biased not
sess_dict = None
if len(un_protocols) != 1:
print(f'Different protocols in same session {session_path} : {protocols}')
for prot in un_protocols:
Expand Down

0 comments on commit def1353

Please sign in to comment.