From 1186fbc4093a78648c97729e62c83ec7ebdd2a5d Mon Sep 17 00:00:00 2001 From: weiglszonja Date: Thu, 3 Oct 2024 16:21:45 +0200 Subject: [PATCH] update to use date_index --- .../mah_2024/mah_2024_convert_all_sessions.py | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/src/constantinople_lab_to_nwb/mah_2024/mah_2024_convert_all_sessions.py b/src/constantinople_lab_to_nwb/mah_2024/mah_2024_convert_all_sessions.py index e05aae6..2e88d00 100644 --- a/src/constantinople_lab_to_nwb/mah_2024/mah_2024_convert_all_sessions.py +++ b/src/constantinople_lab_to_nwb/mah_2024/mah_2024_convert_all_sessions.py @@ -2,8 +2,9 @@ from datetime import datetime from pathlib import Path from typing import Union, List +from warnings import warn -from nwbinspector import inspect_all, format_messages, save_report +import pandas as pd from pymatreader import read_mat from tqdm import tqdm @@ -71,6 +72,50 @@ def _get_sessions_to_convert_from_mat( return bpod_files_to_convert +def _get_date_index(bpod_file_path: Union[str, Path], a_struct_file_path: Union[str, Path]) -> Union[int, None]: + """ + Figure out the date index for the processed behavior file. + + Parameters + ---------- + bpod_file_path: Union[str, Path] + Path to the raw Bpod output (.mat file). + a_struct_file_path: Union[str, Path] + Path to the processed behavior data (.mat file). + + Returns + ------- + int + The date index for the processed behavior file. + """ + bpod_data = read_mat(str(bpod_file_path)) + try: + bpod_session_data = bpod_data["SessionData"] + except KeyError: + warn( + f"'SessionData' key not found in '{bpod_file_path}'. The date index could not be determined from the file." + ) + return None + + num_trials = bpod_session_data["nTrials"] + date = bpod_session_data["Info"]["SessionDate"] + + a_struct_data = read_mat(str(a_struct_file_path)) + dates = a_struct_data["A"]["date"] + num_trials_per_day = a_struct_data["A"]["ntrials"] + + dates_and_trials = pd.DataFrame(dict(date=dates, num_trials=num_trials_per_day)) + filtered_dates_and_trials = dates_and_trials[ + (dates_and_trials["date"] == date) & (dates_and_trials["num_trials"] == num_trials) + ] + + if filtered_dates_and_trials.empty: + warn(f"Date index for '{date}' not found in '{a_struct_file_path}'.") + return None + + return filtered_dates_and_trials.index[0] + + def sessions_to_nwb( raw_behavior_folder_path: Union[str, Path], processed_behavior_folder_path: Union[str, Path], @@ -108,7 +153,7 @@ def sessions_to_nwb( processed_mat_files = list(processed_behavior_folder_path.glob("ratTrial*.mat")) subject_ids = [ processed_behavior_file_path.stem.split("_")[-1] for processed_behavior_file_path in processed_mat_files - ][:10] + ] sessions_to_convert_per_subject = { subject_id: _get_sessions_to_convert_from_mat( file_path=processed_behavior_file_path, bpod_folder_path=raw_behavior_folder_path @@ -128,14 +173,22 @@ def sessions_to_nwb( ) for raw_behavior_file_path in progress_bar: - session_id = raw_behavior_file_path.stem.split("_", maxsplit=1)[1].replace("_", "-") + session_id = Path(raw_behavior_file_path).stem.split("_", maxsplit=1)[1].replace("_", "-") subject_nwb_folder_path = nwbfile_folder_path / f"sub-{subject_id}" if not subject_nwb_folder_path.exists(): os.makedirs(subject_nwb_folder_path, exist_ok=True) nwbfile_path = subject_nwb_folder_path / f"sub-{subject_id}_ses-{session_id}.nwb" if nwbfile_path.exists() and not overwrite: - print(f"Skipping existing NWB file: {nwbfile_path}") + continue + + date_index = _get_date_index( + bpod_file_path=raw_behavior_file_path, a_struct_file_path=processed_behavior_file_path + ) + if date_index is None: + print( + f"Skipping '{subject_id}' session '{session_id}', session not found in the processed behavior file." + ) continue date_from_mat = session_id.split("-")[1] @@ -149,6 +202,7 @@ def sessions_to_nwb( session_to_nwb( raw_behavior_file_path=raw_behavior_file_path, processed_behavior_file_path=processed_behavior_file_path, + date_index=date_index, nwbfile_path=nwbfile_path, column_name_mapping=column_name_mapping, column_descriptions=column_descriptions,