Skip to content

Commit

Permalink
Fix compatible interface derivation
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettmflynn committed May 30, 2024
1 parent b979b1f commit e6f6614
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions src/pyflask/manageNeuroconv/manage_neuroconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,15 @@ def set_interface_alignment(converter, alignment_info):

def get_interface_alignment(info: dict) -> dict:

from neuroconv.datainterfaces.ecephys.basesortingextractorinterface import (
BaseSortingExtractorInterface,
)

from neuroconv.datainterfaces.ecephys.baserecordingextractorinterface import (
BaseRecordingExtractorInterface,
)


alignment_info = info.get("alignment", {})
converter = instantiate_custom_converter(info["source_data"], info["interfaces"])

Expand All @@ -750,18 +759,8 @@ def get_interface_alignment(info: dict) -> dict:
for name, interface in converter.data_interface_objects.items():

metadata[name] = dict()
is_sorting = metadata[name]["sorting"] = hasattr(interface, "sorting_extractor")

if is_sorting:
metadata[name]["compatible"] = []
for sub_name in alignment_info.keys():
sub_interface = converter.data_interface_objects[sub_name]
if hasattr(sub_interface, "recording_extractor"):
try:
interface.register_recording(sub_interface)
metadata[name]["compatible"].append(name)
except Exception:
pass

metadata[name]["sorting"] = hasattr(interface, "sorting_extractor")

# Run interface.get_timestamps if it has the method
if hasattr(interface, "get_timestamps"):
Expand All @@ -775,6 +774,24 @@ def get_interface_alignment(info: dict) -> dict:
else:
timestamps[name] = []



# Derive compatible interfaces
def on_sorting_interface(name, sorting_interface):
metadata[name]["compatible"] = []

def on_recording_interface(sub_name, recording_interface):
try:
sorting_interface.register_recording(recording_interface)
metadata[name]["compatible"].append(sub_name)
except Exception:
pass

map_interfaces(on_recording_interface, converter=converter, to_match=BaseRecordingExtractorInterface)

map_interfaces(on_sorting_interface, converter=converter, to_match=BaseSortingExtractorInterface)

# Return the metadata and timestamps
return dict(
metadata=metadata,
timestamps=timestamps,
Expand Down

0 comments on commit e6f6614

Please sign in to comment.