Skip to content

Commit

Permalink
Enable zarr backend testing in data tests [3] (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 23, 2024
1 parent 9f67ec6 commit 7ea96d8
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Improvements
* Remove dev test from PR [PR #1092](https://github.com/catalystneuro/neuroconv/pull/1092)
* Run only the most basic testing while a PR is on draft [PR #1082](https://github.com/catalystneuro/neuroconv/pull/1082)
* Test that zarr backend_configuration works in gin data tests [PR #1094](https://github.com/catalystneuro/neuroconv/pull/1094)
* Consolidated weekly workflows into one workflow and added email notifications [PR #1088](https://github.com/catalystneuro/neuroconv/pull/1088)
* Avoid running link test when the PR is on draft [PR #1093](https://github.com/catalystneuro/neuroconv/pull/1093)

Expand Down
196 changes: 100 additions & 96 deletions src/neuroconv/tools/testing/data_interface_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,23 @@ def test_source_schema_valid(self):
schema = self.data_interface_cls.get_source_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
def test_conversion_options_schema_valid(self, setup_interface):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def test_metadata_schema_valid(self, setup_interface):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_metadata(self):
def test_metadata(self, setup_interface):
# Validate metadata now happens on the class itself
metadata = self.interface.get_metadata()
self.check_extracted_metadata(metadata)

def check_extracted_metadata(self, metadata: dict):
"""Override this method to make assertions about specific extracted metadata values."""
pass

def test_no_metadata_mutation(self, setup_interface):
"""Ensure the metadata object is not altered by `add_to_nwbfile` method."""

Expand All @@ -107,13 +111,35 @@ def test_no_metadata_mutation(self, setup_interface):
self.interface.add_to_nwbfile(nwbfile=nwbfile, metadata=metadata, **self.conversion_options)
assert metadata == metadata_before_add_method

def check_run_conversion_with_backend_configuration(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend(self, setup_interface, tmp_path, backend):

nwbfile_path = str(tmp_path / f"conversion_with_backend{backend}-{self.test_name}.nwb")

metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)

if backend == "zarr":
with NWBZarrIO(path=nwbfile_path, mode="r") as io:
io.read()

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend_configuration(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_path = str(tmp_path / f"conversion_with_backend_configuration{backend}-{self.test_name}.nwb")

nwbfile = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
backend_configuration = self.interface.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
self.interface.run_conversion(
Expand All @@ -125,6 +151,42 @@ def check_run_conversion_with_backend_configuration(
**self.conversion_options,
)

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_configure_backend_for_equivalent_nwbfiles(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_1 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
nwbfile_2 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)

backend_configuration = get_default_backend_configuration(nwbfile=nwbfile_1, backend=backend)
configure_backend(nwbfile=nwbfile_2, backend_configuration=backend_configuration)

def test_all_conversion_checks(self, setup_interface, tmp_path):
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")
self.nwbfile_path = nwbfile_path

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()

@abstractmethod
def check_read_nwb(self, nwbfile_path: str):
"""Read the produced NWB file and compare it to the interface."""
pass

def run_custom_checks(self):
"""Override this in child classes to inject additional custom checks."""
pass

def check_run_conversion_in_nwbconverter_with_backend(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
Expand Down Expand Up @@ -174,73 +236,6 @@ class TestNWBConverter(NWBConverter):
conversion_options=conversion_options,
)

@abstractmethod
def check_read_nwb(self, nwbfile_path: str):
"""Read the produced NWB file and compare it to the interface."""
pass

def check_extracted_metadata(self, metadata: dict):
"""Override this method to make assertions about specific extracted metadata values."""
pass

def run_custom_checks(self):
"""Override this in child classes to inject additional custom checks."""
pass

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend(self, setup_interface, tmp_path, backend):

nwbfile_path = str(tmp_path / f"conversion_with_backend{backend}-{self.test_name}.nwb")

metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)

if backend == "zarr":
with NWBZarrIO(path=nwbfile_path, mode="r") as io:
io.read()

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_configure_backend_for_equivalent_nwbfiles(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_1 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
nwbfile_2 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)

backend_configuration = get_default_backend_configuration(nwbfile=nwbfile_1, backend=backend)
configure_backend(nwbfile=nwbfile_2, backend_configuration=backend_configuration)

def test_all_conversion_checks(self, setup_interface, tmp_path):
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")
self.nwbfile_path = nwbfile_path

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()


class TemporalAlignmentMixin:
"""
Expand Down Expand Up @@ -718,27 +713,6 @@ def check_shift_segment_timestamps_by_starting_times(self):
):
assert_array_equal(x=retrieved_aligned_timestamps, y=expected_aligned_timestamps)

def test_all_conversion_checks(self, setup_interface, tmp_path):
# The fixture `setup_interface` sets up the necessary objects
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()

def test_interface_alignment(self, setup_interface):

# TODO sorting can have times without associated recordings, test this later
Expand Down Expand Up @@ -872,12 +846,21 @@ class MedPCInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
A mixin for testing MedPC interfaces.
"""

def test_metadata(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_metadata_schema_valid(self):
pass

def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_no_metadata_mutation(self):
pass

Expand All @@ -888,6 +871,10 @@ def check_metadata_schema_valid(self):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def check_metadata(self):
schema = self.interface.get_metadata_schema()
metadata = self.interface.get_metadata()
Expand Down Expand Up @@ -1158,9 +1145,8 @@ def check_read_nwb(self, nwbfile_path: str):
assert one_photon_series.starting_frame is None
assert one_photon_series.timestamps.shape == (15,)

imaging_extractor = self.interface.imaging_extractor
times_from_extractor = imaging_extractor._times
assert_array_equal(one_photon_series.timestamps, times_from_extractor)
interface_times = self.interface.get_original_timestamps()
assert_array_equal(one_photon_series.timestamps, interface_times)


class ScanImageSinglePlaneImagingInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
Expand Down Expand Up @@ -1235,25 +1221,43 @@ def check_read_nwb(self, nwbfile_path: str):
class TDTFiberPhotometryInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""Mixin for testing TDT Fiber Photometry interfaces."""

def test_metadata(self):
pass

def test_metadata_schema_valid(self):
pass

def test_no_metadata_mutation(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_no_metadata_mutation(self):
pass

def test_configure_backend_for_equivalent_nwbfiles(self):
pass

def check_metadata(self):
# Validate metadata now happens on the class itself
metadata = self.interface.get_metadata()
self.check_extracted_metadata(metadata)

def check_metadata_schema_valid(self):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def check_no_metadata_mutation(self, metadata: dict):
"""Ensure the metadata object was not altered by `add_to_nwbfile` method."""

Expand Down
22 changes: 15 additions & 7 deletions tests/test_on_data/ecephys/test_recording_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_add_channel_metadata_to_nwb(self, setup_interface):
else:
assert expected_value == extracted_value

# Test addition to electrodes table
# Test addition to electrodes table!~
with NWBHDF5IO(self.nwbfile_path, "r") as io:
nwbfile = io.read()
electrode_table = nwbfile.electrodes.to_dataframe()
Expand All @@ -176,9 +176,6 @@ class TestEDFRecordingInterface(RecordingExtractorInterfaceTestMixin):
interface_kwargs = dict(file_path=str(ECEPHY_DATA_PATH / "edf" / "edf+C.edf"))
save_directory = OUTPUT_PATH

def check_extracted_metadata(self, metadata: dict):
assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 3, 2, 10, 42, 19)

def check_run_conversion_with_backend(self, nwbfile_path: str, backend="hdf5"):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
Expand All @@ -198,11 +195,10 @@ def test_all_conversion_checks(self, setup_interface, tmp_path):
self.nwbfile_path = nwbfile_path

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()
metadata = self.interface.get_metadata()
assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 3, 2, 10, 42, 19)

self.check_run_conversion_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# EDF has simultaneous access issues; can't have multiple interfaces open on the same file at once...
Expand All @@ -215,12 +211,24 @@ def test_no_metadata_mutation(self):
def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_interface_alignment(self):
pass

def test_configure_backend_for_equivalent_nwbfiles(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_metadata(self):
pass

def test_conversion_options_schema_valid(self):
pass


class TestIntanRecordingInterfaceRHS(RecordingExtractorInterfaceTestMixin):
data_interface_cls = IntanRecordingInterface
Expand Down
4 changes: 4 additions & 0 deletions tests/test_ophys/test_ophys_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class TestMockImagingInterface(ImagingExtractorInterfaceTestMixin):
data_interface_cls = MockImagingInterface
interface_kwargs = dict()

# TODO: fix this by setting a seed on the dummy imaging extractor
def test_all_conversion_checks(self):
pass


class TestMockSegmentationInterface(SegmentationExtractorInterfaceTestMixin):

Expand Down

0 comments on commit 7ea96d8

Please sign in to comment.