Skip to content

Commit

Permalink
update to use date_index
Browse files Browse the repository at this point in the history
  • Loading branch information
weiglszonja committed Oct 3, 2024
1 parent b01f226 commit 1186fbc
Showing 1 changed file with 58 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from datetime import datetime
from pathlib import Path
from typing import Union, List
from warnings import warn

from nwbinspector import inspect_all, format_messages, save_report
import pandas as pd
from pymatreader import read_mat
from tqdm import tqdm

Expand Down Expand Up @@ -71,6 +72,50 @@ def _get_sessions_to_convert_from_mat(
return bpod_files_to_convert


def _get_date_index(bpod_file_path: Union[str, Path], a_struct_file_path: Union[str, Path]) -> Union[int, None]:
"""
Figure out the date index for the processed behavior file.
Parameters
----------
bpod_file_path: Union[str, Path]
Path to the raw Bpod output (.mat file).
a_struct_file_path: Union[str, Path]
Path to the processed behavior data (.mat file).
Returns
-------
int
The date index for the processed behavior file.
"""
bpod_data = read_mat(str(bpod_file_path))
try:
bpod_session_data = bpod_data["SessionData"]
except KeyError:
warn(
f"'SessionData' key not found in '{bpod_file_path}'. The date index could not be determined from the file."
)
return None

num_trials = bpod_session_data["nTrials"]
date = bpod_session_data["Info"]["SessionDate"]

a_struct_data = read_mat(str(a_struct_file_path))
dates = a_struct_data["A"]["date"]
num_trials_per_day = a_struct_data["A"]["ntrials"]

dates_and_trials = pd.DataFrame(dict(date=dates, num_trials=num_trials_per_day))
filtered_dates_and_trials = dates_and_trials[
(dates_and_trials["date"] == date) & (dates_and_trials["num_trials"] == num_trials)
]

if filtered_dates_and_trials.empty:
warn(f"Date index for '{date}' not found in '{a_struct_file_path}'.")
return None

return filtered_dates_and_trials.index[0]


def sessions_to_nwb(
raw_behavior_folder_path: Union[str, Path],
processed_behavior_folder_path: Union[str, Path],
Expand Down Expand Up @@ -108,7 +153,7 @@ def sessions_to_nwb(
processed_mat_files = list(processed_behavior_folder_path.glob("ratTrial*.mat"))
subject_ids = [
processed_behavior_file_path.stem.split("_")[-1] for processed_behavior_file_path in processed_mat_files
][:10]
]
sessions_to_convert_per_subject = {
subject_id: _get_sessions_to_convert_from_mat(
file_path=processed_behavior_file_path, bpod_folder_path=raw_behavior_folder_path
Expand All @@ -128,14 +173,22 @@ def sessions_to_nwb(
)

for raw_behavior_file_path in progress_bar:
session_id = raw_behavior_file_path.stem.split("_", maxsplit=1)[1].replace("_", "-")
session_id = Path(raw_behavior_file_path).stem.split("_", maxsplit=1)[1].replace("_", "-")
subject_nwb_folder_path = nwbfile_folder_path / f"sub-{subject_id}"
if not subject_nwb_folder_path.exists():
os.makedirs(subject_nwb_folder_path, exist_ok=True)
nwbfile_path = subject_nwb_folder_path / f"sub-{subject_id}_ses-{session_id}.nwb"

if nwbfile_path.exists() and not overwrite:
print(f"Skipping existing NWB file: {nwbfile_path}")
continue

date_index = _get_date_index(
bpod_file_path=raw_behavior_file_path, a_struct_file_path=processed_behavior_file_path
)
if date_index is None:
print(
f"Skipping '{subject_id}' session '{session_id}', session not found in the processed behavior file."
)
continue

date_from_mat = session_id.split("-")[1]
Expand All @@ -149,6 +202,7 @@ def sessions_to_nwb(
session_to_nwb(
raw_behavior_file_path=raw_behavior_file_path,
processed_behavior_file_path=processed_behavior_file_path,
date_index=date_index,
nwbfile_path=nwbfile_path,
column_name_mapping=column_name_mapping,
column_descriptions=column_descriptions,
Expand Down

0 comments on commit 1186fbc

Please sign in to comment.