Skip to content

Commit

Permalink
raise error when state type is missing in metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
weiglszonja committed Oct 3, 2024
1 parent 1186fbc commit b7820b7
Showing 1 changed file with 38 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class Mah2024BpodInterface(BaseDataInterface):
"""Behavior interface for mah_2024 conversion"""

def __init__(
self,
file_path: Union[str, Path],
default_struct_name: str = "SessionData",
verbose: bool = True,
self,
file_path: Union[str, Path],
default_struct_name: str = "SessionData",
verbose: bool = True,
):
"""
Interface for converting raw Bpod data to NWB.
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_metadata(self) -> DeepDict:
if "Info" in self._bpod_struct:
info_dict = self._bpod_struct["Info"]
date_string = info_dict["SessionDate"] + info_dict["SessionStartTime_UTC"]
session_start_time = datetime.strptime(date_string, '%d-%b-%Y%H:%M:%S')
session_start_time = datetime.strptime(date_string, "%d-%b-%Y%H:%M:%S")
metadata["NWBFile"].update(session_start_time=session_start_time)

# Device info
Expand Down Expand Up @@ -303,7 +303,9 @@ def _read_file(self) -> dict:
def get_trial_times(self) -> (List[float], List[float]):
return self._bpod_struct["TrialStartTimestamp"], self._bpod_struct["TrialEndTimestamp"]

def create_states_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[StateTypesTable, StatesTable]:
def create_states_table(
self, metadata: dict, trial_start_times: List[float]
) -> tuple[StateTypesTable, StatesTable]:
state_types_metadata = metadata["Behavior"]["StateTypesTable"]
states_table_metadata = metadata["Behavior"]["StatesTable"]

Expand All @@ -313,6 +315,11 @@ def create_states_table(self, metadata: dict, trial_start_times: List[float]) ->

trials_data = self._bpod_struct["RawEvents"]["Trial"]
for state_name in trials_data[0]["States"]:
if state_name not in state_types_metadata:
raise ValueError(
f"State '{state_name}' not found in metadata. Please provide in metadata['Behavior']['StateTypesTable']."
)

state_types.add_row(
state_name=state_types_metadata[state_name]["name"],
check_ragged=False,
Expand All @@ -334,7 +341,9 @@ def create_states_table(self, metadata: dict, trial_start_times: List[float]) ->

return state_types, states_table

def create_actions_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[ActionTypesTable, ActionsTable]:
def create_actions_table(
self, metadata: dict, trial_start_times: List[float]
) -> tuple[ActionTypesTable, ActionsTable]:
action_types_metadata = metadata["Behavior"]["ActionTypesTable"]
actions_table_metadata = metadata["Behavior"]["ActionsTable"]

Expand All @@ -351,7 +360,9 @@ def create_actions_table(self, metadata: dict, trial_start_times: List[float]) -
for trial_states_and_events, trial_start_time in zip(trials_data, trial_start_times):
events = trial_states_and_events["Events"]

sound_events = [event_name for event_name in events if "AudioPlayer" in event_name or "WavePlayer" in event_name]
sound_events = [
event_name for event_name in events if "AudioPlayer" in event_name or "WavePlayer" in event_name
]
if not len(sound_events):
continue

Expand All @@ -369,7 +380,9 @@ def create_actions_table(self, metadata: dict, trial_start_times: List[float]) -

return action_types, actions_table

def create_events_table(self, metadata: dict, trial_start_times: List[float]) -> tuple[EventTypesTable, EventsTable]:
def create_events_table(
self, metadata: dict, trial_start_times: List[float]
) -> tuple[EventTypesTable, EventsTable]:
event_types_metadata = metadata["Behavior"]["EventTypesTable"]
events_table_metadata = metadata["Behavior"]["EventsTable"]

Expand Down Expand Up @@ -520,10 +533,10 @@ def add_trials(self, nwbfile: NWBFile, metadata: dict) -> None:
nwbfile.trials = trials

def add_task_arguments_to_trials(
self,
nwbfile: NWBFile,
metadata: dict,
arguments_to_exclude: List[str] = None,
self,
nwbfile: NWBFile,
metadata: dict,
arguments_to_exclude: List[str] = None,
) -> None:
if arguments_to_exclude is None:
arguments_to_exclude = []
Expand All @@ -540,18 +553,25 @@ def add_task_arguments_to_trials(
if task_argument_name in arguments_to_exclude:
continue
if task_argument_name not in task_arguments_metadata:
warn(f"Task argument '{task_argument_name}' not in metadata. Skipping.")
continue
task_argument_values = np.array([trial_settings["GUI"][task_argument_name] for trial_settings in trials_settings])
warn(f"Task argument '{task_argument_name}' not in metadata.")
task_argument_column_name = task_argument_name
description = "no description"
else:
task_argument_column_name = task_arguments_metadata[task_argument_name]["name"]
description = task_arguments_metadata[task_argument_name]["description"]

task_argument_values = np.array(
[trial_settings["GUI"][task_argument_name] for trial_settings in trials_settings]
)
task_argument_type = task_arguments_metadata[task_argument_name]["expression_type"]
if task_argument_type == "boolean":
task_argument_values = task_argument_values.astype(bool)
elif task_argument_name == "Block":
task_argument_values = np.array([self._block_name_mapping[block] for block in task_argument_values])

trials.add_column(
name=task_arguments_metadata[task_argument_name]["name"],
description=task_arguments_metadata[task_argument_name]["description"],
name=task_argument_column_name,
description=description,
data=task_argument_values,
)

Expand Down

0 comments on commit b7820b7

Please sign in to comment.