diff --git a/src/constantinople_lab_to_nwb/general_interfaces/bpodbehaviorinterface.py b/src/constantinople_lab_to_nwb/general_interfaces/bpodbehaviorinterface.py index 21bd738..f78649a 100644 --- a/src/constantinople_lab_to_nwb/general_interfaces/bpodbehaviorinterface.py +++ b/src/constantinople_lab_to_nwb/general_interfaces/bpodbehaviorinterface.py @@ -17,10 +17,11 @@ TaskArgumentsTable, ) from neuroconv import BaseDataInterface -from neuroconv.utils import DeepDict +from neuroconv.utils import DeepDict, get_base_schema, get_schema_from_hdmf_class from ndx_structured_behavior.utils import loadmat from pynwb import NWBFile +from pynwb.device import Device class BpodBehaviorInterface(BaseDataInterface): @@ -48,8 +49,39 @@ def __init__( self.file_path = file_path self._bpod_struct = self._read_file() self._block_name_mapping = {1: "Mixed", 2: "High", 3: "Low"} + self._trial_start_times = None + self._trial_stop_times = None super().__init__(file_path=file_path, verbose=verbose) + def get_metadata_schema(self) -> dict: + metadata_schema = super().get_metadata_schema() + metadata_schema["properties"]["Behavior"] = get_base_schema(tag="Behavior") + device_schema = get_schema_from_hdmf_class(Device) + metadata_schema["properties"]["Behavior"].update( + required=[ + "Device", + "StateTypesTable", + "StatesTable", + "ActionTypesTable", + "ActionsTable", + "EventTypesTable", + "EventsTable", + "TrialsTable", + ], + properties=dict( + Device=device_schema, + StateTypesTable=dict(type="object", properties=dict()), + StatesTable=dict(type="object", properties=dict()), + ActionTypesTable=dict(type="object", properties=dict()), + ActionsTable=dict(type="object", properties=dict()), + EventTypesTable=dict(type="object", properties=dict()), + EventsTable=dict(type="object", properties=dict()), + TrialsTable=dict(type="object", properties=dict()), + TaskArgumentsTable=dict(type="object", properties=dict()), + ), + ) + return metadata_schema + def get_metadata(self) -> DeepDict: metadata = super().get_metadata() @@ -81,6 +113,7 @@ def get_metadata(self) -> DeepDict: EventTypesTable=dict(description="Contains the name of the events in the task."), EventsTable=dict(description="Contains the onset times of events in the task."), TrialsTable=dict(description="Contains the start and end times of each trial in the task."), + TaskArgumentsTable=dict(description="Contains the task arguments for the task."), ) task_arguments = dict( @@ -301,7 +334,17 @@ def _read_file(self) -> dict: return mat_file[self.default_struct_name] def get_trial_times(self) -> (List[float], List[float]): - return self._bpod_struct["TrialStartTimestamp"], self._bpod_struct["TrialEndTimestamp"] + trial_start_times = ( + self._trial_start_times if self._trial_start_times is not None else self._bpod_struct["TrialStartTimestamp"] + ) + trial_stop_times = ( + self._trial_stop_times if self._trial_stop_times is not None else self._bpod_struct["TrialEndTimestamp"] + ) + return trial_start_times, trial_stop_times + + def set_aligned_trial_times(self, trial_start_times: List[float], trial_stop_times: List[float]) -> None: + self._trial_start_times = trial_start_times + self._trial_stop_times = trial_stop_times def create_states_table( self, metadata: dict, trial_start_times: List[float] @@ -519,7 +562,6 @@ def add_trials(self, nwbfile: NWBFile, metadata: dict) -> None: events_table=events_table, actions_table=actions_table, ) - trial_stop_times = trial_start_times[1:] + [np.nan] for start, stop in zip(trial_start_times, trial_stop_times): states_table_df = states_table[:] states_index_mask = (states_table_df["start_time"] >= start) & (states_table_df["stop_time"] < stop)