diff --git a/src/constantinople_lab_to_nwb/mah_2024/interfaces/mah_2024_bpodinterface.py b/src/constantinople_lab_to_nwb/mah_2024/interfaces/mah_2024_bpodinterface.py index 6b24943..fd6dbca 100644 --- a/src/constantinople_lab_to_nwb/mah_2024/interfaces/mah_2024_bpodinterface.py +++ b/src/constantinople_lab_to_nwb/mah_2024/interfaces/mah_2024_bpodinterface.py @@ -27,10 +27,10 @@ class Mah2024BpodInterface(BaseDataInterface): """Behavior interface for mah_2024 conversion""" def __init__( - self, - file_path: Union[str, Path], - default_struct_name: str = "SessionData", - verbose: bool = True, + self, + file_path: Union[str, Path], + default_struct_name: str = "SessionData", + verbose: bool = True, ): """ Interface for converting raw Bpod data to NWB. @@ -61,7 +61,7 @@ def get_metadata(self) -> DeepDict: if "Info" in self._bpod_struct: info_dict = self._bpod_struct["Info"] date_string = info_dict["SessionDate"] + info_dict["SessionStartTime_UTC"] - session_start_time = datetime.strptime(date_string, '%d-%b-%Y%H:%M:%S') + session_start_time = datetime.strptime(date_string, "%d-%b-%Y%H:%M:%S") metadata["NWBFile"].update(session_start_time=session_start_time) # Device info @@ -303,7 +303,9 @@ def _read_file(self) -> dict: def get_trial_times(self) -> (List[float], List[float]): return self._bpod_struct["TrialStartTimestamp"], self._bpod_struct["TrialEndTimestamp"] - def create_states_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[StateTypesTable, StatesTable]: + def create_states_table( + self, metadata: dict, trial_start_times: List[float] + ) -> tuple[StateTypesTable, StatesTable]: state_types_metadata = metadata["Behavior"]["StateTypesTable"] states_table_metadata = metadata["Behavior"]["StatesTable"] @@ -313,6 +315,11 @@ def create_states_table(self, metadata: dict, trial_start_times: List[float]) -> trials_data = self._bpod_struct["RawEvents"]["Trial"] for state_name in trials_data[0]["States"]: + if state_name not in state_types_metadata: + raise ValueError( + f"State '{state_name}' not found in metadata. Please provide in metadata['Behavior']['StateTypesTable']." + ) + state_types.add_row( state_name=state_types_metadata[state_name]["name"], check_ragged=False, @@ -334,7 +341,9 @@ def create_states_table(self, metadata: dict, trial_start_times: List[float]) -> return state_types, states_table - def create_actions_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[ActionTypesTable, ActionsTable]: + def create_actions_table( + self, metadata: dict, trial_start_times: List[float] + ) -> tuple[ActionTypesTable, ActionsTable]: action_types_metadata = metadata["Behavior"]["ActionTypesTable"] actions_table_metadata = metadata["Behavior"]["ActionsTable"] @@ -351,7 +360,9 @@ def create_actions_table(self, metadata: dict, trial_start_times: List[float]) - for trial_states_and_events, trial_start_time in zip(trials_data, trial_start_times): events = trial_states_and_events["Events"] - sound_events = [event_name for event_name in events if "AudioPlayer" in event_name or "WavePlayer" in event_name] + sound_events = [ + event_name for event_name in events if "AudioPlayer" in event_name or "WavePlayer" in event_name + ] if not len(sound_events): continue @@ -369,7 +380,9 @@ def create_actions_table(self, metadata: dict, trial_start_times: List[float]) - return action_types, actions_table - def create_events_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[EventTypesTable, EventsTable]: + def create_events_table( + self, metadata: dict, trial_start_times: List[float] + ) -> tuple[EventTypesTable, EventsTable]: event_types_metadata = metadata["Behavior"]["EventTypesTable"] events_table_metadata = metadata["Behavior"]["EventsTable"] @@ -520,10 +533,10 @@ def add_trials(self, nwbfile: NWBFile, metadata: dict) -> None: nwbfile.trials = trials def add_task_arguments_to_trials( - self, - nwbfile: NWBFile, - metadata: dict, - arguments_to_exclude: List[str] = None, + self, + nwbfile: NWBFile, + metadata: dict, + arguments_to_exclude: List[str] = None, ) -> None: if arguments_to_exclude is None: arguments_to_exclude = [] @@ -540,9 +553,16 @@ def add_task_arguments_to_trials( if task_argument_name in arguments_to_exclude: continue if task_argument_name not in task_arguments_metadata: - warn(f"Task argument '{task_argument_name}' not in metadata. Skipping.") - continue - task_argument_values = np.array([trial_settings["GUI"][task_argument_name] for trial_settings in trials_settings]) + warn(f"Task argument '{task_argument_name}' not in metadata.") + task_argument_column_name = task_argument_name + description = "no description" + else: + task_argument_column_name = task_arguments_metadata[task_argument_name]["name"] + description = task_arguments_metadata[task_argument_name]["description"] + + task_argument_values = np.array( + [trial_settings["GUI"][task_argument_name] for trial_settings in trials_settings] + ) task_argument_type = task_arguments_metadata[task_argument_name]["expression_type"] if task_argument_type == "boolean": task_argument_values = task_argument_values.astype(bool) @@ -550,8 +570,8 @@ def add_task_arguments_to_trials( task_argument_values = np.array([self._block_name_mapping[block] for block in task_argument_values]) trials.add_column( - name=task_arguments_metadata[task_argument_name]["name"], - description=task_arguments_metadata[task_argument_name]["description"], + name=task_argument_column_name, + description=description, data=task_argument_values, )