Skip to content

Commit

Permalink
Merge pull request #18 from catalystneuro/update_bpod
Browse files Browse the repository at this point in the history
Updates for bpod interface
  • Loading branch information
weiglszonja authored Oct 30, 2024
2 parents e714700 + a1dd45a commit f8b9b8f
Showing 1 changed file with 45 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
TaskArgumentsTable,
)
from neuroconv import BaseDataInterface
from neuroconv.utils import DeepDict
from neuroconv.utils import DeepDict, get_base_schema, get_schema_from_hdmf_class

from ndx_structured_behavior.utils import loadmat
from pynwb import NWBFile
from pynwb.device import Device


class BpodBehaviorInterface(BaseDataInterface):
Expand Down Expand Up @@ -48,8 +49,39 @@ def __init__(
self.file_path = file_path
self._bpod_struct = self._read_file()
self._block_name_mapping = {1: "Mixed", 2: "High", 3: "Low"}
self._trial_start_times = None
self._trial_stop_times = None
super().__init__(file_path=file_path, verbose=verbose)

def get_metadata_schema(self) -> dict:
metadata_schema = super().get_metadata_schema()
metadata_schema["properties"]["Behavior"] = get_base_schema(tag="Behavior")
device_schema = get_schema_from_hdmf_class(Device)
metadata_schema["properties"]["Behavior"].update(
required=[
"Device",
"StateTypesTable",
"StatesTable",
"ActionTypesTable",
"ActionsTable",
"EventTypesTable",
"EventsTable",
"TrialsTable",
],
properties=dict(
Device=device_schema,
StateTypesTable=dict(type="object", properties=dict()),
StatesTable=dict(type="object", properties=dict()),
ActionTypesTable=dict(type="object", properties=dict()),
ActionsTable=dict(type="object", properties=dict()),
EventTypesTable=dict(type="object", properties=dict()),
EventsTable=dict(type="object", properties=dict()),
TrialsTable=dict(type="object", properties=dict()),
TaskArgumentsTable=dict(type="object", properties=dict()),
),
)
return metadata_schema

def get_metadata(self) -> DeepDict:
metadata = super().get_metadata()

Expand Down Expand Up @@ -81,6 +113,7 @@ def get_metadata(self) -> DeepDict:
EventTypesTable=dict(description="Contains the name of the events in the task."),
EventsTable=dict(description="Contains the onset times of events in the task."),
TrialsTable=dict(description="Contains the start and end times of each trial in the task."),
TaskArgumentsTable=dict(description="Contains the task arguments for the task."),
)

task_arguments = dict(
Expand Down Expand Up @@ -301,7 +334,17 @@ def _read_file(self) -> dict:
return mat_file[self.default_struct_name]

def get_trial_times(self) -> (List[float], List[float]):
return self._bpod_struct["TrialStartTimestamp"], self._bpod_struct["TrialEndTimestamp"]
trial_start_times = (
self._trial_start_times if self._trial_start_times is not None else self._bpod_struct["TrialStartTimestamp"]
)
trial_stop_times = (
self._trial_stop_times if self._trial_stop_times is not None else self._bpod_struct["TrialEndTimestamp"]
)
return trial_start_times, trial_stop_times

def set_aligned_trial_times(self, trial_start_times: List[float], trial_stop_times: List[float]) -> None:
self._trial_start_times = trial_start_times
self._trial_stop_times = trial_stop_times

def create_states_table(
self, metadata: dict, trial_start_times: List[float]
Expand Down Expand Up @@ -519,7 +562,6 @@ def add_trials(self, nwbfile: NWBFile, metadata: dict) -> None:
events_table=events_table,
actions_table=actions_table,
)
trial_stop_times = trial_start_times[1:] + [np.nan]
for start, stop in zip(trial_start_times, trial_stop_times):
states_table_df = states_table[:]
states_index_mask = (states_table_df["start_time"] >= start) & (states_table_df["stop_time"] < stop)
Expand Down

0 comments on commit f8b9b8f

Please sign in to comment.