Skip to content

Commit

Permalink
reduce warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Dec 13, 2024
1 parent 43477de commit c1bcf37
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 124 deletions.
10 changes: 10 additions & 0 deletions src/neuroconv/tools/spikeinterface/spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,16 @@ def add_electrical_series_to_nwbfile(
whenever possible.
"""

if starting_time is not None:
warnings.warn(
"The 'starting_time' parameter is deprecated and will be removed in June 2025. "
"Use the time alignment methods or set the recording times directlyfor modifying the starting time or timestamps "
"of the data if needed: "
"https://neuroconv.readthedocs.io/en/main/user_guide/temporal_alignment.html",
DeprecationWarning,
stacklevel=2,
)

assert write_as in [
"raw",
"processed",
Expand Down
22 changes: 7 additions & 15 deletions src/neuroconv/tools/testing/data_interface_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import json
import tempfile
from abc import abstractmethod
Expand Down Expand Up @@ -407,7 +406,7 @@ def check_read_nwb(self, nwbfile_path: str):
# Spikeinterface behavior is to load the electrode table channel_name property as a channel_id
self.nwb_recording = NwbRecordingExtractor(
file_path=nwbfile_path,
electrical_series_name=electrical_series_name,
electrical_series_path=f"acquisition/{electrical_series_name}",
use_pynwb=True,
)

Expand Down Expand Up @@ -439,7 +438,7 @@ def check_read_nwb(self, nwbfile_path: str):
assert_array_equal(
recording.get_property(property_name), self.nwb_recording.get_property(property_name)
)
if recording.has_scaled_traces() and self.nwb_recording.has_scaled_traces():
if recording.has_scaleable_traces() and self.nwb_recording.has_scaleable_traces():
check_recordings_equal(RX1=recording, RX2=self.nwb_recording, return_scaled=True)

# Compare channel groups
Expand Down Expand Up @@ -625,29 +624,22 @@ def check_read_nwb(self, nwbfile_path: str):

# NWBSortingExtractor on spikeinterface does not yet support loading data written from multiple segment.
if sorting.get_num_segments() == 1:
# TODO after 0.100 release remove this if
signature = inspect.signature(NwbSortingExtractor)
if "t_start" in signature.parameters:
nwb_sorting = NwbSortingExtractor(file_path=nwbfile_path, sampling_frequency=sf, t_start=0.0)
else:
nwb_sorting = NwbSortingExtractor(file_path=nwbfile_path, sampling_frequency=sf)
nwb_sorting = NwbSortingExtractor(file_path=nwbfile_path, sampling_frequency=sf, t_start=0.0)

# In the NWBSortingExtractor, since unit_names could be not unique,
# table "ids" are loaded as unit_ids. Here we rename the original sorting accordingly
if "unit_name" in sorting.get_property_keys():
renamed_unit_ids = sorting.get_property("unit_name")
# sorting_renamed = sorting.rename_units(new_unit_ids=renamed_unit_ids) #TODO after 0.100 release use this
sorting_renamed = sorting.select_units(unit_ids=sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)
sorting_renamed = sorting.rename_units(new_unit_ids=renamed_unit_ids)

else:
nwb_has_ids_as_strings = all(isinstance(id, str) for id in nwb_sorting.unit_ids)
if nwb_has_ids_as_strings:
renamed_unit_ids = sorting.get_unit_ids()
renamed_unit_ids = [str(id) for id in renamed_unit_ids]
renamed_unit_ids = [str(id) for id in sorting.get_unit_ids()]
else:
renamed_unit_ids = np.arange(len(sorting.unit_ids))

# sorting_renamed = sorting.rename_units(new_unit_ids=sorting.unit_ids) #TODO after 0.100 release use this
sorting_renamed = sorting.select_units(unit_ids=sorting.unit_ids, renamed_unit_ids=renamed_unit_ids)
sorting_renamed = sorting.rename_units(new_unit_ids=sorting.unit_ids)
check_sortings_equal(SX1=sorting_renamed, SX2=nwb_sorting)

def check_interface_set_aligned_segment_timestamps(self):
Expand Down
4 changes: 4 additions & 0 deletions src/neuroconv/tools/testing/mock_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def __init__(
verbose=verbose,
)

# Sorting extractor to have string unit ids until is changed in SpikeInterface
string_unit_ids = [str(id) for id in self.sorting_extractor.unit_ids]
self.sorting_extractor = self.sorting_extractor.rename_units(new_unit_ids=string_unit_ids)

def get_metadata(self) -> dict:
metadata = super().get_metadata()
session_start_time = datetime.now().astimezone()
Expand Down
139 changes: 39 additions & 100 deletions tests/test_ecephys/test_ecephys_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
import shutil
import unittest
from datetime import datetime
from pathlib import Path
from platform import python_version as get_python_version
from tempfile import mkdtemp
from warnings import warn

import jsonschema
import numpy as np
import pytest
from hdmf.testing import TestCase
from packaging.version import Version
from pynwb import NWBHDF5IO
from spikeinterface.extractors import NumpySorting

from neuroconv import NWBConverter
from neuroconv.datainterfaces import Spike2RecordingInterface
from neuroconv.datainterfaces.ecephys.basesortingextractorinterface import (
BaseSortingExtractorInterface,
)
from neuroconv.tools.nwb_helpers import get_module
from neuroconv.tools.testing.mock_interfaces import (
MockRecordingInterface,
Expand Down Expand Up @@ -54,6 +42,45 @@ def test_propagate_conversion_options(self, setup_interface):
assert nwbfile.units is None
assert "processed_units" in ecephys.data_interfaces

def test_stub(self):

interface = MockSortingInterface(num_units=4, durations=[1.0])
sorting_extractor = interface.sorting_extractor
unit_ids = sorting_extractor.unit_ids
first_unit_spike = {
unit_id: sorting_extractor.get_unit_spike_train(unit_id=unit_id, return_times=True)[0]
for unit_id in unit_ids
}

nwbfile = interface.create_nwbfile(stub_test=True)
units_table = nwbfile.units.to_dataframe()

for unit_id, first_spike_time in first_unit_spike.items():
unit_row = units_table[units_table["unit_name"] == unit_id]
unit_spike_times = unit_row["spike_times"].values[0]
np.testing.assert_almost_equal(unit_spike_times[0], first_spike_time, decimal=6)

def test_stub_with_recording(self):
interface = MockSortingInterface(num_units=4, durations=[1.0])

recording_interface = MockRecordingInterface(num_channels=4, durations=[2.0])
interface.register_recording(recording_interface)

sorting_extractor = interface.sorting_extractor
unit_ids = sorting_extractor.unit_ids
first_unit_spike = {
unit_id: sorting_extractor.get_unit_spike_train(unit_id=unit_id, return_times=True)[0]
for unit_id in unit_ids
}

nwbfile = interface.create_nwbfile(stub_test=True)
units_table = nwbfile.units.to_dataframe()

for unit_id, first_spike_time in first_unit_spike.items():
unit_row = units_table[units_table["unit_name"] == unit_id]
unit_spike_times = unit_row["spike_times"].values[0]
np.testing.assert_almost_equal(unit_spike_times[0], first_spike_time, decimal=6)

def test_electrode_indices(self, setup_interface):

recording_interface = MockRecordingInterface(num_channels=4, durations=[0.100])
Expand Down Expand Up @@ -136,91 +163,3 @@ def test_spike2_import_assertions_3_11(self):
exc_msg="\nThe package 'sonpy' is not available for Python version 3.11!",
):
Spike2RecordingInterface.get_all_channels_info(file_path="does_not_matter.smrx")


class TestSortingInterfaceOld(unittest.TestCase):
"""Old-style tests for the SortingInterface. Remove once we we are sure all the behaviors are covered by the mock."""

@classmethod
def setUpClass(cls) -> None:
cls.test_dir = Path(mkdtemp())
cls.sorting_start_frames = [100, 200, 300]
cls.num_frames = 1000
cls.sampling_frequency = 3000.0
times = np.array([], dtype="int")
labels = np.array([], dtype="int")
for i, start_frame in enumerate(cls.sorting_start_frames):
times_i = np.arange(start_frame, cls.num_frames, dtype="int")
labels_i = (i + 1) * np.ones_like(times_i, dtype="int")
times = np.concatenate((times, times_i))
labels = np.concatenate((labels, labels_i))
sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=cls.sampling_frequency)

class TestSortingInterface(BaseSortingExtractorInterface):
ExtractorName = "NumpySorting"

def __init__(self, verbose: bool = True):
self.sorting_extractor = sorting
self.source_data = dict()
self.verbose = verbose

class TempConverter(NWBConverter):
data_interface_classes = dict(TestSortingInterface=TestSortingInterface)

source_data = dict(TestSortingInterface=dict())
cls.test_sorting_interface = TempConverter(source_data)

@classmethod
def tearDownClass(cls):
try:
shutil.rmtree(cls.test_dir)
except PermissionError: # Windows CI bug
warn(f"Unable to fully clean the temporary directory: {cls.test_dir}\n\nPlease remove it manually.")

def test_sorting_stub(self):
minimal_nwbfile = self.test_dir / "stub_temp.nwb"
conversion_options = dict(TestSortingInterface=dict(stub_test=True))
metadata = self.test_sorting_interface.get_metadata()
metadata["NWBFile"]["session_start_time"] = datetime.now().astimezone()
self.test_sorting_interface.run_conversion(
nwbfile_path=minimal_nwbfile, metadata=metadata, conversion_options=conversion_options
)
with NWBHDF5IO(minimal_nwbfile, "r") as io:
nwbfile = io.read()
start_frame_max = np.max(self.sorting_start_frames)
for i, start_times in enumerate(self.sorting_start_frames):
assert len(nwbfile.units["spike_times"][i]) == (start_frame_max * 1.1) - start_times

def test_sorting_stub_with_recording(self):
subset_end_frame = int(np.max(self.sorting_start_frames) * 1.1 - 1)
sorting_interface = self.test_sorting_interface.data_interface_objects["TestSortingInterface"]
sorting_interface.sorting_extractor = sorting_interface.sorting_extractor.frame_slice(
start_frame=0, end_frame=subset_end_frame
)
recording_interface = MockRecordingInterface(
durations=[subset_end_frame / self.sampling_frequency],
sampling_frequency=self.sampling_frequency,
)
sorting_interface.register_recording(recording_interface)

minimal_nwbfile = self.test_dir / "stub_temp_recording.nwb"
conversion_options = dict(TestSortingInterface=dict(stub_test=True))
metadata = self.test_sorting_interface.get_metadata()
metadata["NWBFile"]["session_start_time"] = datetime.now().astimezone()
self.test_sorting_interface.run_conversion(
nwbfile_path=minimal_nwbfile, metadata=metadata, conversion_options=conversion_options
)
with NWBHDF5IO(minimal_nwbfile, "r") as io:
nwbfile = io.read()
for i, start_times in enumerate(self.sorting_start_frames):
assert len(nwbfile.units["spike_times"][i]) == subset_end_frame - start_times

def test_sorting_full(self):
minimal_nwbfile = self.test_dir / "temp.nwb"
metadata = self.test_sorting_interface.get_metadata()
metadata["NWBFile"]["session_start_time"] = datetime.now().astimezone()
self.test_sorting_interface.run_conversion(nwbfile_path=minimal_nwbfile, metadata=metadata)
with NWBHDF5IO(minimal_nwbfile, "r") as io:
nwbfile = io.read()
for i, start_times in enumerate(self.sorting_start_frames):
assert len(nwbfile.units["spike_times"][i]) == self.num_frames - start_times
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_simple_time_series(
dataset_configuration = backend_configuration.dataset_configurations["acquisition/TestTimeSeries/data"]
configure_backend(nwbfile=nwbfile, backend_configuration=backend_configuration)

nwbfile_path = str(tmpdir / f"test_configure_defaults_{case_name}_time_series.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_defaults_{case_name}_time_series.nwb")
with BACKEND_NWB_IO[backend](path=nwbfile_path, mode="w") as io:
io.write(nwbfile)

Expand Down Expand Up @@ -98,7 +98,7 @@ def test_simple_dynamic_table(tmpdir: Path, integer_array: np.ndarray, backend:
dataset_configuration = backend_configuration.dataset_configurations["acquisition/TestDynamicTable/TestColumn/data"]
configure_backend(nwbfile=nwbfile, backend_configuration=backend_configuration)

nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb")
NWB_IO = BACKEND_NWB_IO[backend]
with NWB_IO(path=nwbfile_path, mode="w") as io:
io.write(nwbfile)
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_time_series_timestamps_linkage(
assert nwbfile.acquisition["TestTimeSeries1"].timestamps
assert nwbfile.acquisition["TestTimeSeries2"].timestamps

nwbfile_path = str(tmpdir / f"test_time_series_timestamps_linkage_{case_name}_data.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_time_series_timestamps_linkage_{case_name}_data.nwb")
with BACKEND_NWB_IO[backend](path=nwbfile_path, mode="w") as io:
io.write(nwbfile)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_configure_backend_equivalency(
dataset_configuration.compression_options = {"level": 2}
configure_backend(nwbfile=nwbfile_1, backend_configuration=backend_configuration_2)

nwbfile_path = str(tmpdir / f"test_configure_backend_equivalency.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_backend_equivalency.nwb")
with BACKEND_NWB_IO[backend](path=nwbfile_path, mode="w") as io:
io.write(nwbfile_1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_simple_time_series_override(
if case_name != "unwrapped": # TODO: eventually, even this case will be buffered automatically
assert nwbfile.acquisition["TestTimeSeries"].data

nwbfile_path = str(tmpdir / f"test_configure_defaults_{case_name}_data.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_defaults_{case_name}_data.nwb")
with BACKEND_NWB_IO[backend](path=nwbfile_path, mode="w") as io:
io.write(nwbfile)

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_simple_dynamic_table_override(tmpdir: Path, backend: Literal["hdf5", "z

configure_backend(nwbfile=nwbfile, backend_configuration=backend_configuration)

nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb")
NWB_IO = BACKEND_NWB_IO[backend]
with NWB_IO(path=nwbfile_path, mode="w") as io:
io.write(nwbfile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_dynamic_table_skip_zero_length_axis(
dataset_configuration = backend_configuration.dataset_configurations["acquisition/TestDynamicTable/TestColumn/data"]
configure_backend(nwbfile=nwbfile, backend_configuration=backend_configuration)

nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb.{backend}")
nwbfile_path = str(tmpdir / f"test_configure_defaults_dynamic_table.nwb")
NWB_IO = BACKEND_NWB_IO[backend]
with NWB_IO(path=nwbfile_path, mode="w") as io:
io.write(nwbfile)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_on_data/ecephys/test_lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class TestConverter(NWBConverter):

npt.assert_array_equal(x=recording.get_traces(return_scaled=False), y=nwb_lfp_unscaled)
# This can only be tested if both gain and offset are present
if recording.has_scaled_traces():
if recording.has_scaleable_traces():
channel_conversion = nwb_lfp_electrical_series.channel_conversion
nwb_lfp_conversion_vector = (
channel_conversion[:]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_on_data/ecephys/test_raw_recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class TestConverter(NWBConverter):
# are specified, which occurs during check_recordings_equal when there is only one channel
if nwb_recording.get_channel_ids()[0] != nwb_recording.get_channel_ids()[-1]:
check_recordings_equal(RX1=recording, RX2=nwb_recording, return_scaled=False)
if recording.has_scaled_traces() and nwb_recording.has_scaled_traces():
if recording.has_scaleable_traces() and nwb_recording.has_scaleable_traces():
check_recordings_equal(RX1=recording, RX2=nwb_recording, return_scaled=True)


Expand Down

0 comments on commit c1bcf37

Please sign in to comment.