From 299077e878619ea9e7af6d83ae1cf1a5278e354f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 28 Jun 2023 22:13:08 +0200 Subject: [PATCH 01/57] failing test --- src/spikeinterface/core/core_tools.py | 25 ++++ .../tests/test_binaryrecordingextractor.py | 114 +++++++++++++++++- .../core/tests/test_generate.py | 29 +---- 3 files changed, 139 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 316d8f79a2..3a02b6f71c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -922,3 +922,28 @@ def convert_bytes_to_str(byte_value: int) -> str: byte_value /= 1024 i += 1 return f"{byte_value:.2f} {suffixes[i]}" + + +def measure_memory_allocation(measure_in_process: bool = True) -> float: + """ + A local utility to measure memory allocation at a specific point in time. + Can measure either the process resident memory or system wide memory available + + Uses psutil package. + + Parameters + ---------- + measure_in_process : bool, True by default + Mesure memory allocation in the current process only, if false then measures at the system + level. + """ + import psutil + + if measure_in_process: + process = psutil.Process() + memory = process.memory_info().rss + else: + mem_info = psutil.virtual_memory() + memory = mem_info.total - mem_info.available + + return memory diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 1d2c6e4c21..16001325ae 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -4,6 +4,8 @@ from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording +from spikeinterface.core.core_tools import measure_memory_allocation +from spikeinterface.core.generate import GeneratorRecording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -35,9 +37,10 @@ def test_BinaryRecordingExtractor(): def test_round_trip(tmp_path): num_channels = 10 num_samples = 50 - traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")] + + traces = np.arange(num_channels * num_samples, dtype="int16").reshape(num_samples, num_channels) sampling_frequency = 30_000.0 - recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency) + recording = NumpyRecording(traces_list=[traces], sampling_frequency=sampling_frequency) file_path = tmp_path / "test_BinaryRecordingExtractor.raw" dtype = recording.get_dtype() @@ -59,5 +62,112 @@ def test_round_trip(tmp_path): np.allclose(smaller_traces, binary_smaller_traces) +@pytest.fixture(scope="module") +def folder_with_binary_files(tmpdir_factory): + tmp_path = Path(tmpdir_factory.mktemp("spike_interface_test")) + folder = tmp_path / "test_binary_recording" + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + recording = GeneratorRecording( + durations=[3600], + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + ) + dtype = recording.get_dtype() + recording.save(folder=folder, overwrite=True) + + return folder + + +def test_memory_effcienty(folder_with_binary_files): + folder = folder_with_binary_files + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + + file_paths = [folder / "traces_cached_seg0.raw"] + recorder_binary = BinaryRecordingExtractor( + num_chan=num_channels, + file_paths=file_paths, + sampling_frequency=sampling_frequency, + dtype=dtype, + ) + + memory_before_traces_bytes = measure_memory_allocation() + traces = recorder_binary.get_traces(start_frame=1000, end_frame=10_000) + memory_after_traces_bytes = measure_memory_allocation() + traces_size_bytes = traces.nbytes + + expected_memory_usage = memory_before_traces_bytes + traces_size_bytes + expected_memory_usage_GiB = expected_memory_usage / 1024**3 + memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 + assert expected_memory_usage_GiB == pytest.approx(memory_after_traces_bytes_GiB, rel=0.1) + + +def measure_peak_memory_usage(): + """ + Measure the peak memory usage in bytes for the current process. + + The `resource.getrusage(resource.RUSAGE_SELF).ru_maxrss` command is used to get the peak memory usage. + The `ru_maxrss` attribute represents the maximum resident set size used (in kilobytes on Linux and bytes on MacOS), + which is the maximum memory used by the process since it was started. + + This function only works on Unix systems (including Linux and MacOS). + + Returns + ------- + int + Peak memory usage in bytes. + + Raises + ------ + NotImplementedError + If the function is called on a Windows system. + """ + + import sys + import resource + + if sys.platform == "win32": + raise NotImplementedError("Function cannot be used on Windows") + + mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + # If ru_maxrss returns memory in kilobytes (like on Linux), convert to bytes + if hasattr(resource, "RLIMIT_AS"): + mem_usage = mem_usage * 1024 + + return mem_usage + + +def test_peak_memory_usage(folder_with_binary_files): + folder = folder_with_binary_files + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + + file_paths = [folder / "traces_cached_seg0.raw"] + recorder_binary = BinaryRecordingExtractor( + num_chan=num_channels, + file_paths=file_paths, + sampling_frequency=sampling_frequency, + dtype=dtype, + ) + + memory_before_traces_bytes = measure_memory_allocation() + traces = recorder_binary.get_traces(start_frame=1000, end_frame=10_000) + traces_size_bytes = traces.nbytes + + expected_memory_usage = memory_before_traces_bytes + traces_size_bytes + peak_memory_GiB = measure_peak_memory_usage() / 1024**3 + expected_memory_usage_GiB = expected_memory_usage / 1024**3 + assert expected_memory_usage_GiB == pytest.approx(peak_memory_GiB, rel=0.1) + + print("Expected memory usage: {:.2f} GiB".format(expected_memory_usage_GiB)) + print(f"Peak memory usage: {peak_memory_GiB:.2f} GiB") + + if __name__ == "__main__": test_BinaryRecordingExtractor() diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index 873105e115..45ed791ab3 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -1,44 +1,19 @@ import pytest -import psutil import numpy as np from spikeinterface.core.generate import GeneratorRecording, generate_lazy_recording -from spikeinterface.core.core_tools import convert_bytes_to_str +from spikeinterface.core.core_tools import convert_bytes_to_str, measure_memory_allocation mode_list = GeneratorRecording.available_modes -def measure_memory_allocation(measure_in_process: bool = True) -> float: - """ - A local utility to measure memory allocation at a specific point in time. - Can measure either the process resident memory or system wide memory available - - Uses psutil package. - - Parameters - ---------- - measure_in_process : bool, True by default - Mesure memory allocation in the current process only, if false then measures at the system - level. - """ - - if measure_in_process: - process = psutil.Process() - memory = process.memory_info().rss - else: - mem_info = psutil.virtual_memory() - memory = mem_info.total - mem_info.available - - return memory - - @pytest.mark.parametrize("mode", mode_list) def test_lazy_random_recording(mode): # Test that get_traces does not consume more memory than allocated. bytes_to_MiB_factor = 1024**2 - relative_tolerance = 0.05 # relative tolerance of 5 per cent + relative_tolerance = 0.01 # relative tolerance of 5 per cent sampling_frequency = 30000 # Hz durations = [2.0] From 628b0a884a25233bed4baf56418dd4d3a3d0adc1 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 29 Jun 2023 10:54:11 +0200 Subject: [PATCH 02/57] passing tests --- .../core/binaryrecordingextractor.py | 68 +++++++++++++------ .../tests/test_binaryrecordingextractor.py | 68 +++++++++++++++---- 2 files changed, 101 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index c04a1c6ec7..deadcc2624 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -155,25 +155,17 @@ def get_binary_description(self): class BinaryRecordingSegment(BaseRecordingSegment): - def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset): + def __init__(self, file_path, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) self.num_chan = num_chan self.dtype = np.dtype(dtype) self.file_offset = file_offset self.time_axis = time_axis - self.datfile = datfile - self.file = open(self.datfile, "r") - self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize) - if self.time_axis == 0: - self.shape = (self.num_samples, self.num_chan) - else: - self.shape = (self.num_chan, self.num_samples) - - byte_offset = self.file_offset - dtype_size_bytes = self.dtype.itemsize - data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan - self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) - self.memmap_length = data_size_bytes + self.array_offset + self.file_path = file_path + self.file = open(self.file_path, "rb") + self.elements_per_sample = self.num_chan * self.dtype.itemsize + self.data_size_in_bytes = Path(file_path).stat().st_size - file_offset + self.num_samples = self.data_size_in_bytes // self.elements_per_sample def get_num_samples(self) -> int: """Returns the number of samples in this signal block @@ -189,23 +181,55 @@ def get_traces( end_frame: Union[int, None] = None, channel_indices: Union[List, None] = None, ) -> np.ndarray: - length = self.memmap_length - memmap_offset = self.memmap_offset + if start_frame is None: + start_frame = 0 + + if end_frame is None: + end_frame = self.get_num_samples() + + if end_frame > self.get_num_samples(): + raise ValueError(f"end_frame {end_frame} is larger than the number of samples {self.get_num_samples()}") + + dtype_size_bytes = np.dtype(self.dtype).itemsize + elements_per_sample = self.num_chan * dtype_size_bytes + + # Calculate byte offsets for start and end frames + start_byte = self.file_offset + start_frame * elements_per_sample + end_byte = self.file_offset + end_frame * elements_per_sample + + # Calculate the length of the data chunk to load into memory + length = end_byte - start_byte + + # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY + memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) + memmap_offset *= mmap.ALLOCATIONGRANULARITY + + # Adjust the length so it includes the extra data from rounding down the memmap offset to a multiple of ALLOCATIONGRANULARITY + length += start_offset + + # Create the mmap object memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset) - array = np.ndarray.__new__( - np.ndarray, - shape=self.shape, + # Create a numpy array using the mmap object as the buffer + # Note that the shape must be recalculated based on the new data chunk + if self.time_axis == 0: + shape = ((end_frame - start_frame), self.num_chan) + else: + shape = (self.num_chan, (end_frame - start_frame)) + + array = np.ndarray( + shape=shape, dtype=self.dtype, buffer=memmap_obj, - order="C", - offset=self.array_offset, + offset=start_offset, ) if self.time_axis == 1: array = array.T - traces = array[start_frame:end_frame] + # Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly + traces = array + if channel_indices is not None: traces = traces[:, channel_indices] diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 16001325ae..ed0b2922e7 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -52,10 +52,12 @@ def test_round_trip(tmp_path): file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype ) + # Test for full traces assert np.allclose(recording.get_traces(), binary_recorder.get_traces()) - start_frame = 200 - end_frame = 500 + # Ttest for a sub-set of the traces + start_frame = 20 + end_frame = 40 smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame) @@ -70,7 +72,7 @@ def folder_with_binary_files(tmpdir_factory): sampling_frequency = 30_000.0 dtype = "float32" recording = GeneratorRecording( - durations=[3600], + durations=[1.0], sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype, @@ -81,14 +83,53 @@ def folder_with_binary_files(tmpdir_factory): return folder +def test_sequential_reading_of_small_traces(folder_with_binary_files): + folder = folder_with_binary_files + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + + file_paths = [folder / "traces_cached_seg0.raw"] + recording = BinaryRecordingExtractor( + num_chan=num_channels, + file_paths=file_paths, + sampling_frequency=sampling_frequency, + dtype=dtype, + ) + + full_traces = recording.get_traces() + + # Test for a sub-set of the traces + start_frame = 10 + end_frame = 15 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 1000 + end_frame = 1100 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 10_000 + end_frame = 11_000 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + def test_memory_effcienty(folder_with_binary_files): + "This test that memory is freed afte reading the traces" folder = folder_with_binary_files num_channels = 32 sampling_frequency = 30_000.0 dtype = "float32" file_paths = [folder / "traces_cached_seg0.raw"] - recorder_binary = BinaryRecordingExtractor( + recording = BinaryRecordingExtractor( num_chan=num_channels, file_paths=file_paths, sampling_frequency=sampling_frequency, @@ -96,14 +137,14 @@ def test_memory_effcienty(folder_with_binary_files): ) memory_before_traces_bytes = measure_memory_allocation() - traces = recorder_binary.get_traces(start_frame=1000, end_frame=10_000) + traces = recording.get_traces(start_frame=1000, end_frame=10_000) memory_after_traces_bytes = measure_memory_allocation() traces_size_bytes = traces.nbytes expected_memory_usage = memory_before_traces_bytes + traces_size_bytes expected_memory_usage_GiB = expected_memory_usage / 1024**3 memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 - assert expected_memory_usage_GiB == pytest.approx(memory_after_traces_bytes_GiB, rel=0.1) + assert memory_after_traces_bytes_GiB == pytest.approx(expected_memory_usage_GiB, rel=0.1) def measure_peak_memory_usage(): @@ -143,13 +184,14 @@ def measure_peak_memory_usage(): def test_peak_memory_usage(folder_with_binary_files): + "This tests that there are no spikes in memory usage when reading traces." folder = folder_with_binary_files num_channels = 32 sampling_frequency = 30_000.0 dtype = "float32" file_paths = [folder / "traces_cached_seg0.raw"] - recorder_binary = BinaryRecordingExtractor( + recording = BinaryRecordingExtractor( num_chan=num_channels, file_paths=file_paths, sampling_frequency=sampling_frequency, @@ -157,16 +199,16 @@ def test_peak_memory_usage(folder_with_binary_files): ) memory_before_traces_bytes = measure_memory_allocation() - traces = recorder_binary.get_traces(start_frame=1000, end_frame=10_000) + traces = recording.get_traces(start_frame=1000, end_frame=2000) traces_size_bytes = traces.nbytes expected_memory_usage = memory_before_traces_bytes + traces_size_bytes - peak_memory_GiB = measure_peak_memory_usage() / 1024**3 - expected_memory_usage_GiB = expected_memory_usage / 1024**3 - assert expected_memory_usage_GiB == pytest.approx(peak_memory_GiB, rel=0.1) + peak_memory_MiB = measure_peak_memory_usage() / 1024**2 + expected_memory_usage_MiB = expected_memory_usage / 1024**2 + assert expected_memory_usage_MiB == pytest.approx(peak_memory_MiB, rel=0.1) - print("Expected memory usage: {:.2f} GiB".format(expected_memory_usage_GiB)) - print(f"Peak memory usage: {peak_memory_GiB:.2f} GiB") + print("Expected memory usage: {:.2f} MiB".format(expected_memory_usage_MiB)) + print(f"Peak memory usage: {peak_memory_MiB:.2f} MiB") if __name__ == "__main__": From 59e82a13b590809abfe8803111e4e584e5f155e2 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 4 Jul 2023 20:49:52 +0200 Subject: [PATCH 03/57] merging --- .../core/binaryrecordingextractor.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 5b588a31a2..2cbc52f4b4 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -175,9 +175,9 @@ def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, self.time_axis = time_axis self.file_path = file_path self.file = open(self.file_path, "rb") - self.elements_per_sample = self.num_chan * self.dtype.itemsize + self.bytes_per_sample = self.num_channels * self.dtype.itemsize self.data_size_in_bytes = Path(file_path).stat().st_size - file_offset - self.num_samples = self.data_size_in_bytes // self.elements_per_sample + self.num_samples = self.data_size_in_bytes // self.bytes_per_sample def get_num_samples(self) -> int: """Returns the number of samples in this signal block @@ -202,12 +202,9 @@ def get_traces( if end_frame > self.get_num_samples(): raise ValueError(f"end_frame {end_frame} is larger than the number of samples {self.get_num_samples()}") - dtype_size_bytes = np.dtype(self.dtype).itemsize - elements_per_sample = self.num_chan * dtype_size_bytes - # Calculate byte offsets for start and end frames - start_byte = self.file_offset + start_frame * elements_per_sample - end_byte = self.file_offset + end_frame * elements_per_sample + start_byte = self.file_offset + start_frame * self.bytes_per_sample + end_byte = self.file_offset + end_frame * self.bytes_per_sample # Calculate the length of the data chunk to load into memory length = end_byte - start_byte @@ -216,7 +213,8 @@ def get_traces( memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) memmap_offset *= mmap.ALLOCATIONGRANULARITY - # Adjust the length so it includes the extra data from rounding down the memmap offset to a multiple of ALLOCATIONGRANULARITY + # Adjust the length so it includes the extra data from rounding down + # the memmap offset to a multiple of ALLOCATIONGRANULARITY length += start_offset # Create the mmap object @@ -225,11 +223,12 @@ def get_traces( # Create a numpy array using the mmap object as the buffer # Note that the shape must be recalculated based on the new data chunk if self.time_axis == 0: - shape = ((end_frame - start_frame), self.num_chan) + shape = ((end_frame - start_frame), self.num_channels) else: - shape = (self.num_chan, (end_frame - start_frame)) + shape = (self.num_channels, (end_frame - start_frame)) - array = np.ndarray( + # Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly + traces = np.ndarray( shape=shape, dtype=self.dtype, buffer=memmap_obj, @@ -237,10 +236,7 @@ def get_traces( ) if self.time_axis == 1: - array = array.T - - # Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly - traces = array + traces = traces.T if channel_indices is not None: traces = traces[:, channel_indices] From 4d4c55e140014ea79947910883f3ef360fe0f723 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 4 Jul 2023 21:03:53 +0200 Subject: [PATCH 04/57] refactor tests --- .../tests/test_binaryrecordingextractor.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 16cd01141e..a73cc8d1f4 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -1,6 +1,8 @@ import pytest import numpy as np from pathlib import Path +import sys +import resource from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording @@ -91,6 +93,7 @@ def folder_with_binary_files(tmpdir_factory): def test_sequential_reading_of_small_traces(folder_with_binary_files): + # Test that memmap is readed correctly when pointing to specific frames folder = folder_with_binary_files num_channels = 32 sampling_frequency = 30_000.0 @@ -151,6 +154,15 @@ def test_memory_effcienty(folder_with_binary_files): expected_memory_usage = memory_before_traces_bytes + traces_size_bytes expected_memory_usage_GiB = expected_memory_usage / 1024**3 memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 + + ratio = memory_after_traces_bytes_GiB / expected_memory_usage_GiB + + assertion_msg = ( + f"Peak memory {memory_after_traces_bytes_GiB} GiB usage is {ratio:.2f} times" + f"the expected memory usage of {expected_memory_usage_GiB} GiB." + ) + assert ratio <= 1.05, assertion_msg + assert memory_after_traces_bytes_GiB == pytest.approx(expected_memory_usage_GiB, rel=0.1) @@ -175,9 +187,6 @@ def measure_peak_memory_usage(): If the function is called on a Windows system. """ - import sys - import resource - if sys.platform == "win32": raise NotImplementedError("Function cannot be used on Windows") @@ -190,6 +199,7 @@ def measure_peak_memory_usage(): return mem_usage +@pytest.mark.skipif(sys.platform == "win32", reason="Don't know how to calculate peak memory on widnows") def test_peak_memory_usage(folder_with_binary_files): "This tests that there are no spikes in memory usage when reading traces." folder = folder_with_binary_files @@ -212,10 +222,12 @@ def test_peak_memory_usage(folder_with_binary_files): expected_memory_usage = memory_before_traces_bytes + traces_size_bytes peak_memory_MiB = measure_peak_memory_usage() / 1024**2 expected_memory_usage_MiB = expected_memory_usage / 1024**2 - assert expected_memory_usage_MiB == pytest.approx(peak_memory_MiB, rel=0.1) - - print("Expected memory usage: {:.2f} MiB".format(expected_memory_usage_MiB)) - print(f"Peak memory usage: {peak_memory_MiB:.2f} MiB") + ratio = peak_memory_MiB / expected_memory_usage_MiB + assertion_msg = ( + f"Peak memory {peak_memory_MiB} MiB usage is {ratio:.2f} times" + f"the expected memory usage of {expected_memory_usage_MiB} MiB." + ) + assert ratio <= 1.05, assertion_msg if __name__ == "__main__": From 2ea7f1bde6ec38c43f1a747ca43f4765a49cb97a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 4 Jul 2023 21:06:44 +0200 Subject: [PATCH 05/57] window import --- .../core/tests/test_binaryrecordingextractor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index a73cc8d1f4..a62974c833 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -2,7 +2,6 @@ import numpy as np from pathlib import Path import sys -import resource from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording @@ -188,7 +187,9 @@ def measure_peak_memory_usage(): """ if sys.platform == "win32": - raise NotImplementedError("Function cannot be used on Windows") + raise NotImplementedError("Resource module not available on Windows") + + import resource mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -199,7 +200,7 @@ def measure_peak_memory_usage(): return mem_usage -@pytest.mark.skipif(sys.platform == "win32", reason="Don't know how to calculate peak memory on widnows") +@pytest.mark.skipif(sys.platform == "win32", reason="resource module not available on Windows") def test_peak_memory_usage(folder_with_binary_files): "This tests that there are no spikes in memory usage when reading traces." folder = folder_with_binary_files From 80a41525b1bb583df8a05a34c50baaab1529f63d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 11 Sep 2023 11:56:12 +0200 Subject: [PATCH 06/57] update generator recording --- .../core/tests/test_binaryrecordingextractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index a62974c833..6f63dad576 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -6,7 +6,7 @@ from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording from spikeinterface.core.core_tools import measure_memory_allocation -from spikeinterface.core.generate import GeneratorRecording +from spikeinterface.core.generate import NoiseGeneratorRecording if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -79,7 +79,7 @@ def folder_with_binary_files(tmpdir_factory): num_channels = 32 sampling_frequency = 30_000.0 dtype = "float32" - recording = GeneratorRecording( + recording = NoiseGeneratorRecording( durations=[1.0], sampling_frequency=sampling_frequency, num_channels=num_channels, From 25b125ece7c465f74e0cd8e100278604d8b9d124 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 11 Sep 2023 11:58:46 +0200 Subject: [PATCH 07/57] add missing import --- src/spikeinterface/core/tests/test_generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index abcbd9c4e2..2b4e65980f 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -18,6 +18,7 @@ from spikeinterface.core.core_tools import convert_bytes_to_str, measure_memory_allocation +from spikeinterface.core.testing import check_recordings_equal strategy_list = ["tile_pregenerated", "on_the_fly"] From 757e939b8c3e45d12c952d073c9e22215453b9cc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 14 Dec 2023 15:33:25 +0100 Subject: [PATCH 08/57] maybe tests --- .../tests/test_binaryrecordingextractor.py | 198 +++++++++--------- 1 file changed, 99 insertions(+), 99 deletions(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 6f63dad576..03b1927b33 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -130,105 +130,105 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files): assert np.allclose(small_traces, expected_traces) -def test_memory_effcienty(folder_with_binary_files): - "This test that memory is freed afte reading the traces" - folder = folder_with_binary_files - num_channels = 32 - sampling_frequency = 30_000.0 - dtype = "float32" - - file_paths = [folder / "traces_cached_seg0.raw"] - recording = BinaryRecordingExtractor( - num_chan=num_channels, - file_paths=file_paths, - sampling_frequency=sampling_frequency, - dtype=dtype, - ) - - memory_before_traces_bytes = measure_memory_allocation() - traces = recording.get_traces(start_frame=1000, end_frame=10_000) - memory_after_traces_bytes = measure_memory_allocation() - traces_size_bytes = traces.nbytes - - expected_memory_usage = memory_before_traces_bytes + traces_size_bytes - expected_memory_usage_GiB = expected_memory_usage / 1024**3 - memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 - - ratio = memory_after_traces_bytes_GiB / expected_memory_usage_GiB - - assertion_msg = ( - f"Peak memory {memory_after_traces_bytes_GiB} GiB usage is {ratio:.2f} times" - f"the expected memory usage of {expected_memory_usage_GiB} GiB." - ) - assert ratio <= 1.05, assertion_msg - - assert memory_after_traces_bytes_GiB == pytest.approx(expected_memory_usage_GiB, rel=0.1) - - -def measure_peak_memory_usage(): - """ - Measure the peak memory usage in bytes for the current process. - - The `resource.getrusage(resource.RUSAGE_SELF).ru_maxrss` command is used to get the peak memory usage. - The `ru_maxrss` attribute represents the maximum resident set size used (in kilobytes on Linux and bytes on MacOS), - which is the maximum memory used by the process since it was started. - - This function only works on Unix systems (including Linux and MacOS). - - Returns - ------- - int - Peak memory usage in bytes. - - Raises - ------ - NotImplementedError - If the function is called on a Windows system. - """ - - if sys.platform == "win32": - raise NotImplementedError("Resource module not available on Windows") - - import resource - - mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - - # If ru_maxrss returns memory in kilobytes (like on Linux), convert to bytes - if hasattr(resource, "RLIMIT_AS"): - mem_usage = mem_usage * 1024 - - return mem_usage - - -@pytest.mark.skipif(sys.platform == "win32", reason="resource module not available on Windows") -def test_peak_memory_usage(folder_with_binary_files): - "This tests that there are no spikes in memory usage when reading traces." - folder = folder_with_binary_files - num_channels = 32 - sampling_frequency = 30_000.0 - dtype = "float32" - - file_paths = [folder / "traces_cached_seg0.raw"] - recording = BinaryRecordingExtractor( - num_chan=num_channels, - file_paths=file_paths, - sampling_frequency=sampling_frequency, - dtype=dtype, - ) - - memory_before_traces_bytes = measure_memory_allocation() - traces = recording.get_traces(start_frame=1000, end_frame=2000) - traces_size_bytes = traces.nbytes - - expected_memory_usage = memory_before_traces_bytes + traces_size_bytes - peak_memory_MiB = measure_peak_memory_usage() / 1024**2 - expected_memory_usage_MiB = expected_memory_usage / 1024**2 - ratio = peak_memory_MiB / expected_memory_usage_MiB - assertion_msg = ( - f"Peak memory {peak_memory_MiB} MiB usage is {ratio:.2f} times" - f"the expected memory usage of {expected_memory_usage_MiB} MiB." - ) - assert ratio <= 1.05, assertion_msg +# def test_memory_effcienty(folder_with_binary_files): +# "This test that memory is freed afte reading the traces" +# folder = folder_with_binary_files +# num_channels = 32 +# sampling_frequency = 30_000.0 +# dtype = "float32" + +# file_paths = [folder / "traces_cached_seg0.raw"] +# recording = BinaryRecordingExtractor( +# num_chan=num_channels, +# file_paths=file_paths, +# sampling_frequency=sampling_frequency, +# dtype=dtype, +# ) + +# memory_before_traces_bytes = measure_memory_allocation() +# traces = recording.get_traces(start_frame=1000, end_frame=10_000) +# memory_after_traces_bytes = measure_memory_allocation() +# traces_size_bytes = traces.nbytes + +# expected_memory_usage = memory_before_traces_bytes + traces_size_bytes +# expected_memory_usage_GiB = expected_memory_usage / 1024**3 +# memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 + +# ratio = memory_after_traces_bytes_GiB / expected_memory_usage_GiB + +# assertion_msg = ( +# f"Peak memory {memory_after_traces_bytes_GiB} GiB usage is {ratio:.2f} times" +# f"the expected memory usage of {expected_memory_usage_GiB} GiB." +# ) +# assert ratio <= 1.05, assertion_msg + +# assert memory_after_traces_bytes_GiB == pytest.approx(expected_memory_usage_GiB, rel=0.1) + + +# def measure_peak_memory_usage(): +# """ +# Measure the peak memory usage in bytes for the current process. + +# The `resource.getrusage(resource.RUSAGE_SELF).ru_maxrss` command is used to get the peak memory usage. +# The `ru_maxrss` attribute represents the maximum resident set size used (in kilobytes on Linux and bytes on MacOS), +# which is the maximum memory used by the process since it was started. + +# This function only works on Unix systems (including Linux and MacOS). + +# Returns +# ------- +# int +# Peak memory usage in bytes. + +# Raises +# ------ +# NotImplementedError +# If the function is called on a Windows system. +# """ + +# if sys.platform == "win32": +# raise NotImplementedError("Resource module not available on Windows") + +# import resource + +# mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + +# # If ru_maxrss returns memory in kilobytes (like on Linux), convert to bytes +# if hasattr(resource, "RLIMIT_AS"): +# mem_usage = mem_usage * 1024 + +# return mem_usage + + +# @pytest.mark.skipif(sys.platform == "win32", reason="resource module not available on Windows") +# def test_peak_memory_usage(folder_with_binary_files): +# "This tests that there are no spikes in memory usage when reading traces." +# folder = folder_with_binary_files +# num_channels = 32 +# sampling_frequency = 30_000.0 +# dtype = "float32" + +# file_paths = [folder / "traces_cached_seg0.raw"] +# recording = BinaryRecordingExtractor( +# num_chan=num_channels, +# file_paths=file_paths, +# sampling_frequency=sampling_frequency, +# dtype=dtype, +# ) + +# memory_before_traces_bytes = measure_memory_allocation() +# traces = recording.get_traces(start_frame=1000, end_frame=2000) +# traces_size_bytes = traces.nbytes + +# expected_memory_usage = memory_before_traces_bytes + traces_size_bytes +# peak_memory_MiB = measure_peak_memory_usage() / 1024**2 +# expected_memory_usage_MiB = expected_memory_usage / 1024**2 +# ratio = peak_memory_MiB / expected_memory_usage_MiB +# assertion_msg = ( +# f"Peak memory {peak_memory_MiB} MiB usage is {ratio:.2f} times" +# f"the expected memory usage of {expected_memory_usage_MiB} MiB." +# ) +# assert ratio <= 1.05, assertion_msg if __name__ == "__main__": From ca5a50d84807b1615c1ebc5f1363efb40f94bd4c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 29 Apr 2024 17:54:22 -0600 Subject: [PATCH 09/57] add name as an extractor attribute --- src/spikeinterface/core/base.py | 2 ++ src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/generate.py | 5 +++++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d9c39e6ebd..526d851383 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -50,6 +50,8 @@ class BaseExtractor: installed = True installation_mesg = "" + name = None # The name of the extractor for display purposes + def __init__(self, main_ids: Sequence) -> None: # store init kwargs for nested serialisation self._kwargs = {} diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8ea8e04246..e8c3a046be 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -48,7 +48,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): self.annotate(is_filtered=False) def __repr__(self): - extractor_name = self.__class__.__name__ + extractor_name = self.__class__.__name__ if self.name is None else self.name num_segments = self.get_num_segments() num_channels = self.get_num_channels() sf_khz = self.get_sampling_frequency() / 1000.0 diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 4e3551e290..a35eb42eb0 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -30,7 +30,7 @@ def __init__(self, sampling_frequency: float, unit_ids: List): self._cached_spike_trains = {} def __repr__(self): - clsname = self.__class__.__name__ + clsname = self.__class__.__name__ if self.name is None else self.name nseg = self.get_num_segments() nunits = self.get_num_units() sf_khz = self.get_sampling_frequency() / 1000.0 diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ec76fcbaa9..91c61ac78e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -81,6 +81,8 @@ def generate_recording( probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) + recording.name = "SyntethicRecording" + return recording @@ -2130,4 +2132,7 @@ def generate_ground_truth_recording( recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) + recording.name = "GroundTruthRecording" + sorting.name = "GroundTruthSorting" + return recording, sorting From 64a8432aaab985a7ae89addaacd46114f26f55e7 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 30 Apr 2024 08:36:23 -0600 Subject: [PATCH 10/57] Update src/spikeinterface/core/generate.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 91c61ac78e..267e06d07a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -81,7 +81,7 @@ def generate_recording( probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) - recording.name = "SyntethicRecording" + recording.name = "SyntheticRecording" return recording From 091042052165e5292b70c934b6da190bcb413b21 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 6 Jun 2024 12:05:14 -0600 Subject: [PATCH 11/57] remove unused tests --- .../tests/test_binaryrecordingextractor.py | 101 ------------------ 1 file changed, 101 deletions(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 30435df820..b0fab7a579 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -125,106 +125,5 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files): assert np.allclose(small_traces, expected_traces) -# def test_memory_effcienty(folder_with_binary_files): -# "This test that memory is freed afte reading the traces" -# folder = folder_with_binary_files -# num_channels = 32 -# sampling_frequency = 30_000.0 -# dtype = "float32" - -# file_paths = [folder / "traces_cached_seg0.raw"] -# recording = BinaryRecordingExtractor( -# num_chan=num_channels, -# file_paths=file_paths, -# sampling_frequency=sampling_frequency, -# dtype=dtype, -# ) - -# memory_before_traces_bytes = measure_memory_allocation() -# traces = recording.get_traces(start_frame=1000, end_frame=10_000) -# memory_after_traces_bytes = measure_memory_allocation() -# traces_size_bytes = traces.nbytes - -# expected_memory_usage = memory_before_traces_bytes + traces_size_bytes -# expected_memory_usage_GiB = expected_memory_usage / 1024**3 -# memory_after_traces_bytes_GiB = memory_after_traces_bytes / 1024**3 - -# ratio = memory_after_traces_bytes_GiB / expected_memory_usage_GiB - -# assertion_msg = ( -# f"Peak memory {memory_after_traces_bytes_GiB} GiB usage is {ratio:.2f} times" -# f"the expected memory usage of {expected_memory_usage_GiB} GiB." -# ) -# assert ratio <= 1.05, assertion_msg - -# assert memory_after_traces_bytes_GiB == pytest.approx(expected_memory_usage_GiB, rel=0.1) - - -# def measure_peak_memory_usage(): -# """ -# Measure the peak memory usage in bytes for the current process. - -# The `resource.getrusage(resource.RUSAGE_SELF).ru_maxrss` command is used to get the peak memory usage. -# The `ru_maxrss` attribute represents the maximum resident set size used (in kilobytes on Linux and bytes on MacOS), -# which is the maximum memory used by the process since it was started. - -# This function only works on Unix systems (including Linux and MacOS). - -# Returns -# ------- -# int -# Peak memory usage in bytes. - -# Raises -# ------ -# NotImplementedError -# If the function is called on a Windows system. -# """ - -# if sys.platform == "win32": -# raise NotImplementedError("Resource module not available on Windows") - -# import resource - -# mem_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - -# # If ru_maxrss returns memory in kilobytes (like on Linux), convert to bytes -# if hasattr(resource, "RLIMIT_AS"): -# mem_usage = mem_usage * 1024 - -# return mem_usage - - -# @pytest.mark.skipif(sys.platform == "win32", reason="resource module not available on Windows") -# def test_peak_memory_usage(folder_with_binary_files): -# "This tests that there are no spikes in memory usage when reading traces." -# folder = folder_with_binary_files -# num_channels = 32 -# sampling_frequency = 30_000.0 -# dtype = "float32" - -# file_paths = [folder / "traces_cached_seg0.raw"] -# recording = BinaryRecordingExtractor( -# num_chan=num_channels, -# file_paths=file_paths, -# sampling_frequency=sampling_frequency, -# dtype=dtype, -# ) - -# memory_before_traces_bytes = measure_memory_allocation() -# traces = recording.get_traces(start_frame=1000, end_frame=2000) -# traces_size_bytes = traces.nbytes - -# expected_memory_usage = memory_before_traces_bytes + traces_size_bytes -# peak_memory_MiB = measure_peak_memory_usage() / 1024**2 -# expected_memory_usage_MiB = expected_memory_usage / 1024**2 -# ratio = peak_memory_MiB / expected_memory_usage_MiB -# assertion_msg = ( -# f"Peak memory {peak_memory_MiB} MiB usage is {ratio:.2f} times" -# f"the expected memory usage of {expected_memory_usage_MiB} MiB." -# ) -# assert ratio <= 1.05, assertion_msg - - if __name__ == "__main__": test_BinaryRecordingExtractor() From 3784c3898ee521b057bc3962fb66e2eecc561cfe Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 15:08:26 -0600 Subject: [PATCH 12/57] use the value instead of book --- src/spikeinterface/core/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6fbc5ac289..05f59e5349 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -56,7 +56,7 @@ def __init__(self, main_ids: Sequence) -> None: # "main_ids" will either be channel_ids or units_ids # They are used for properties - self._main_ids = np.array(main_ids) + self._main_ids = np.asarray(main_ids) if len(self._main_ids) > 0: assert ( self._main_ids.dtype.kind in "uiSU" @@ -128,8 +128,18 @@ def ids_to_indices( indices = np.arange(len(self._main_ids)) else: assert isinstance(ids, (list, np.ndarray, tuple)), "'ids' must be a list, np.ndarray or tuple" + + non_existent_ids = [id for id in ids if id not in self._main_ids] + if non_existent_ids: + error_msg = ( + f"IDs {non_existent_ids} are not channel ids of the extractor. \n" + f"Available ids are {self._main_ids} with dtype {self._main_ids.dtype}" + ) + raise ValueError(error_msg) + _main_ids = self._main_ids.tolist() - indices = np.array([_main_ids.index(id) for id in ids], dtype=int) + indices = np.array([_main_ids.index(id) for id in ids], dtype=np.int) + if prefer_slice: if np.all(np.diff(indices) == 1): indices = slice(indices[0], indices[-1] + 1) From e2b1a3b734245e696b8e83d73354a62161f8fcff Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jun 2024 15:22:41 -0600 Subject: [PATCH 13/57] int went flying --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 05f59e5349..a7b250690f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -138,7 +138,7 @@ def ids_to_indices( raise ValueError(error_msg) _main_ids = self._main_ids.tolist() - indices = np.array([_main_ids.index(id) for id in ids], dtype=np.int) + indices = np.array([_main_ids.index(id) for id in ids], dtype=int) if prefer_slice: if np.all(np.diff(indices) == 1): From 16d40899d3d8dbc925d00cd034da8fb93af47946 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 1 Jul 2024 20:41:20 +0100 Subject: [PATCH 14/57] Fix t_starts not propagated to save memory. --- src/spikeinterface/core/baserecording.py | 4 ++-- src/spikeinterface/core/numpyextractors.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index aab7577b31..bb96fb06ca 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -545,11 +545,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, **job_kwargs) + cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs) else: from spikeinterface.core import NumpyRecording - cached = NumpyRecording.from_recording(self, **job_kwargs) + cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 0ba1c05417..b60ecb52a6 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N } @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array @@ -95,13 +95,14 @@ def from_recording(source_recording, **job_kwargs): for shm in shms: shm.close() shm.unlink() - # TODO later : propagte t_starts ? + recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -211,18 +212,16 @@ def __del__(self): shm.unlink() @staticmethod - def from_recording(source_recording, **job_kwargs): + def from_recording(source_recording, t_starts=None, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) - # TODO later : propagte t_starts ? - recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], shape_list=[traces.shape for traces in traces_list], dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, ) From 3e9652b7c64cfe112ec549bd340d55ad95d97720 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Tue, 2 Jul 2024 09:43:10 +0100 Subject: [PATCH 15/57] force tests From c95c4357126eab2619cdffe5826409689638df0c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 2 Jul 2024 07:12:51 -0600 Subject: [PATCH 16/57] Update src/spikeinterface/core/base.py --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index a7b250690f..dcb80d2a67 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -56,7 +56,7 @@ def __init__(self, main_ids: Sequence) -> None: # "main_ids" will either be channel_ids or units_ids # They are used for properties - self._main_ids = np.asarray(main_ids) + self._main_ids = np.array(main_ids) if len(self._main_ids) > 0: assert ( self._main_ids.dtype.kind in "uiSU" From f46f9c9391bdf3a11a0fef953c2eb9d097dd15ab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 11:55:08 +0200 Subject: [PATCH 17/57] Fix serializability of InjectDriftingTemplatesRecording --- src/spikeinterface/core/base.py | 11 +++++------ src/spikeinterface/core/generate.py | 2 ++ src/spikeinterface/generation/drift_tools.py | 3 +++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5800166f39..304a85e74f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -7,7 +7,6 @@ import weakref import json import pickle -import os import random import string from packaging.version import parse @@ -928,13 +927,14 @@ def save_to_folder( folder.mkdir(parents=True, exist_ok=False) # dump provenance - provenance_file = folder / f"provenance.json" if self.check_serializability("json"): + provenance_file = folder / f"provenance.json" + self.dump(provenance_file) + elif self.check_serializability("pickle"): + provenance_file = folder / f"provenance.pkl" self.dump(provenance_file) else: - provenance_file.write_text( - json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" - ) + warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") self.save_metadata_to_folder(folder) @@ -1001,7 +1001,6 @@ def save_to_zarr( cached: ZarrExtractor Saved copy of the extractor. """ - import zarr from .zarrextractors import read_zarr save_kwargs.pop("format", None) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 11909bce0e..4e265f3766 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1738,6 +1738,8 @@ def __init__( ) self.add_recording_segment(recording_segment) + # to discuss: maybe we could set json serializability to False always + # because templates could be large! if not sorting.check_serializability("json"): self._serializability["json"] = False if parent_recording is not None: diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index cce2e08b58..70e13160f4 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -458,6 +458,9 @@ def __init__( self.set_probe(drifting_templates.probe, in_place=True) + # templates are too large, we don't serialize them to JSON + self._serializability["json"] = False + self._kwargs = { "sorting": sorting, "drifting_templates": drifting_templates, From b0b8b9aac2e480e102bbdc4980955d28778bd919 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 12:55:58 +0200 Subject: [PATCH 18/57] Fix select peaks --- .../sortingcomponents/peak_selection.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index 1ccfbc4d22..fed026b6a7 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -76,19 +76,18 @@ def select_peaks( selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs) selected_peaks = peaks[selected_indices] + num_segments = len(np.unique(selected_peaks["segment_index"])) if margin is not None: to_keep = np.zeros(len(selected_peaks), dtype=bool) - offset = 0 - for segment_index in range(recording.get_num_segments()): - duration = recording.get_num_frames(segment_index) + for segment_index in range(num_segments): + num_samples_in_segment = recording.get_num_samples(segment_index) i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1]) - while selected_peaks["sample_index"][i0] <= margin[0] + offset: + while selected_peaks["sample_index"][i0] <= margin[0]: i0 += 1 - while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset: + while selected_peaks["sample_index"][i1 - 1] >= (num_samples_in_segment - margin[1]): i1 -= 1 to_keep[i0:i1] = True - offset += duration selected_indices = selected_indices[to_keep] selected_peaks = peaks[selected_indices] @@ -284,7 +283,9 @@ def select_peak_indices(peaks, method, seed, **method_kwargs): ) selected_indices = np.concatenate(selected_indices) - selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])] + selected_indices = selected_indices[ + np.lexsort((peaks[selected_indices]["sample_index"], peaks[selected_indices]["segment_index"])) + ] return selected_indices From 6501252bbea1ed8f6f3fa955ca89f2363f08c409 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 3 Jul 2024 19:57:11 +0200 Subject: [PATCH 19/57] Revert "Fix select peaks" This reverts commit b0b8b9aac2e480e102bbdc4980955d28778bd919. --- .../sortingcomponents/peak_selection.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_selection.py b/src/spikeinterface/sortingcomponents/peak_selection.py index fed026b6a7..1ccfbc4d22 100644 --- a/src/spikeinterface/sortingcomponents/peak_selection.py +++ b/src/spikeinterface/sortingcomponents/peak_selection.py @@ -76,18 +76,19 @@ def select_peaks( selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs) selected_peaks = peaks[selected_indices] - num_segments = len(np.unique(selected_peaks["segment_index"])) if margin is not None: to_keep = np.zeros(len(selected_peaks), dtype=bool) - for segment_index in range(num_segments): - num_samples_in_segment = recording.get_num_samples(segment_index) + offset = 0 + for segment_index in range(recording.get_num_segments()): + duration = recording.get_num_frames(segment_index) i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1]) - while selected_peaks["sample_index"][i0] <= margin[0]: + while selected_peaks["sample_index"][i0] <= margin[0] + offset: i0 += 1 - while selected_peaks["sample_index"][i1 - 1] >= (num_samples_in_segment - margin[1]): + while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset: i1 -= 1 to_keep[i0:i1] = True + offset += duration selected_indices = selected_indices[to_keep] selected_peaks = peaks[selected_indices] @@ -283,9 +284,7 @@ def select_peak_indices(peaks, method, seed, **method_kwargs): ) selected_indices = np.concatenate(selected_indices) - selected_indices = selected_indices[ - np.lexsort((peaks[selected_indices]["sample_index"], peaks[selected_indices]["segment_index"])) - ] + selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])] return selected_indices From f572bfd0a6511e58813ee0fd37d1aff93e4d4b9f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 4 Jul 2024 12:19:21 +0200 Subject: [PATCH 20/57] Add option to use ref_channel_ids in global common reference --- .../preprocessing/common_reference.py | 28 +++++++++++-------- .../tests/test_common_reference.py | 4 ++- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index bc8ecb4cb7..93d0448ef4 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be re-referenced reference : "global" | "single" | "local", default: "global" - If "global" the reference is the average or median across all the channels. + If "global" the reference is the average or median across all the channels. To select a subset of channels, + you can use the `ref_channel_ids` parameter. If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`. If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter. operator : "median" | "average", default: "median" @@ -51,10 +52,10 @@ class CommonReferenceRecording(BasePreprocessor): List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to single channels are applied group-wise. However, this is not applied for the local CAR. It is useful when dealing with different channel groups, e.g. multiple tetrodes. - ref_channel_ids : list or str or int, default: None - If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a - list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an - int is expected. + ref_channel_ids : list | str | int | None, default: None + If "global" reference, a list of channels to be used as reference. + If "single" reference, a list of one channel or a single channel id is expected. + If "groups" is provided, then a list of channels to be applied to each group is expected. local_radius : tuple(int, int), default: (30, 55) Use in the local CAR implementation as the selecting annulus with the following format: @@ -82,10 +83,10 @@ def __init__( recording: BaseRecording, reference: Literal["global", "single", "local"] = "global", operator: Literal["median", "average"] = "median", - groups=None, - ref_channel_ids=None, - local_radius=(30, 55), - dtype=None, + groups: list | None = None, + ref_channel_ids: list | str | int | None = None, + local_radius: tuple[float, float] = (30.0, 55.0), + dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None @@ -96,7 +97,9 @@ def __init__( raise ValueError("'operator' must be either 'median', 'average'") if reference == "global": - pass + if ref_channel_ids is not None: + if not isinstance(ref_channel_ids, list): + raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) if self.reference == "global": - shift = self.operator_func(traces, axis=1, keepdims=True) + if self.ref_channel_indices is None: + shift = self.operator_func(traces, axis=1, keepdims=True) + else: + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift elif self.reference == "single": # single channel -> no need of operator diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 1df9b21c81..8b37e7f4b9 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -11,7 +11,7 @@ def _generate_test_recording(): recording = generate_recording(durations=[1.0], num_channels=4) - recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"])) + recording = recording.rename_channels(np.array(["a", "b", "c", "d"])) return recording @@ -23,12 +23,14 @@ def recording(): def test_common_reference(recording): # Test simple case rec_cmr = common_reference(recording, reference="global", operator="median") + rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"]) rec_car = common_reference(recording, reference="global", operator="average") rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"]) rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") traces = recording.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) + assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01) assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01) assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0]) From de3153179bb385a0a40cec4b2a983fb34da6f3ec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 4 Jul 2024 17:49:31 +0200 Subject: [PATCH 21/57] analyse -> analyze neuropixels docs --- ...se_neuropixels.rst => analyze_neuropixels.rst} | 14 +++++++------- .../analyze_neuropixels_13_0.png} | Bin .../analyze_neuropixels_14_1.png} | Bin .../analyze_neuropixels_21_1.png} | Bin .../analyze_neuropixels_26_1.png} | Bin .../analyze_neuropixels_27_1.png} | Bin .../analyze_neuropixels_8_1.png} | Bin 7 files changed, 7 insertions(+), 7 deletions(-) rename doc/how_to/{analyse_neuropixels.rst => analyze_neuropixels.rst} (98%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_13_0.png => analyze_neuropixels_files/analyze_neuropixels_13_0.png} (100%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_14_1.png => analyze_neuropixels_files/analyze_neuropixels_14_1.png} (100%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_21_1.png => analyze_neuropixels_files/analyze_neuropixels_21_1.png} (100%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_26_1.png => analyze_neuropixels_files/analyze_neuropixels_26_1.png} (100%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_27_1.png => analyze_neuropixels_files/analyze_neuropixels_27_1.png} (100%) rename doc/how_to/{analyse_neuropixels_files/analyse_neuropixels_8_1.png => analyze_neuropixels_files/analyze_neuropixels_8_1.png} (100%) diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyze_neuropixels.rst similarity index 98% rename from doc/how_to/analyse_neuropixels.rst rename to doc/how_to/analyze_neuropixels.rst index 02e497b0fe..1fe741ea48 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyze_neuropixels.rst @@ -1,4 +1,4 @@ -Analyse Neuropixels datasets +Analyze Neuropixels datasets ============================ This example shows how to perform Neuropixels-specific analysis, @@ -218,7 +218,7 @@ We need to specify which one to read: -.. image:: analyse_neuropixels_files/analyse_neuropixels_8_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_8_1.png Preprocess the recording @@ -286,7 +286,7 @@ is lazy, so you can change the previsous cell (parameters, step order, -.. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_13_0.png .. code:: ipython3 @@ -306,7 +306,7 @@ is lazy, so you can change the previsous cell (parameters, step order, -.. image:: analyse_neuropixels_files/analyse_neuropixels_14_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_14_1.png Should we save the preprocessed data to a binary file? @@ -389,7 +389,7 @@ Noise levels can be estimated on the scaled traces or on the raw -.. image:: analyse_neuropixels_files/analyse_neuropixels_21_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_21_1.png Detect and localize peaks @@ -480,7 +480,7 @@ documentation for motion estimation and correction for more details. -.. image:: analyse_neuropixels_files/analyse_neuropixels_26_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_26_1.png .. code:: ipython3 @@ -502,7 +502,7 @@ documentation for motion estimation and correction for more details. -.. image:: analyse_neuropixels_files/analyse_neuropixels_27_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_27_1.png Run a spike sorter diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_13_0.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_13_0.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_13_0.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_13_0.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_14_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_14_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_14_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_14_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_21_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_21_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_21_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_21_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_26_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_26_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_26_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_26_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_27_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_27_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_27_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_27_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_8_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_8_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_8_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_8_1.png From 6f462505ea7c6ec6d6e0f2a7bd4051fe554c097a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 4 Jul 2024 17:54:53 +0200 Subject: [PATCH 22/57] and index.rst --- doc/how_to/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 66dd9b417c..7127c4faf0 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -8,7 +8,7 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. viewers handle_drift - analyse_neuropixels + analyze_neuropixels load_matlab_data combine_recordings process_by_channel_group From 170db7f90e4b5ba5a1b9f22bc01e8dabfdfbd6e0 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 4 Jul 2024 15:28:54 -0600 Subject: [PATCH 23/57] better closing conditions for nwb --- .../extractors/nwbextractors.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index ccb2ff4370..d213126f34 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -626,13 +626,29 @@ def __init__( "file": file, } + def _close_hdf5_file(self): + has_hdf5_backend = hasattr(self, "_file") + if has_hdf5_backend: + import h5py + + main_file_id = self._file.id + open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) + for object_id in open_object_ids_main: + object_name = h5py.h5i.get_name(object_id).decode("utf-8") + try: + object_id.close() + except: + import warnings + + warnings.warn(f"Error closing object {object_name}") + def __del__(self): # backend mode if hasattr(self, "_file"): if hasattr(self._file, "store"): self._file.store.close() else: - self._file.close() + self._close_hdf5_file() # pynwb mode elif hasattr(self, "_nwbfile"): io = self._nwbfile.get_read_io() @@ -1111,19 +1127,41 @@ def __init__( "t_start": self.t_start, } + def _close_hdf5_file(self): + has_hdf5_backend = hasattr(self, "_file") + if has_hdf5_backend: + import h5py + + main_file_id = self._file.id + open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) + for object_id in open_object_ids_main: + object_name = h5py.h5i.get_name(object_id).decode("utf-8") + try: + object_id.close() + except: + import warnings + + warnings.warn(f"Error closing object {object_name}") + def __del__(self): # backend mode if hasattr(self, "_file"): if hasattr(self._file, "store"): self._file.store.close() else: - self._file.close() + self._close_hdf5_file() # pynwb mode elif hasattr(self, "_nwbfile"): io = self._nwbfile.get_read_io() if io is not None: io.close() + # pynwb mode + elif hasattr(self, "_nwbfile"): # hdf + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + def _fetch_sorting_segment_info_pynwb( self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False ): From b0dd1b257658394512a2f2f58e71e2e283fed675 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 4 Jul 2024 15:59:01 -0600 Subject: [PATCH 24/57] more explicit concistency checks with error messages describing the error in channel aggregation --- .../core/channelsaggregationrecording.py | 81 ++++++++++++------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index b8735dff3c..ddd73cd3c8 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -15,10 +15,15 @@ class ChannelsAggregationRecording(BaseRecording): def __init__(self, recording_list, renamed_channel_ids=None): + self._recordings = recording_list + + self._perform_consistency_checks() + sampling_frequency = self.recording_list[0].get_sampling_frequency() + dtype = self.recording_list[0].get_dtype() + num_segments = recording_list[0].get_num_segments() + # Generate a default list of channel ids that are unique and consecutive numbers as strings. - channel_map = {} num_all_channels = sum(rec.get_num_channels() for rec in recording_list) - if renamed_channel_ids is not None: assert ( len(np.unique(renamed_channel_ids)) == num_all_channels @@ -39,33 +44,6 @@ def __init__(self, recording_list, renamed_channel_ids=None): default_channel_ids = [str(i) for i in range(num_all_channels)] channel_ids = default_channel_ids - ch_id = 0 - for r_i, recording in enumerate(recording_list): - single_channel_ids = recording.get_channel_ids() - single_channel_indices = recording.ids_to_indices(single_channel_ids) - for chan_id, chan_idx in zip(single_channel_ids, single_channel_indices): - channel_map[ch_id] = {"recording_id": r_i, "channel_index": chan_idx} - ch_id += 1 - - sampling_frequency = recording_list[0].get_sampling_frequency() - num_segments = recording_list[0].get_num_segments() - dtype = recording_list[0].get_dtype() - - ok1 = all(sampling_frequency == rec.get_sampling_frequency() for rec in recording_list) - ok2 = all(num_segments == rec.get_num_segments() for rec in recording_list) - ok3 = all(dtype == rec.get_dtype() for rec in recording_list) - ok4 = True - for i_seg in range(num_segments): - num_samples = recording_list[0].get_num_samples(i_seg) - ok4 = all(num_samples == rec.get_num_samples(i_seg) for rec in recording_list) - if not ok4: - break - - if not (ok1 and ok2 and ok3 and ok4): - raise ValueError( - "Recordings do not have consistent sampling frequency, number of segments, data type, or number of samples." - ) - BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) property_keys = recording_list[0].get_property_keys() @@ -99,19 +77,60 @@ def __init__(self, recording_list, renamed_channel_ids=None): "Locations are not unique! " "Cannot aggregate recordings!" ) - # finally add segments + # finally add segments, we need a channel mapping + ch_id = 0 + channel_map = {} + for r_i, recording in enumerate(recording_list): + single_channel_ids = recording.get_channel_ids() + single_channel_indices = recording.ids_to_indices(single_channel_ids) + for chan_id, chan_idx in zip(single_channel_ids, single_channel_indices): + channel_map[ch_id] = {"recording_id": r_i, "channel_index": chan_idx} + ch_id += 1 + for i_seg in range(num_segments): parent_segments = [rec._recording_segments[i_seg] for rec in recording_list] sub_segment = ChannelsAggregationRecordingSegment(channel_map, parent_segments) self.add_recording_segment(sub_segment) - self._recordings = recording_list self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids} @property def recordings(self): return self._recordings + def _perform_consistency_checks(self): + + # Check for consistent sampling frequency across recordings + sampling_frequencies = [rec.get_sampling_frequency() for rec in self.recording_list] + sampling_frequency = sampling_frequencies[0] + consistent_sampling_frequency = all(sampling_frequency == sf for sf in sampling_frequencies) + if not consistent_sampling_frequency: + raise ValueError(f"Inconsistent sampling frequency among recordings: {sampling_frequencies}") + + # Check for consistent number of segments across recordings + num_segments_list = [rec.get_num_segments() for rec in self.recording_list] + num_segments = num_segments_list[0] + consistent_num_segments = all(num_segments == ns for ns in num_segments_list) + if not consistent_num_segments: + raise ValueError(f"Inconsistent number of segments among recordings: {num_segments_list}") + + # Check for consistent data type across recordings + data_types = [rec.get_dtype() for rec in self.recording_list] + dtype = data_types[0] + consistent_dtype = all(dtype == dt for dt in data_types) + if not consistent_dtype: + raise ValueError(f"Inconsistent data type among recordings: {data_types}") + + # Check for consistent number of samples across recordings for each segment + for segment_index in range(num_segments): + num_samples_list = [rec.get_num_samples(segment_index=segment_index) for rec in self.recording_list] + num_samples = num_samples_list[0] + consistent_num_samples = all(num_samples == ns for ns in num_samples_list) + if not consistent_num_samples: + raise ValueError( + f"Inconsistent number of samples in segment {segment_index} among recordings: {num_samples_list}" + ) + class ChannelsAggregationRecordingSegment(BaseRecordingSegment): """ From 6645dbba97f638703b7d544154af7915af1651e3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 4 Jul 2024 16:06:32 -0600 Subject: [PATCH 25/57] fix error --- .../core/channelsaggregationrecording.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index ddd73cd3c8..820b4fcd91 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -18,8 +18,8 @@ def __init__(self, recording_list, renamed_channel_ids=None): self._recordings = recording_list self._perform_consistency_checks() - sampling_frequency = self.recording_list[0].get_sampling_frequency() - dtype = self.recording_list[0].get_dtype() + sampling_frequency = recording_list[0].get_sampling_frequency() + dtype = recording_list[0].get_dtype() num_segments = recording_list[0].get_num_segments() # Generate a default list of channel ids that are unique and consecutive numbers as strings. @@ -101,21 +101,21 @@ def recordings(self): def _perform_consistency_checks(self): # Check for consistent sampling frequency across recordings - sampling_frequencies = [rec.get_sampling_frequency() for rec in self.recording_list] + sampling_frequencies = [rec.get_sampling_frequency() for rec in self.recordings] sampling_frequency = sampling_frequencies[0] consistent_sampling_frequency = all(sampling_frequency == sf for sf in sampling_frequencies) if not consistent_sampling_frequency: raise ValueError(f"Inconsistent sampling frequency among recordings: {sampling_frequencies}") # Check for consistent number of segments across recordings - num_segments_list = [rec.get_num_segments() for rec in self.recording_list] + num_segments_list = [rec.get_num_segments() for rec in self.recordings] num_segments = num_segments_list[0] consistent_num_segments = all(num_segments == ns for ns in num_segments_list) if not consistent_num_segments: raise ValueError(f"Inconsistent number of segments among recordings: {num_segments_list}") # Check for consistent data type across recordings - data_types = [rec.get_dtype() for rec in self.recording_list] + data_types = [rec.get_dtype() for rec in self.recordings] dtype = data_types[0] consistent_dtype = all(dtype == dt for dt in data_types) if not consistent_dtype: @@ -123,7 +123,7 @@ def _perform_consistency_checks(self): # Check for consistent number of samples across recordings for each segment for segment_index in range(num_segments): - num_samples_list = [rec.get_num_samples(segment_index=segment_index) for rec in self.recording_list] + num_samples_list = [rec.get_num_samples(segment_index=segment_index) for rec in self.recordings] num_samples = num_samples_list[0] consistent_num_samples = all(num_samples == ns for ns in num_samples_list) if not consistent_num_samples: From a97a1f056fd2eb95c75bbfb97535e7c698a6981f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 5 Jul 2024 09:56:44 +0200 Subject: [PATCH 26/57] Remove unused legacy class name/mode from extractors --- src/spikeinterface/core/binaryfolder.py | 3 -- .../core/binaryrecordingextractor.py | 3 -- .../core/npzsortingextractor.py | 3 -- src/spikeinterface/core/numpyextractors.py | 8 --- src/spikeinterface/core/zarrextractors.py | 12 ----- .../extractors/alfsortingextractor.py | 3 +- src/spikeinterface/extractors/cbin_ibl.py | 4 +- .../cellexplorersortingextractor.py | 1 - .../extractors/combinatoextractors.py | 14 ++---- .../extractors/extractorlist.py | 49 ------------------- .../extractors/hdsortextractors.py | 3 -- .../extractors/herdingspikesextractors.py | 19 ++----- .../extractors/iblextractors.py | 3 -- .../extractors/klustaextractors.py | 20 ++------ .../extractors/mclustextractors.py | 2 - .../extractors/mcsh5extractors.py | 8 +-- .../extractors/mdaextractors.py | 6 --- .../extractors/neoextractors/alphaomega.py | 2 - .../extractors/neoextractors/axona.py | 2 - .../extractors/neoextractors/biocam.py | 4 -- .../extractors/neoextractors/blackrock.py | 4 -- .../extractors/neoextractors/ced.py | 2 - .../extractors/neoextractors/edf.py | 2 - .../extractors/neoextractors/intan.py | 2 - .../extractors/neoextractors/maxwell.py | 4 -- .../extractors/neoextractors/mcsraw.py | 2 - .../extractors/neoextractors/mearec.py | 4 -- .../extractors/neoextractors/neuralynx.py | 4 -- .../extractors/neoextractors/neuroexplorer.py | 2 - .../extractors/neoextractors/neuroscope.py | 4 -- .../extractors/neoextractors/nix.py | 2 - .../extractors/neoextractors/openephys.py | 6 --- .../extractors/neoextractors/plexon.py | 4 -- .../extractors/neoextractors/plexon2.py | 6 --- .../extractors/neoextractors/spike2.py | 2 - .../extractors/neoextractors/spikegadgets.py | 2 - .../extractors/neoextractors/spikeglx.py | 4 -- .../extractors/neoextractors/tdt.py | 2 - .../extractors/nwbextractors.py | 4 -- .../extractors/phykilosortextractors.py | 13 +---- .../extractors/shybridextractors.py | 23 ++------- .../extractors/sinapsrecordingextractors.py | 8 --- .../extractors/spykingcircusextractors.py | 15 ++---- .../extractors/tridesclousextractors.py | 2 - .../extractors/waveclussnippetstextractors.py | 1 - .../extractors/waveclustextractors.py | 2 - .../extractors/yassextractors.py | 19 ++----- 47 files changed, 32 insertions(+), 282 deletions(-) diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index fca08d9c26..86f14faa30 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -25,9 +25,6 @@ class BinaryFolderRecording(BinaryRecordingExtractor): The recording """ - mode = "folder" - name = "binaryfolder" - def __init__(self, folder_path): folder_path = Path(folder_path) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 64c1b9b2e6..84d06a599f 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -52,9 +52,6 @@ class BinaryRecordingExtractor(BaseRecording): The recording Extractor """ - mode = "file" - name = "binary" - def __init__( self, file_paths, diff --git a/src/spikeinterface/core/npzsortingextractor.py b/src/spikeinterface/core/npzsortingextractor.py index f60dadd8ec..b8e7357e8c 100644 --- a/src/spikeinterface/core/npzsortingextractor.py +++ b/src/spikeinterface/core/npzsortingextractor.py @@ -16,9 +16,6 @@ class NpzSortingExtractor(BaseSorting): All spike are store in two columns maner index+labels """ - mode = "file" - name = "npz" - def __init__(self, file_path): self.npz_filename = file_path diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1ee472ffa4..09ba743a8c 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -37,9 +37,6 @@ class NumpyRecording(BaseRecording): An optional list of channel_ids. If None, linear channels are assumed """ - mode = "memory" - name = "numpy" - def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=None): if isinstance(traces_list, list): all_elements_are_list = all(isinstance(e, list) for e in traces_list) @@ -142,9 +139,6 @@ class SharedMemoryRecording(BaseRecording): If True, the main instance will unlink the sharedmem buffer when deleted """ - mode = "memory" - name = "SharedMemory" - def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): @@ -252,8 +246,6 @@ class NumpySorting(BaseSorting): A list of unit_ids. """ - name = "numpy" - def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4851c0eb5c..1b9637e097 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -31,13 +31,7 @@ class ZarrRecordingExtractor(BaseRecording): The recording Extractor """ - installed = True - mode = "folder" - installation_mesg = "" - name = "zarr" - def __init__(self, folder_path: Path | str, storage_options: dict | None = None): - assert self.installed, self.installation_mesg folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) @@ -167,13 +161,7 @@ class ZarrSortingExtractor(BaseSorting): The sorting Extractor """ - installed = True - mode = "folder" - installation_mesg = "" - name = "zarr" - def __init__(self, folder_path: Path | str, storage_options: dict | None = None, zarr_group: str | None = None): - assert self.installed, self.installation_mesg folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) diff --git a/src/spikeinterface/extractors/alfsortingextractor.py b/src/spikeinterface/extractors/alfsortingextractor.py index fa6490135c..f7b5401182 100644 --- a/src/spikeinterface/extractors/alfsortingextractor.py +++ b/src/spikeinterface/extractors/alfsortingextractor.py @@ -25,12 +25,11 @@ class ALFSortingExtractor(BaseSorting): """ installation_mesg = "To use the ALF extractors, install ONE-api: \n\n pip install ONE-api\n\n" - name = "alf" def __init__(self, folder_path, sampling_frequency=30000): try: import one.alf.io as alfio - except ImportError as e: + except ImportError: raise ImportError(self.installation_mesg) self._folder_path = Path(folder_path) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index a09cea9863..d7e5b58e11 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -39,16 +39,14 @@ class CompressedBinaryIblExtractor(BaseRecording): The loaded data. """ - mode = "folder" installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - name = "cbin_ibl" def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp - except: + except ImportError: raise ImportError(self.installation_mesg) if cbin_file is None: folder_path = Path(folder_path) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 736927a1ee..0dfa3a85ad 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -29,7 +29,6 @@ class CellExplorerSortingExtractor(BaseSorting): Path to the `sessionInfo.mat` file. If None, it will be inferred from the file_path. """ - mode = "file" installation_mesg = "To use the CellExplorerSortingExtractor install pymatreader" def __init__( diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 8828ea8b64..35fce3a8e3 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - class CombinatoSortingExtractor(BaseSorting): """Load Combinato format data as a sorting extractor. @@ -37,11 +30,14 @@ class CombinatoSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_H5PY installation_mesg = "To use the CombinatoSortingExtractor install h5py: \n\n pip install h5py\n\n" - name = "combinato" def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign="both", keep_good_only=True): + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) + folder_path = Path(folder_path) assert folder_path.is_dir(), "Folder {} doesn't exist".format(folder_path) if sampling_frequency is None: diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 8948aad606..e56d4fff52 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -116,52 +116,3 @@ event_extractor_full_list += neo_event_extractors_list snippets_extractor_full_list = [NpySnippetsExtractor, WaveClusSnippetsExtractor] - - -recording_extractor_full_dict = {recext.name: recext for recext in recording_extractor_full_list} -sorting_extractor_full_dict = {recext.name: recext for recext in sorting_extractor_full_list} -snippets_extractor_full_dict = {recext.name: recext for recext in snippets_extractor_full_list} - - -def get_recording_extractor_from_name(name: str) -> Type[BaseRecording]: - """ - Returns the Recording Extractor class based on its name. - - Parameters - ---------- - name: str - The Recording Extractor's name. - - Returns - ------- - recording_extractor: BaseRecording - The Recording Extractor class. - """ - - for recording_extractor in recording_extractor_full_list: - if recording_extractor.__name__ == name: - return recording_extractor - - raise ValueError(f"Recording extractor '{name}' not found.") - - -def get_sorting_extractor_from_name(name: str) -> Type[BaseSorting]: - """ - Returns the Sorting Extractor class based on its name. - - Parameters - ---------- - name: str - The Sorting Extractor's name. - - Returns - ------- - sorting_extractor: BaseSorting - The Sorting Extractor class. - """ - - for sorting_extractor in sorting_extractor_full_list: - if sorting_extractor.__name__ == name: - return sorting_extractor - - raise ValueError(f"Sorting extractor '{name}' not found.") diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 19038344ee..fa627d2ee3 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -25,9 +25,6 @@ class HDSortSortingExtractor(MatlabHelper, BaseSorting): The loaded data. """ - mode = "file" - name = "hdsort" - def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 4fe915a96b..de4929218b 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_HS2SX = True -except ImportError: - HAVE_HS2SX = False - class HerdingspikesSortingExtractor(BaseSorting): """Load HerdingSpikes format data as a sorting extractor. @@ -31,15 +24,13 @@ class HerdingspikesSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_HS2SX # check at class level if installed or not - mode = "file" - installation_mesg = ( - "To use the HS2SortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - name = "herdingspikes" + installation_mesg = "To use the HS2SortingExtractor install h5py: \n\n pip install h5py\n\n" def __init__(self, file_path, load_unit_info=True): - assert self.installed, self.installation_mesg + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) self._recording_file = file_path self._rf = h5py.File(self._recording_file, mode="r") diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 34481c94f1..5dd549347d 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -65,9 +65,7 @@ class IblRecordingExtractor(BaseRecording): The recording extractor which allows access to the traces. """ - mode = "folder" installation_mesg = "To use the IblRecordingSegment, install ibllib: \n\n pip install ONE-api\npip install ibllib\n" - name = "ibl_recording" @staticmethod def _get_default_one(cache_folder: Optional[Union[Path, str]] = None): @@ -304,7 +302,6 @@ class IblSortingExtractor(BaseSorting): The loaded data. """ - name = "ibl" installation_mesg = "IBL extractors require ibllib as a dependency." " To install, run: \n\n pip install ibllib\n\n" def __init__(self, pid: str, good_clusters_only: bool = False, load_unit_properties: bool = True, one=None): diff --git a/src/spikeinterface/extractors/klustaextractors.py b/src/spikeinterface/extractors/klustaextractors.py index 82534771a1..162376cb3c 100644 --- a/src/spikeinterface/extractors/klustaextractors.py +++ b/src/spikeinterface/extractors/klustaextractors.py @@ -18,13 +18,6 @@ from spikeinterface.core import BaseRecording, BaseSorting, BaseRecordingSegment, BaseSortingSegment, read_python from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - # noinspection SpellCheckingInspection class KlustaSortingExtractor(BaseSorting): @@ -43,18 +36,15 @@ class KlustaSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_H5PY # check at class level if installed or not - installation_mesg = ( - "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - mode = "file_or_folder" - name = "klusta" + installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" default_cluster_groups = {0: "Noise", 1: "MUA", 2: "Good", 3: "Unsorted"} def __init__(self, file_or_folder_path, exclude_cluster_groups=None): - assert HAVE_H5PY, self.installation_mesg - # ~ SortingExtractor.__init__(self) + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None diff --git a/src/spikeinterface/extractors/mclustextractors.py b/src/spikeinterface/extractors/mclustextractors.py index 5cfa583054..d611a1576a 100644 --- a/src/spikeinterface/extractors/mclustextractors.py +++ b/src/spikeinterface/extractors/mclustextractors.py @@ -29,8 +29,6 @@ class MClustSortingExtractor(BaseSorting): Loaded data. """ - name = "mclust" - def __init__(self, folder_path, sampling_frequency, sampling_frequency_raw=None): end_header_str = "%%ENDHEADER" ext_list = ["t64", "t32", "t", "raw64", "raw32"] diff --git a/src/spikeinterface/extractors/mcsh5extractors.py b/src/spikeinterface/extractors/mcsh5extractors.py index f419b7e64d..d86969e005 100644 --- a/src/spikeinterface/extractors/mcsh5extractors.py +++ b/src/spikeinterface/extractors/mcsh5extractors.py @@ -24,18 +24,12 @@ class MCSH5RecordingExtractor(BaseRecording): The loaded data. """ - mode = "file" - installation_mesg = ( - "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - name = "mcsh5" + installation_mesg = "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" def __init__(self, file_path, stream_id=0): try: import h5py - - HAVE_MCSH5 = True except ImportError: raise ImportError(self.installation_mesg) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index acc7be58dd..f055e1d7c9 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -36,9 +36,6 @@ class MdaRecordingExtractor(BaseRecording): The loaded data. """ - mode = "folder" - name = "mda" - def __init__(self, folder_path, raw_fname="raw.mda", params_fname="params.json", geom_fname="geom.csv"): folder_path = Path(folder_path) self._folder_path = folder_path @@ -192,9 +189,6 @@ class MdaSortingExtractor(BaseSorting): The loaded data. """ - mode = "file" - name = "mda" - def __init__(self, file_path, sampling_frequency): firings = readmda(str(Path(file_path).absolute())) labels = firings[2, :] diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 239928f66d..2e70d5ba41 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -27,9 +27,7 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "AlphaOmegaRawIO" - name = "alphaomega" def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path, lsx_files) diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index 71e1277946..e086cb5dde 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -21,9 +21,7 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "AxonaRawIO" - name = "axona" def __init__(self, file_path, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 9f23575dba..15953ff6d7 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -31,10 +31,6 @@ class BiocamRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" - NeoRawIOClass = "BiocamRawIO" - name = "biocam" - def __init__( self, file_path, diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index ab3710e05e..9bd2b05f24 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -31,9 +31,7 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "file" NeoRawIOClass = "BlackrockRawIO" - name = "blackrock" def __init__( self, @@ -87,10 +85,8 @@ class BlackrockSortingExtractor(NeoBaseSortingExtractor): Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ - mode = "file" NeoRawIOClass = "BlackrockRawIO" neo_returns_frames = False - name = "blackrock" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index e2c79478fa..73a783ec5d 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -27,9 +27,7 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "CedRawIO" - name = "ced" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 90627d5772..8369369922 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -26,9 +26,7 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "EDFRawIO" - name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = {"filename": str(file_path)} diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 43439b80c9..34c8bf2eb5 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -33,9 +33,7 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "file" NeoRawIOClass = "IntanRawIO" - name = "intan" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index a66075b451..6c72696e16 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -39,9 +39,7 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load """ - mode = "file" NeoRawIOClass = "MaxwellRawIO" - name = "maxwell" def __init__( self, @@ -96,8 +94,6 @@ class MaxwellEventExtractor(BaseEvent): Class for reading TTL events from Maxwell files. """ - name = "maxwell" - def __init__(self, file_path): import h5py diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 0cbd9263ba..307a6c1fba 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -30,9 +30,7 @@ class MCSRawRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "RawMCSRawIO" - name = "mcsraw" def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 76f0b29f54..21a597029b 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -40,9 +40,7 @@ class MEArecRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "MEArecRawIO" - name = "mearec" def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -75,10 +73,8 @@ def map_to_neo_kwargs( class MEArecSortingExtractor(NeoBaseSortingExtractor): - mode = "file" NeoRawIOClass = "MEArecRawIO" neo_returns_frames = False - name = "mearec" def __init__(self, file_path: Union[str, Path]): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 0670371ba9..98f4a7c2ff 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -40,9 +40,7 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): Note that here the default is False contrary to neo. """ - mode = "folder" NeoRawIOClass = "NeuralynxRawIO" - name = "neuralynx" def __init__( self, @@ -90,11 +88,9 @@ class NeuralynxSortingExtractor(NeoBaseSortingExtractor): Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ - mode = "folder" NeoRawIOClass = "NeuralynxRawIO" neo_returns_frames = True need_t_start_from_signal_stream = True - name = "neuralynx" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index 49784418e1..ac569c0df0 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -47,9 +47,7 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NeuroExplorerRawIO" - name = "neuroexplorer" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = {"filename": str(file_path)} diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 104f47af24..6c6f1d4bea 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -37,9 +37,7 @@ class NeuroScopeRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NeuroScopeRawIO" - name = "neuroscope" def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path, xml_file_path) @@ -103,8 +101,6 @@ class NeuroScopeSortingExtractor(BaseSorting): Path to the .xml file referenced by this sorting. """ - name = "neuroscope" - def __init__( self, folder_path: OptionalPathType = None, diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 00e5f8bfc1..b869936fa3 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -27,9 +27,7 @@ class NixRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NIXRawIO" - name = "nix" def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index f3363b9013..04c25998f0 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -64,9 +64,7 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): neo.OpenEphysRawIO is now handling gaps directly but makes the read slower. """ - mode = "folder" NeoRawIOClass = "OpenEphysRawIO" - name = "openephyslegacy" def __init__( self, @@ -138,9 +136,7 @@ class OpenEphysBinaryRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "folder" NeoRawIOClass = "OpenEphysBinaryRawIO" - name = "openephys" def __init__( self, @@ -287,9 +283,7 @@ class OpenEphysBinaryEventExtractor(NeoBaseEventExtractor): """ - mode = "folder" NeoRawIOClass = "OpenEphysBinaryRawIO" - name = "openephys" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index cf08778ffa..9c2586dd5a 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -25,9 +25,7 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "PlexonRawIO" - name = "plexon" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -54,9 +52,7 @@ class PlexonSortingExtractor(NeoBaseSortingExtractor): The file path to load the recordings from. """ - mode = "file" NeoRawIOClass = "PlexonRawIO" - name = "plexon" neo_returns_frames = True def __init__(self, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 6c9160f13b..4434d02cc1 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -30,9 +30,7 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" - name = "plexon2" def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -66,10 +64,8 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" neo_returns_frames = True - name = "plexon2" def __init__(self, file_path, sampling_frequency=None): from neo.rawio import Plexon2RawIO @@ -98,9 +94,7 @@ class Plexon2EventExtractor(NeoBaseEventExtractor): """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" - name = "plexon2" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 1bd0351553..cbc1db3f74 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -26,9 +26,7 @@ class Spike2RecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "Spike2RawIO" - name = "spike2" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 3d57817f88..e7c31b8afa 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -28,9 +28,7 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "SpikeGadgetsRawIO" - name = "spikegadgets" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index adfd0f702e..10a1f78265 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -40,9 +40,7 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "SpikeGLXRawIO" - name = "spikeglx" def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path, load_sync_channel=load_sync_channel) @@ -110,9 +108,7 @@ class SpikeGLXEventExtractor(NeoBaseEventExtractor): """ - mode = "folder" NeoRawIOClass = "SpikeGLXRawIO" - name = "spikeglx" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 27b456102f..a1298dece7 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -27,9 +27,7 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load """ - mode = "folder" NeoRawIOClass = "TdtRawIO" - name = "tdt" def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 9786766af1..13fd56e959 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -472,8 +472,6 @@ class NwbRecordingExtractor(BaseRecording): >>> rec = NwbRecordingExtractor(s3_url, stream_mode="fsspec", stream_cache_path="cache") """ - mode = "file" - name = "nwb" installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" def __init__( @@ -1000,9 +998,7 @@ class NwbSortingExtractor(BaseSorting): The sorting extractor for the NWB file. """ - mode = "file" installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" - name = "nwb" def __init__( self, diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 3287f7422f..737a88c51a 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -26,12 +26,9 @@ class BasePhyKilosortSortingExtractor(BaseSorting): If True, all cluster properties are loaded from the tsv/csv files. """ - installed = False # check at class level if installed or not - mode = "folder" installation_mesg = ( "To use the PhySortingExtractor install pandas: \n\n pip install pandas\n\n" # error message when not installed ) - name = "phykilosort" def __init__( self, @@ -43,14 +40,10 @@ def __init__( ): try: import pandas as pd - - HAVE_PD = True except ImportError: - HAVE_PD = False - assert HAVE_PD, self.installation_mesg + raise ImportError(self.installation_mesg) phy_folder = Path(folder_path) - spike_times = np.load(phy_folder / "spike_times.npy").astype(int) if (phy_folder / "spike_clusters.npy").is_file(): @@ -228,8 +221,6 @@ class PhySortingExtractor(BasePhyKilosortSortingExtractor): The loaded Sorting object. """ - name = "phy" - def __init__( self, folder_path: Path | str, @@ -269,8 +260,6 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): The loaded Sorting object. """ - name = "kilosort" - def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove_empty_units: bool = True): BasePhyKilosortSortingExtractor.__init__( self, diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index b53b3b2056..1c5c147c6a 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -30,24 +30,19 @@ class SHYBRIDRecordingExtractor(BinaryRecordingExtractor): Loaded data. """ - mode = "folder" installation_mesg = ( "To use the SHYBRID extractors, install SHYBRID and pyyaml: " "\n\n pip install shybrid pyyaml\n\n" ) - name = "shybrid" def __init__(self, file_path): try: import hybridizer.io as sbio import hybridizer.probes as sbprb import yaml - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(self.installation_mesg) # load params file related to the given shybrid recording - assert HAVE_SBEX, self.installation_mesg assert Path(file_path).suffix in [".yml", ".yaml"], "The 'file_path' should be a yaml file!" params = sbio.get_params(file_path)["data"] file_path = Path(file_path) @@ -102,12 +97,9 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * import hybridizer.io as sbio import hybridizer.probes as sbprb import yaml - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(SHYBRIDRecordingExtractor.installation_mesg) - assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg assert recording.get_num_segments() == 1, "SHYBRID can only write single segment recordings" save_path = Path(save_path) recording_name = "recording.bin" @@ -159,18 +151,14 @@ class SHYBRIDSortingExtractor(BaseSorting): """ installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n" - name = "shybrid" def __init__(self, file_path, sampling_frequency, delimiter=","): try: import hybridizer.io as sbio import hybridizer.probes as sbprb - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(self.installation_mesg) - assert HAVE_SBEX, self.installation_mesg assert Path(file_path).suffix == ".csv", "The 'file_path' should be a csv file!" if Path(file_path).is_file(): @@ -205,12 +193,9 @@ def write_sorting(sorting, save_path): try: import hybridizer.io as sbio import hybridizer.probes as sbprb - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(SHYBRIDSortingExtractor.installation_mesg) - assert HAVE_SBEX, SHYBRIDSortingExtractor.installation_mesg assert sorting.get_num_segments() == 1, "SHYBRID can only write single segment sortings" save_path = Path(save_path) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 522f639760..c3e92a63ff 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -23,10 +23,6 @@ class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): "filt" extracts the filtered data, "raw" extracts the raw data, and "aux" extracts the auxiliary data. """ - extractor_name = "SinapsResearchPlatform" - mode = "file" - name = "sinaps_research_platform" - def __init__(self, file_path: str | Path, stream_name: str = "filt"): from ..preprocessing import UnsignedToSignedRecording @@ -91,10 +87,6 @@ class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): Path to the SiNAPS .h5 file. """ - extractor_name = "SinapsResearchPlatformH5" - mode = "file" - name = "sinaps_research_platform_h5" - def __init__(self, file_path: str | Path): self._file_path = file_path diff --git a/src/spikeinterface/extractors/spykingcircusextractors.py b/src/spikeinterface/extractors/spykingcircusextractors.py index 7c3fb154fe..b8a1e5635e 100644 --- a/src/spikeinterface/extractors/spykingcircusextractors.py +++ b/src/spikeinterface/extractors/spykingcircusextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - class SpykingCircusSortingExtractor(BaseSorting): """Load SpykingCircus format data as a recording extractor. @@ -29,13 +22,13 @@ class SpykingCircusSortingExtractor(BaseSorting): Loaded data. """ - installed = HAVE_H5PY # check at class level if installed or not - mode = "folder" installation_mesg = "To use the SpykingCircusSortingExtractor install h5py: \n\n pip install h5py\n\n" - name = "spykingcircus" def __init__(self, folder_path): - assert HAVE_H5PY, self.installation_mesg + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) spykingcircus_folder = Path(folder_path) listfiles = spykingcircus_folder.iterdir() diff --git a/src/spikeinterface/extractors/tridesclousextractors.py b/src/spikeinterface/extractors/tridesclousextractors.py index 8589f03fd4..ac1ce4727b 100644 --- a/src/spikeinterface/extractors/tridesclousextractors.py +++ b/src/spikeinterface/extractors/tridesclousextractors.py @@ -22,9 +22,7 @@ class TridesclousSortingExtractor(BaseSorting): Loaded data. """ - mode = "folder" installation_mesg = "To use the TridesclousSortingExtractor install tridesclous: \n\n pip install tridesclous\n\n" # error message when not installed - name = "tridesclous" def __init__(self, folder_path, chan_grp=None): try: diff --git a/src/spikeinterface/extractors/waveclussnippetstextractors.py b/src/spikeinterface/extractors/waveclussnippetstextractors.py index 7c26eee7bd..75bae32519 100644 --- a/src/spikeinterface/extractors/waveclussnippetstextractors.py +++ b/src/spikeinterface/extractors/waveclussnippetstextractors.py @@ -10,7 +10,6 @@ class WaveClusSnippetsExtractor(MatlabHelper, BaseSnippets): - name = "waveclus" def __init__(self, file_path): file_path = Path(file_path) if isinstance(file_path, str) else file_path diff --git a/src/spikeinterface/extractors/waveclustextractors.py b/src/spikeinterface/extractors/waveclustextractors.py index 844b1cc7cf..3d024910fa 100644 --- a/src/spikeinterface/extractors/waveclustextractors.py +++ b/src/spikeinterface/extractors/waveclustextractors.py @@ -25,8 +25,6 @@ class WaveClusSortingExtractor(MatlabHelper, BaseSorting): Loaded data. """ - name = "waveclus" - def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 61a49ccf01..7a76906acc 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import yaml - - HAVE_YAML = True -except: - HAVE_YAML = False - class YassSortingExtractor(BaseSorting): """Load YASS format data as a sorting extractor. @@ -29,15 +22,13 @@ class YassSortingExtractor(BaseSorting): Loaded data. """ - mode = "folder" - installed = HAVE_YAML # check at class level if installed or not - installation_mesg = ( - "To use the Yass extractor, install pyyaml: \n\n pip install pyyaml\n\n" # error message when not installed - ) - name = "yass" + installation_mesg = "To use the Yass extractor, install pyyaml: \n\n pip install pyyaml\n\n" def __init__(self, folder_path): - assert HAVE_YAML, self.installation_mesg + try: + import yaml + except: + raise ImportError(self.installation_mesg) folder_path = Path(folder_path) From 5b663ad012534c85fc0565f1cf32ad5d0d739096 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 5 Jul 2024 10:33:59 +0200 Subject: [PATCH 27/57] Fix biocam --- src/spikeinterface/extractors/neoextractors/biocam.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 15953ff6d7..e7b6199ea9 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -31,6 +31,8 @@ class BiocamRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ + NeoRawIOClass = "BiocamRawIO" + def __init__( self, file_path, From dd99121652f8875d130f008ee10c6d1eb290201e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 14:24:14 +0200 Subject: [PATCH 28/57] Implement apply_merges_to_sorting() --- src/spikeinterface/core/sorting_tools.py | 201 +++++++++++++++++- .../core/tests/test_sorting_tools.py | 85 +++++++- .../curation/mergeunitssorting.py | 33 +-- 3 files changed, 291 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 02f4529a98..9ee8ecb528 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,7 +1,10 @@ from __future__ import annotations -from .basesorting import BaseSorting + import numpy as np +from .basesorting import BaseSorting +from .numpyextractors import NumpySorting + def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]: """ @@ -220,3 +223,199 @@ def random_spikes_selection( raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") return random_spikes_indices + + + +def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'): + """ + Function to apply a resolved representation of the merges to a sorting object. + + This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. + + If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. + + Optionaly, the boolean of kept spikes is returned + + Parameters + ---------- + sorting : Sorting + The Sorting object to apply merges + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges + censor_ms: None or float + When applying the merges, should be discard consecutive spikes violating a given refractory per + return_kept : bool, default False + return also a booolean of kept spikes + new_id_strategy : "append" | "take_first", default "append" + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + sorting : The new Sorting object + The newly create sorting with the merged units + keep_mask : numpy.array + A boolean mask, if censor_ms is not None, telling which spike from the original spike vector + has been kept, given the refractory period violations (None if censor_ms is None) + """ + + spikes = sorting.to_spike_vector().copy() + keep_mask = np.ones(len(spikes), dtype=bool) + + new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, + new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy) + + rename_ids = {} + for i, merge_group in enumerate(units_to_merge): + for unit_id in merge_group: + rename_ids[unit_id] = new_unit_ids[i] + + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [] + for i in range(num_seg): + segment_slices += [(segment_limits[i], segment_limits[i+1])] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for old_unit_id in sorting.unit_ids: + if old_unit_id in rename_ids.keys(): + new_unit_id = rename_ids[old_unit_id] + else: + new_unit_id = old_unit_id + + new_unit_index = all_unit_ids.index(new_unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][old_unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + + if censor_ms is not None: + rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) + for group_old_ids in units_to_merge: + for segment_index in range(num_seg): + group_indices = [] + for unit_id in group_old_ids: + group_indices.append(spike_indices[segment_index][unit_id]) + group_indices = np.concatenate(group_indices) + group_indices = np.sort(group_indices) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv ) + keep_mask[group_indices[inds + 1]] = False + + spikes = spikes[keep_mask] + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_kept: + return sorting, keep_mask + else: + return sorting + + +def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): + """ + Function to get the list of unique unit_ids after some merges, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + + Returns + ------- + + all_unit_ids : The unit ids in the merged sorting + The units_ids that will be present after merges + + """ + old_unit_ids = np.asarray(old_unit_ids) + + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + + all_unit_ids = list(old_unit_ids.copy()) + for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge): + assert len(group_ids) > 1, "A merge should have at least two units" + for unit_id in group_ids: + assert unit_id in old_unit_ids, "Merged ids should be in the sorting" + for unit_id in group_ids: + if unit_id != new_unit_id: + # new_unit_id can be inside group_ids + all_unit_ids.remove(unit_id) + if new_unit_id not in all_unit_ids: + all_unit_ids.append(new_unit_id) + return np.array(all_unit_ids) + + + +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the length as + to_be_merged. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges + new_id_strategy : "append" | "take_first", default "append" + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + new_unit_ids : The new unit ids + The new units_ids associated with the merges + + + """ + old_unit_ids = np.asarray(old_unit_ids) + + + if new_unit_ids is not None: + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + else: + dtype = old_unit_ids.dtype + num_merge = len(units_to_merge) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "take_first": + new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge] + elif new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids = [str(m + i) for i in range(num_merge)] + else: + # we cannot automatically find new names + new_unit_ids = [f"merge{i}" for i in range(num_merge)] + else: + # dtype int + new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) + else: + raise ValueError("wrong new_id_strategy") + + return new_unit_ids \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 1aefeeb062..24739fb374 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -9,6 +9,9 @@ spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices, + apply_merges_to_sorting, + _get_ids_after_merging, + generate_unit_ids_for_merge_group ) @@ -74,8 +77,88 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all") assert random_spikes_indices.size == spikes.size +def test_apply_merges_to_sorting(): + + times = np.array([0, 0, 10, 20, 300]) + labels = np.array(['a', 'b', 'c', 'a', 'b' ]) + + # unit_ids str + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] + ) + spikes1 = sorting1.to_spike_vector() + + sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None) + spikes2 = sorting2.to_spike_vector() + assert sorting2.unit_ids.size == 2 + assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size + assert np.array_equal(['c', 'merge0'], sorting2.unit_ids) + assert np.array_equal( + spikes1[spikes1['unit_index'] == 2]['sample_index'], + spikes2[spikes2['unit_index'] == 0]['sample_index'] + ) + + + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True) + spikes3 = sorting3.to_spike_vector() + assert spikes3.size < spikes1.size + assert not keep_mask[1] + st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0') + assert st.size == 3 # one spike is removed by censor period + + + # unit_ids int + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30] + ) + spikes1 = sorting1.to_spike_vector() + sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) + assert np.array_equal(sorting2.unit_ids, [30, 31]) + + sorting1 = NumpySorting.from_times_labels( + [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] + ) + sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ['a', 'c']) + + + +def test_get_ids_after_merging(): + + all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd']) + assert np.array_equal(all_unit_ids, ['c', 'd', 'x']) + # print(all_unit_ids) + + all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) + assert np.array_equal(all_unit_ids, [12, 9, 28]) + # print(all_unit_ids) + + +def test_generate_unit_ids_for_merge_group(): + + new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append') + assert np.array_equal(new_unit_ids, ['merge0', 'merge1']) + + new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, ['a', 'd']) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append') + assert np.array_equal(new_unit_ids, [16, 17]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, [0, 9]) + + new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append') + assert np.array_equal(new_unit_ids, ["16", "17"]) + + new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first') + assert np.array_equal(new_unit_ids, ["0", "9"]) if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - test_random_spikes_selection() + # test_random_spikes_selection() + + test_apply_merges_to_sorting() + test_get_ids_after_merging() + test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index bbdb70b2f6..c182d4130a 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -4,7 +4,7 @@ from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class from copy import deepcopy - +from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group class MergeUnitsSorting(BaseSorting): """ @@ -44,35 +44,16 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() + from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group + new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, + new_unit_ids=new_unit_ids, + new_id_strategy='append') + all_removed_ids = [] for ids in units_to_merge: all_removed_ids.extend(ids) keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids] - if new_unit_ids is None: - dtype = parents_unit_ids.dtype - # select new_unit_ids greater that the max id, event greater than the numerical str ids - if np.issubdtype(dtype, np.character): - # dtype str - if all(p.isdigit() for p in parents_unit_ids): - # All str are digit : we can generate a max - m = max(int(p) for p in parents_unit_ids) + 1 - new_unit_ids = [str(m + i) for i in range(num_merge)] - else: - # we cannot automatically find new names - new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError( - "Unable to find 'new_unit_ids' because it is a string and parents " - "already contain merges. Pass a list of 'new_unit_ids' as an argument." - ) - else: - # dtype int - new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) - else: - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") - assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" # some checks @@ -81,7 +62,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" "" # new units are put at the end - unit_ids = keep_unit_ids + new_unit_ids + unit_ids = keep_unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) # assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids' From 8783b7fbba6d6029fbac09357a85ccbf27ea0ca9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 12:28:37 +0000 Subject: [PATCH 29/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 46 ++++++------ .../core/tests/test_sorting_tools.py | 70 +++++++++---------- .../curation/mergeunitssorting.py | 8 ++- 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 9ee8ecb528..8d038aa45b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -225,17 +225,18 @@ def random_spikes_selection( return random_spikes_indices - -def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy='append'): +def apply_merges_to_sorting( + sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" +): """ Function to apply a resolved representation of the merges to a sorting object. This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. - + If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. Optionaly, the boolean of kept spikes is returned - + Parameters ---------- sorting : Sorting @@ -251,7 +252,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m return_kept : bool, default False return also a booolean of kept spikes new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges @@ -267,14 +268,15 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m spikes = sorting.to_spike_vector().copy() keep_mask = np.ones(len(spikes), dtype=bool) - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy) + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) rename_ids = {} for i, merge_group in enumerate(units_to_merge): for unit_id in merge_group: rename_ids[unit_id] = new_unit_ids[i] - + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) all_unit_ids = list(all_unit_ids) @@ -282,23 +284,23 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) segment_slices = [] for i in range(num_seg): - segment_slices += [(segment_limits[i], segment_limits[i+1])] + segment_slices += [(segment_limits[i], segment_limits[i + 1])] # using this function vaoid to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) - + for old_unit_id in sorting.unit_ids: if old_unit_id in rename_ids.keys(): new_unit_id = rename_ids[old_unit_id] else: new_unit_id = old_unit_id - + new_unit_index = all_unit_ids.index(new_unit_id) for segment_index in range(num_seg): spike_inds = spike_indices[segment_index][old_unit_id] spikes["unit_index"][spike_inds] = new_unit_index - + if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) for group_old_ids in units_to_merge: @@ -308,7 +310,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m group_indices.append(spike_indices[segment_index][unit_id]) group_indices = np.concatenate(group_indices) group_indices = np.sort(group_indices) - inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv ) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv) keep_mask[group_indices[inds + 1]] = False spikes = spikes[keep_mask] @@ -326,7 +328,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): be provided. Every new unit_id will be added at the end if not already present. - + Parameters ---------- old_unit_ids : np.array @@ -341,7 +343,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): ------- all_unit_ids : The unit ids in the merged sorting - The units_ids that will be present after merges + The units_ids that will be present after merges """ old_unit_ids = np.asarray(old_unit_ids) @@ -362,8 +364,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): return np.array(all_unit_ids) - -def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy='append'): +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the length as @@ -380,20 +381,19 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, merged units will have the first unit_id of every lists of merges new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. + The strategy that should be used, if new_unit_ids is None, to create new unit_ids. "append" : new_units_ids will be added at the end of max(sorging.unit_ids) "take_first" : new_unit_ids will be the first unit_id of every list of merges - + Returns ------- new_unit_ids : The new unit ids - The new units_ids associated with the merges + The new units_ids associated with the merges + - """ old_unit_ids = np.asarray(old_unit_ids) - if new_unit_ids is not None: assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" else: @@ -418,4 +418,4 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids else: raise ValueError("wrong new_id_strategy") - return new_unit_ids \ No newline at end of file + return new_unit_ids diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 24739fb374..38baf62c35 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -11,7 +11,7 @@ spike_vector_to_indices, apply_merges_to_sorting, _get_ids_after_merging, - generate_unit_ids_for_merge_group + generate_unit_ids_for_merge_group, ) @@ -77,56 +77,47 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection(sorting, num_samples, method="all") assert random_spikes_indices.size == spikes.size + def test_apply_merges_to_sorting(): times = np.array([0, 0, 10, 20, 300]) - labels = np.array(['a', 'b', 'c', 'a', 'b' ]) + labels = np.array(["a", "b", "c", "a", "b"]) # unit_ids str - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) spikes1 = sorting1.to_spike_vector() - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None) spikes2 = sorting2.to_spike_vector() assert sorting2.unit_ids.size == 2 assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size - assert np.array_equal(['c', 'merge0'], sorting2.unit_ids) + assert np.array_equal(["c", "merge0"], sorting2.unit_ids) assert np.array_equal( - spikes1[spikes1['unit_index'] == 2]['sample_index'], - spikes2[spikes2['unit_index'] == 0]['sample_index'] + spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] ) - - sorting3, keep_mask = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=1.5, return_kept=True) + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) spikes3 = sorting3.to_spike_vector() assert spikes3.size < spikes1.size assert not keep_mask[1] - st = sorting3.get_unit_spike_train(segment_index=0, unit_id='merge0') - assert st.size == 3 # one spike is removed by censor period - + st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0") + assert st.size == 3 # one spike is removed by censor period # unit_ids int - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=[10, 20, 30] - ) + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30]) spikes1 = sorting1.to_spike_vector() sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) assert np.array_equal(sorting2.unit_ids, [30, 31]) - sorting1 = NumpySorting.from_times_labels( - [times, times], [labels, labels], 10_000., unit_ids=['a', 'b', 'c'] - ) - sorting2 = apply_merges_to_sorting(sorting1, [['a', 'b']], censor_ms=None, new_id_strategy="take_first") - assert np.array_equal(sorting2.unit_ids, ['a', 'c']) - + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ["a", "c"]) def test_get_ids_after_merging(): - all_unit_ids = _get_ids_after_merging(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], ['x', 'd']) - assert np.array_equal(all_unit_ids, ['c', 'd', 'x']) + all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"]) + assert np.array_equal(all_unit_ids, ["c", "d", "x"]) # print(all_unit_ids) all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) @@ -136,24 +127,33 @@ def test_get_ids_after_merging(): def test_generate_unit_ids_for_merge_group(): - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='append') - assert np.array_equal(new_unit_ids, ['merge0', 'merge1']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["merge0", "merge1"]) - new_unit_ids = generate_unit_ids_for_merge_group(['a', 'b', 'c', 'd', 'e'], [['a', 'b'], ['d', 'e']], new_id_strategy='take_first') - assert np.array_equal(new_unit_ids, ['a', 'd']) + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["a", "d"]) - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append") assert np.array_equal(new_unit_ids, [16, 17]) - - new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first") assert np.array_equal(new_unit_ids, [0, 9]) - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='append') + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append" + ) assert np.array_equal(new_unit_ids, ["16", "17"]) - - new_unit_ids = generate_unit_ids_for_merge_group(["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy='take_first') + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first" + ) assert np.array_equal(new_unit_ids, ["0", "9"]) + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index c182d4130a..3771b1c63c 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -6,6 +6,7 @@ from copy import deepcopy from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group + class MergeUnitsSorting(BaseSorting): """ Class that handles several merges of units from a Sorting object based on a list of lists of unit_ids. @@ -45,9 +46,10 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy sampling_frequency = sorting.get_sampling_frequency() from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group - new_unit_ids = generate_unit_ids_for_merge_group(sorting.unit_ids, units_to_merge, - new_unit_ids=new_unit_ids, - new_id_strategy='append') + + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" + ) all_removed_ids = [] for ids in units_to_merge: From 9703af1c94174d8f0159ed97b91b2c2d9fcd970a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Jul 2024 14:30:23 +0200 Subject: [PATCH 30/57] Regularize whitening (#2744) Regularize whitening --- .../preprocessing/tests/test_whiten.py | 12 +++-- src/spikeinterface/preprocessing/whiten.py | 48 ++++++++++++++++--- .../sorters/internal/spyking_circus2.py | 2 +- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index c3d1544869..04b731de4f 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -8,13 +8,13 @@ def test_whiten(create_cache_folder): cache_folder = create_cache_folder - rec = generate_recording(num_channels=4) + rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) random_chunk_kwargs = {} - W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) - print(W) - print(M) + W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) + # print(W) + # print(M) with pytest.raises(AssertionError): W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) @@ -41,6 +41,10 @@ def test_whiten(create_cache_folder): assert rec4.get_dtype() == "int16" assert rec4._kwargs["M"] is None + # test regularization : norm should be smaller + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) + if __name__ == "__main__": test_whiten() diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 874d4304e3..96cf5e028f 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -7,6 +7,7 @@ from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype +from ..core.globals import get_global_job_kwargs class WhitenRecording(BasePreprocessor): @@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor): M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -55,6 +62,8 @@ def __init__( recording, dtype=None, apply_mean=False, + regularize=False, + regularize_kwargs=None, mode="global", radius_um=100.0, int_scale=None, @@ -75,7 +84,14 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps + recording, + mode, + random_chunk_kwargs, + apply_mean, + radius_um=radius_um, + eps=eps, + regularize=regularize, + regularize_kwargs=regularize_kwargs, ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -90,6 +106,8 @@ def __init__( mode=mode, radius_um=radius_um, apply_mean=apply_mean, + regularize=regularize, + regularize_kwargs=regularize_kwargs, int_scale=float(int_scale) if int_scale is not None else None, M=M.tolist() if M is not None else None, W=W.tolist(), @@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): +def compute_whitening_matrix( + recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None +): """ Compute whitening matrix @@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r eps : float or None, default: None Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. - + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. Returns ------- W : 2D array @@ -162,7 +187,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype("float32") + + regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} if apply_mean: M = np.mean(random_data, axis=0) @@ -172,8 +198,18 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r M = None data = random_data - cov = data.T @ data - cov = cov / data.shape[0] + if not regularize: + cov = data.T @ data + cov = cov / data.shape[0] + else: + import sklearn.covariance + + method = regularize_kwargs.pop("method") + regularize_kwargs["assume_centered"] = True + estimator_class = getattr(sklearn.covariance, method) + estimator = estimator_class(**regularize_kwargs) + estimator.fit(data) + cov = estimator.covariance_ # Here we determine eps used below to avoid division by zero. # Typically we can assume that data is either unscaled integers or in units of diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 45cc93d0b6..be75877f02 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -147,7 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32") + recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) noise_levels = get_noise_levels(recording_w, return_scaled=False) From f8a43318bdac09535535cf600c6231707f4dd0a3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Jul 2024 14:30:54 +0200 Subject: [PATCH 31/57] Benchmarks components: plotting utils (#2959) Benchmarks components: plotting utils --- .../comparison/groundtruthstudy.py | 20 +++- .../benchmark/benchmark_clustering.py | 98 +++++++++++-------- .../benchmark/benchmark_matching.py | 38 ++++--- .../benchmark/benchmark_peak_detection.py | 2 + .../benchmark/benchmark_tools.py | 4 +- src/spikeinterface/sortingcomponents/tools.py | 11 +++ src/spikeinterface/widgets/gtstudy.py | 14 +-- 7 files changed, 120 insertions(+), 67 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 6682252349..ba7268b4f0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -9,9 +9,6 @@ import numpy as np from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core.core_tools import SIJsonEncoder -from spikeinterface.core.job_tools import split_job_kwargs - from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface.qualitymetrics import compute_quality_metrics @@ -54,6 +51,7 @@ def __init__(self, study_folder): self.cases = {} self.sortings = {} self.comparisons = {} + self.colors = None self.scan_folder() @@ -175,6 +173,22 @@ def remove_sorting(self, key): if f.exists(): f.unlink() + def set_colors(self, colors=None, map_name="tab20"): + from spikeinterface.widgets import get_some_colors + + if colors is None: + case_keys = list(self.cases.keys()) + self.colors = get_some_colors( + case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 + ) + else: + self.colors = colors + + def get_colors(self): + if self.colors is None: + self.set_colors() + return self.colors + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): if case_keys is None: case_keys = self.cases.keys() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 2da950ceda..92fcda35d9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -185,11 +185,11 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - ax = axs[count] + ax = axes[count] for key in case_keys: label = self.cases[key]["label"] @@ -211,7 +211,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -234,21 +234,25 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): else: distances = sklearn.metrics.pairwise_distances(a, b, metric) - im = axs[0, count].imshow(distances, aspect="auto") - axs[0, count].set_title(metric) - fig.colorbar(im, ax=axs[0, count]) + im = axes[0, count].imshow(distances, aspect="auto") + axes[0, count].set_title(metric) + fig.colorbar(im, ax=axes[0, count]) label = self.cases[key]["label"] - axs[0, count].set_title(label) + axes[0, count].set_title(label) return fig - def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): + def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5), axes=None): if case_keys is None: case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + if axes is None: + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + axes = axes.flatten() + else: + fig = None for count, key in enumerate(case_keys): @@ -287,13 +291,13 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5 elif metric == "agreement": for found, real in zip(matched_ids2[mask], unit_ids1[mask]): to_plot += [scores.at[real, found]] - axs[0, count].plot(snr_matched, to_plot, ".", label="matched") - axs[0, count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed") - axs[0, count].set_xlabel("snr") - axs[0, count].set_ylabel(metric) + axes[count].plot(snr_matched, to_plot, ".", label="matched") + axes[count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed") + axes[count].set_xlabel("snr") + axes[count].set_ylabel(metric) label = self.cases[key]["label"] - axs[0, count].set_title(label) - axs[0, count].legend() + axes[count].set_title(label) + axes[count].legend() return fig @@ -303,7 +307,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -348,47 +352,61 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs elif metric == "agreement": for found, real in zip(matched_ids2[mask], unit_ids1[mask]): to_plot += [scores.at[real, found]] - axs[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") - axs[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") - axs[0, count].set_xlabel("depth") - axs[0, count].set_ylabel("snr") + elif metric in ["recall", "precision", "accuracy"]: + to_plot = result["gt_comparison"].get_performance()[metric].values + depth_matched = depth + snr_matched = metrics["snr"] + + im = axes[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") + im.set_clim(0, 1) + axes[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") + axes[0, count].set_xlabel("depth") + axes[0, count].set_ylabel("snr") label = self.cases[key]["label"] - axs[0, count].set_title(label) + axes[0, count].set_title(label) + if count > 0: + axes[0, count].set_ylabel("") + axes[0, count].set_yticks([], []) # axs[0, count].legend() + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + fig.colorbar(im, cax=cbar_ax, label=metric) + return fig - def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None): - import pylab as plt + def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None): - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=len(cases_before), nrows=1, figsize=figsize) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, (case_before, case_after) in enumerate(zip(cases_before, cases_after)): ax = axs[count] - - # label = self.cases[case_after]["label"] - - # positions = self.get_result(case_before)["gt_comparison"].sorting1.get_property("gt_unit_locations") - dataset_key = self.cases[case_before]["dataset"] - rec, gt_sorting1 = self.datasets[dataset_key] + _, gt_sorting1 = self.datasets[dataset_key] positions = gt_sorting1.get_property("gt_unit_locations") analyzer = self.get_sorting_analyzer(case_before) metrics_before = analyzer.get_extension("quality_metrics").get_data() x = metrics_before["snr"].values - y_before = self.get_result(case_before)["gt_comparison"].get_performance()[k].values - y_after = self.get_result(case_after)["gt_comparison"].get_performance()[k].values - if count < 2: - ax.set_xticks([], []) - elif count == 2: - ax.set_xlabel("depth (um)") - im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") - fig.colorbar(im, ax=ax) - ax.set_title(k) + y_before = self.get_result(case_before)["gt_comparison"].get_performance()[metric].values + y_after = self.get_result(case_after)["gt_comparison"].get_performance()[metric].values + ax.set_ylabel("depth (um)") ax.set_ylabel("snr") + if count > 0: + ax.set_ylabel("") + ax.set_yticks([], []) + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + im.set_clim(-1, 1) + # fig.colorbar(im, ax=ax) + # ax.set_title(k) + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + # cbar.set_clim(-1, 1) + return fig def plot_comparison_clustering( diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index cf91c8b873..ab1523d13a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -11,6 +11,9 @@ import numpy as np from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.sparsity import compute_sparsity class MatchingBenchmark(Benchmark): @@ -73,17 +76,15 @@ def plot_agreements(self, case_keys=None, figsize=None): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - return fig - - def plot_performances_vs_snr(self, case_keys=None, figsize=None): + def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, k in enumerate(metrics): - ax = axs[count] + ax = axs[count, 0] for key in case_keys: label = self.cases[key]["label"] @@ -223,13 +224,13 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) - def plot_unit_losses(self, before, after, figsize=None): + def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, k in enumerate(metric): - ax = axs[count] + ax = axs[0, count] label = self.cases[after]["label"] @@ -241,15 +242,20 @@ def plot_unit_losses(self, before, after, figsize=None): y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values - if count < 2: - ax.set_xticks([], []) - elif count == 2: - ax.set_xlabel("depth (um)") - im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") - fig.colorbar(im, ax=ax) + # if count < 2: + # ax.set_xticks([], []) + # elif count == 2: + ax.set_xlabel("depth (um)") + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + fig.colorbar(im, ax=ax, label=k) + im.set_clim(-1, 1) ax.set_title(k) ax.set_ylabel("snr") + # fig.subplots_adjust(right=0.85) + # cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + # cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + # if count == 2: # ax.legend() return fig diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 062309b581..7d862343d2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -196,6 +196,8 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_thres abs_threshold = -detect_threshold * noise_levels ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") + return fig + def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)): if case_keys is None: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index e9f128993d..aaa67e3aeb 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -10,9 +10,11 @@ from spikeinterface.core import SortingAnalyzer -from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer + +from spikeinterface import load_extractor, create_sorting_analyzer, load_sorting_analyzer from spikeinterface.widgets import get_some_colors + import pickle _key_separator = "_-°°-_" diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 0872a6066c..facefac4c5 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -140,3 +140,14 @@ def remove_empty_templates(templates): probe=templates.probe, is_scaled=templates.is_scaled, ) + + +def sigmoid(x, x0, k, b): + return (1 / (1 + np.exp(-k * (x - x0)))) + b + + +def fit_sigmoid(xdata, ydata, p0=None): + from scipy.optimize import curve_fit + + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + return popt diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index a2c366851b..85043d0d12 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -30,9 +30,7 @@ def __init__( case_keys = list(study.cases.keys()) plot_data = dict( - study=study, - run_times=study.get_run_times(case_keys), - case_keys=case_keys, + study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -48,8 +46,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for i, key in enumerate(dp.case_keys): label = dp.study.cases[key]["label"] rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label) - + self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) + self.ax.set_ylabel("run time (s)") self.ax.legend() @@ -167,6 +165,8 @@ def __init__( case_keys=case_keys, ) + self.colors = study.get_colors() + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): @@ -192,7 +192,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label = study.cases[key]["label"] val = perfs.xs(key).loc[:, performance_name].values val = np.sort(val)[::-1] - ax.plot(val, label=label) + ax.plot(val, label=label, c=self.colors[key]) ax.set_title(performance_name) if count == len(dp.performance_names) - 1: ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) @@ -207,7 +207,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x = study.get_metrics(key).loc[:, metric_name].values y = perfs.xs(key).loc[:, performance_name].values label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label) + ax.scatter(x, y, s=10, label=label, color=self.colors[key]) max_metric = max(max_metric, np.max(x)) ax.set_title(performance_name) ax.set_xlim(0, max_metric * 1.05) From 0422cfbab1692c57eea3743bdac7b88ee0d4a7dc Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 5 Jul 2024 14:31:18 +0200 Subject: [PATCH 32/57] Mcsh5 offsets and proper scaling in uV for return_scaled (#2988) Mcsh5 offsets and proper scaling in uV for return_scaled --- src/spikeinterface/extractors/mcsh5extractors.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/extractors/mcsh5extractors.py b/src/spikeinterface/extractors/mcsh5extractors.py index f419b7e64d..9485536d97 100644 --- a/src/spikeinterface/extractors/mcsh5extractors.py +++ b/src/spikeinterface/extractors/mcsh5extractors.py @@ -61,6 +61,9 @@ def __init__(self, file_path, stream_id=0): # set gain self.set_channel_gains(mcs_info["gain"]) + # set offsets + self.set_channel_offsets(mcs_info["offset"]) + # set other properties self.set_property("electrode_labels", mcs_info["electrode_labels"]) @@ -100,7 +103,11 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): def openMCSH5File(filename, stream_id): - """Open an MCS hdf5 file, read and return the recording info.""" + """Open an MCS hdf5 file, read and return the recording info. + Specs can be found online + https://www.multichannelsystems.com/downloads/documentation?page=3 + """ + import h5py rf = h5py.File(filename, "r") @@ -121,7 +128,8 @@ def openMCSH5File(filename, stream_id): Tick = info["Tick"][0] / 1e6 exponent = info["Exponent"][0] convFact = info["ConversionFactor"][0] - gain = convFact.astype(float) * (10.0**exponent) + gain_uV = 1e6 * (convFact.astype(float) * (10.0**exponent)) + offset_uV = -1e6 * (info["ADZero"].astype(float) * (10.0**exponent)) * gain_uV nRecCh, nFrames = data.shape channel_ids = [f"Ch{ch}" for ch in info["ChannelID"]] @@ -149,8 +157,9 @@ def openMCSH5File(filename, stream_id): "num_channels": nRecCh, "channel_ids": channel_ids, "electrode_labels": electrodeLabels, - "gain": gain, + "gain": gain_uV, "dtype": dtype, + "offset": offset_uV, } return mcs_info From 85d504dd9e0bc685069c4e426fcd32a99df48f5d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 5 Jul 2024 15:44:10 +0200 Subject: [PATCH 33/57] docstrings --- src/spikeinterface/core/sorting_tools.py | 54 ++++++++++++------------ 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 8d038aa45b..ca9697b222 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -229,32 +229,33 @@ def apply_merges_to_sorting( sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" ): """ - Function to apply a resolved representation of the merges to a sorting object. + Apply a resolved representation of the merges to a sorting object. - This function is not lazy and create a new NumpySorting with a compact spike_vector as fast as possible. + This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible. - If censor_ms is not None, duplicated spikes violating the censor_ms refractory period are removed. + If `censor_ms` is not None, duplicated spikes violating the `censor_ms` refractory period are removed. - Optionaly, the boolean of kept spikes is returned + Optionally, the boolean mask of kept spikes is returned. Parameters ---------- sorting : Sorting - The Sorting object to apply merges + The Sorting object to apply merges. units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list + new_unit_ids : list | None, default: None A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, - merged units will have the first unit_id of every lists of merges - censor_ms: None or float + merged units will have the first unit_id of every lists of merges. + censor_ms: float | None, default: None When applying the merges, should be discard consecutive spikes violating a given refractory per - return_kept : bool, default False - return also a booolean of kept spikes - new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. - "append" : new_units_ids will be added at the end of max(sorging.unit_ids) - "take_first" : new_unit_ids will be the first unit_id of every list of merges + return_kept : bool, default: False + If True, also return also a booolean mask of kept spikes. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges Returns ------- @@ -336,7 +337,7 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list + new_unit_ids : list | None A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. Returns @@ -367,28 +368,29 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids - are provided, it will return these unit ids, checking that they have the length as - to_be_merged. + are provided, it will return these unit ids, checking that they have the the same + length as `units_to:merge`. Parameters ---------- old_unit_ids : np.array - The old unit_ids + The old unit_ids. units_to_merge : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). - new_unit_ids : None or list - A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, - merged units will have the first unit_id of every lists of merges - new_id_strategy : "append" | "take_first", default "append" - The strategy that should be used, if new_unit_ids is None, to create new unit_ids. - "append" : new_units_ids will be added at the end of max(sorging.unit_ids) - "take_first" : new_unit_ids will be the first unit_id of every list of merges + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges Returns ------- new_unit_ids : The new unit ids - The new units_ids associated with the merges + The new units_ids associated with the merges. """ From a0d8097cc69d8efca1443ce5e8edb3152285e640 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 5 Jul 2024 08:27:51 -0600 Subject: [PATCH 34/57] Update src/spikeinterface/core/binaryrecordingextractor.py Co-authored-by: Garcia Samuel --- src/spikeinterface/core/binaryrecordingextractor.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 3733c1f0c3..f91d8165df 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -193,14 +193,6 @@ def get_traces( end_frame: int | None = None, channel_indices: list | None = None, ) -> np.ndarray: - if start_frame is None: - start_frame = 0 - - if end_frame is None: - end_frame = self.get_num_samples() - - if end_frame > self.get_num_samples(): - raise ValueError(f"end_frame {end_frame} is larger than the number of samples {self.get_num_samples()}") # Calculate byte offsets for start and end frames start_byte = self.file_offset + start_frame * self.bytes_per_sample From 752603ee4dd9678c2f66322079e3c84487552c0b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 17:51:10 +0200 Subject: [PATCH 35/57] Add apply_merges_to_sorting in api.rst --- doc/api.rst | 5 ++++- src/spikeinterface/core/__init__.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c5c9ebe4dd..c73cd812da 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -60,6 +60,10 @@ spikeinterface.core .. autofunction:: select_segment_sorting .. autofunction:: read_binary .. autofunction:: read_zarr + .. autofunction:: apply_merges_to_sorting + .. autofunction:: spike_vector_to_spike_trains + .. autofunction:: random_spikes_selection + Low-level ~~~~~~~~~ @@ -67,7 +71,6 @@ Low-level .. automodule:: spikeinterface.core :noindex: - .. autoclass:: BaseWaveformExtractorExtension .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index a5e1f44842..674f1ac463 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -101,7 +101,7 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection +from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting From 669a9c25b3ef01b2a38620f77a34c0937aff7847 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 5 Jul 2024 10:20:30 -0600 Subject: [PATCH 36/57] DRY --- .../extractors/nwbextractors.py | 101 ++++++------------ 1 file changed, 35 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d213126f34..9aa8b1b907 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -401,7 +401,40 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect return electrodes_indices -class NwbRecordingExtractor(BaseRecording): +class _BaseNWBExtractor: + "A class for common methods for NWB extractors." + + def _close_hdf5_file(self): + has_hdf5_backend = hasattr(self, "_file") + if has_hdf5_backend: + import h5py + + main_file_id = self._file.id + open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) + for object_id in open_object_ids_main: + object_name = h5py.h5i.get_name(object_id).decode("utf-8") + try: + object_id.close() + except: + import warnings + + warnings.warn(f"Error closing object {object_name}") + + def __del__(self): + # backend mode + if hasattr(self, "_file"): + if hasattr(self._file, "store"): + self._file.store.close() + else: + self._close_hdf5_file() + # pynwb mode + elif hasattr(self, "_nwbfile"): + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + + +class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): """Load an NWBFile as a RecordingExtractor. Parameters @@ -626,35 +659,6 @@ def __init__( "file": file, } - def _close_hdf5_file(self): - has_hdf5_backend = hasattr(self, "_file") - if has_hdf5_backend: - import h5py - - main_file_id = self._file.id - open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) - for object_id in open_object_ids_main: - object_name = h5py.h5i.get_name(object_id).decode("utf-8") - try: - object_id.close() - except: - import warnings - - warnings.warn(f"Error closing object {object_name}") - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._close_hdf5_file() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation): self._nwbfile = read_nwbfile( backend=self.backend, @@ -968,7 +972,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -class NwbSortingExtractor(BaseSorting): +class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor): """Load an NWBFile as a SortingExtractor. Parameters ---------- @@ -1127,41 +1131,6 @@ def __init__( "t_start": self.t_start, } - def _close_hdf5_file(self): - has_hdf5_backend = hasattr(self, "_file") - if has_hdf5_backend: - import h5py - - main_file_id = self._file.id - open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) - for object_id in open_object_ids_main: - object_name = h5py.h5i.get_name(object_id).decode("utf-8") - try: - object_id.close() - except: - import warnings - - warnings.warn(f"Error closing object {object_name}") - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._close_hdf5_file() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - - # pynwb mode - elif hasattr(self, "_nwbfile"): # hdf - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_sorting_segment_info_pynwb( self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False ): From a4ad437d4ad27be283faa18d30ba9fdac3028015 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 5 Jul 2024 19:07:33 +0200 Subject: [PATCH 37/57] more checks --- src/spikeinterface/core/sorting_tools.py | 15 ++++++++------- src/spikeinterface/curation/mergeunitssorting.py | 2 -- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index ca9697b222..d2104eec73 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -282,10 +282,8 @@ def apply_merges_to_sorting( all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() - segment_limits = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) - segment_slices = [] - for i in range(num_seg): - segment_slices += [(segment_limits[i], segment_limits[i + 1])] + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] # using this function vaoid to use the mask approach and simplify a lot the algo spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] @@ -369,7 +367,7 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the the same - length as `units_to:merge`. + length as `units_to_merge`. Parameters ---------- @@ -391,13 +389,16 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids ------- new_unit_ids : The new unit ids The new units_ids associated with the merges. - - """ old_unit_ids = np.asarray(old_unit_ids) if new_unit_ids is not None: + # then only doing a consistency check assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + for i, new_unit_id in enumerate(new_unit_ids): + if new_unit_id in old_unit_ids: + assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" else: dtype = old_unit_ids.dtype num_merge = len(units_to_merge) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 3771b1c63c..11f26ea778 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -45,8 +45,6 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() - from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group - new_unit_ids = generate_unit_ids_for_merge_group( sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" ) From 1e8b5515c8feab49bc5564e0ad347fd11c0586c6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 5 Jul 2024 11:52:22 -0600 Subject: [PATCH 38/57] Better error message for invalid parameters in sorter (#3156) * better error for invdalid parameters in sorter * zach suggestion add typing --- src/spikeinterface/sorters/basesorter.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 788044c0f1..3502d27548 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -15,7 +15,7 @@ import warnings -from spikeinterface.core import load_extractor, BaseRecordingSnippets +from spikeinterface.core import load_extractor, BaseRecordingSnippets, BaseRecording from spikeinterface.core.core_tools import check_json from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs @@ -167,16 +167,20 @@ def params_description(cls): return p @classmethod - def set_params_to_folder(cls, recording, output_folder, new_params, verbose): + def set_params_to_folder( + cls, + recording: BaseRecording, + output_folder: str | Path, + new_params: dict, + verbose: bool, + ) -> dict: params = cls.default_params() + valid_parameters = params.keys() + invalid_parameters = [k for k in new_params.keys() if k not in valid_parameters] - # verify params are in list - bad_params = [] - for p in new_params.keys(): - if p not in params.keys(): - bad_params.append(p) - if len(bad_params) > 0: - raise AttributeError("Bad parameters: " + str(bad_params)) + if invalid_parameters: + error_msg = f"Invalid parameters: {invalid_parameters} \n" f"Valid parameters are: {valid_parameters}" + raise ValueError(error_msg) params.update(new_params) From e46b86a02c6fde2dc9539857a0ff5b16728c3419 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jul 2024 17:53:49 +0000 Subject: [PATCH 39/57] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d2104eec73..918d95bf52 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -395,7 +395,7 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids if new_unit_ids is not None: # then only doing a consistency check assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" - # new_unit_ids can also be part of old_unit_ids only inside the same group: + # new_unit_ids can also be part of old_unit_ids only inside the same group: for i, new_unit_id in enumerate(new_unit_ids): if new_unit_id in old_unit_ids: assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" From f1d2755103cb8f9397e12aa1dad31033ba2e9fb3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 13:33:09 +0200 Subject: [PATCH 40/57] Add _get_t_starts function and move t_starts retrieval to from_recording functions --- src/spikeinterface/core/baserecording.py | 31 +++++++++++++++------- src/spikeinterface/core/numpyextractors.py | 9 +++++-- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f4a276a396..3c46193c02 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -501,24 +501,35 @@ def time_to_sample_index(self, time_s, segment_index=None): rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) - def _save(self, format="binary", verbose: bool = False, **save_kwargs): + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for segment_index, rs in enumerate(self._recording_segments): + for rs in self._recording_segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) - has_time_vectors.append(d["time_vector"] is not None) if all(t_start is None for t_start in t_starts): t_starts = None + return t_starts + def _get_time_vectors(self): + time_vectors = [] + for rs in self._recording_segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + + def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -548,11 +559,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording - cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs) + cached = SharedMemoryRecording.from_recording(self, **job_kwargs) else: from spikeinterface.core import NumpyRecording - cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs) + cached = NumpyRecording.from_recording(self, **job_kwargs) elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor @@ -575,11 +586,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] - if time_vector is not None: - cached._recording_segments[segment_index].time_vector = time_vector + time_vectors = self._get_time_vectors() + if time_vectors is not None: + for segment_index, time_vector in enumerate(time_vectors): + if time_vector is not None: + cached.set_times(time_vector, segment_index=segment_index) return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 2cf03927a3..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -81,8 +81,11 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N } @staticmethod - def from_recording(source_recording, t_starts=None, **job_kwargs): + def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + + t_starts = source_recording._get_t_starts() + if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array # this can lead to problem @@ -204,9 +207,11 @@ def __del__(self): shm.unlink() @staticmethod - def from_recording(source_recording, t_starts=None, **job_kwargs): + def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + t_starts = source_recording._get_t_starts() + recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], shape_list=[traces.shape for traces in traces_list], From 11bf61e75729dea45ef40d8681ad6d0a9120368a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 14:42:12 +0200 Subject: [PATCH 41/57] Protect against name annotation being None --- src/spikeinterface/core/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index c985586cab..f3e300fff4 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -81,7 +81,8 @@ def __init__(self, main_ids: Sequence) -> None: @property def name(self): - return self._annotations.get("name", self.__class__.__name__) + name = self._annotations.get("name", None) + return name if name is not None else self.__class__.__name__ @name.setter def name(self, value): From 2d09eb5d24136c2e7371e7b3f5de85fcae136699 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 14:52:05 +0200 Subject: [PATCH 42/57] Fix segment start/end frame None in concatenate_recordings --- src/spikeinterface/core/segmentutils.py | 10 +++++++--- src/spikeinterface/core/tests/test_segmentutils.py | 4 ---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index b23b7202c6..08583f71c9 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -191,16 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices): seg_start = self.cumsum_length[i] if i == i0: # first - traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices) + end_frame = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame, channel_indices) all_traces.append(traces_chunk) elif i == i1: # last if (end_frame - seg_start) > 0: - traces_chunk = rec_seg.get_traces(None, end_frame - seg_start, channel_indices) + start_frame = 0 + traces_chunk = rec_seg.get_traces(start_frame, end_frame - seg_start, channel_indices) all_traces.append(traces_chunk) else: # in between - traces_chunk = rec_seg.get_traces(None, None, channel_indices) + start_frame = 0 + end_frame = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame, end_frame, channel_indices) all_traces.append(traces_chunk) traces = np.concatenate(all_traces, axis=0) diff --git a/src/spikeinterface/core/tests/test_segmentutils.py b/src/spikeinterface/core/tests/test_segmentutils.py index d3c73805f0..166ecafd09 100644 --- a/src/spikeinterface/core/tests/test_segmentutils.py +++ b/src/spikeinterface/core/tests/test_segmentutils.py @@ -5,10 +5,6 @@ from numpy.testing import assert_raises from spikeinterface.core import ( - AppendSegmentRecording, - AppendSegmentSorting, - ConcatenateSegmentRecording, - ConcatenateSegmentSorting, NumpyRecording, NumpySorting, append_recordings, From 8bfc520032f111ade8f8ac9979b1f58c08f20a69 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Jul 2024 14:56:35 +0200 Subject: [PATCH 43/57] oups --- src/spikeinterface/core/segmentutils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 08583f71c9..039fa8fd60 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -191,20 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices): seg_start = self.cumsum_length[i] if i == i0: # first - end_frame = rec_seg.get_num_samples() - traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame, channel_indices) + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame_, channel_indices) all_traces.append(traces_chunk) elif i == i1: # last if (end_frame - seg_start) > 0: - start_frame = 0 - traces_chunk = rec_seg.get_traces(start_frame, end_frame - seg_start, channel_indices) + start_frame_ = 0 + traces_chunk = rec_seg.get_traces(start_frame_, end_frame - seg_start, channel_indices) all_traces.append(traces_chunk) else: # in between - start_frame = 0 - end_frame = rec_seg.get_num_samples() - traces_chunk = rec_seg.get_traces(start_frame, end_frame, channel_indices) + start_frame_ = 0 + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame_, end_frame_, channel_indices) all_traces.append(traces_chunk) traces = np.concatenate(all_traces, axis=0) From 989aa8b4eb29b70e04b3e9a730ab649eee2a084f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:45:44 +0100 Subject: [PATCH 44/57] Add more tests for time handling. --- .../core/tests/test_time_handling.py | 280 ++++++++++++++++++ 1 file changed, 280 insertions(+) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 487a893096..8a6971b0b7 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,9 +1,289 @@ +import copy + import pytest import numpy as np from spikeinterface.core import generate_recording, generate_sorting +import spikeinterface.full as si + + +class TestTimeHandling: + + # Fixtures ##### + @pytest.fixture(scope="session") + def raw_recording(self): + """ + A three-segment raw recording without times added. + """ + durations = [10, 15, 20] + recording = generate_recording(num_channels=4, durations=durations) + return recording + + @pytest.fixture(scope="session") + def time_vector_recording(self, raw_recording): + """ + Add time vectors to the recording, returning the + raw recording, recording with time vectors added to + segments, and list a the time vectors added to the recording. + """ + return self._get_time_vector_recording(raw_recording) + + @pytest.fixture(scope="session") + def t_start_recording(self, raw_recording): + """ + Add a t_starts to the recording, returning the + raw recording, recording with t_starts added to segments, + and a list of the time vectors generated from adding the + t_start to the recording times. + """ + return self._get_t_start_recording(raw_recording) + + def _get_time_vector_recording(self, raw_recording): + """ + Loop through all recording segments, adding a different time + vector to each segment. The time vector is the original times with + a t_start and irregularly spaced offsets to mimic irregularly + spaced timeseries data. Return the original recording, + recoridng with time vectors added and list including the added time vectors. + """ + times_recording = copy.deepcopy(raw_recording) + all_time_vectors = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = segment_index + 1 * 100 + offsets = np.arange(times_recording.get_num_samples(segment_index)) * ( + 1 / times_recording.get_sampling_frequency() + ) + time_vector = t_start + times_recording.get_times(segment_index) + offsets + + all_time_vectors.append(time_vector) + times_recording.set_times(times=time_vector, segment_index=segment_index) + + assert np.array_equal( + times_recording._recording_segments[segment_index].time_vector, + time_vector, + ), "time_vector was not properly set during test setup" + + return (raw_recording, times_recording, all_time_vectors) + + def _get_t_start_recording(self, raw_recording): + """ + For each segment in the recording, add a different `t_start`. + Return a list of time vectors generating from the recording times + + the t_starts. + """ + t_start_recording = copy.deepcopy(raw_recording) + + all_t_starts = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = (segment_index + 1) * 100 + + all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) + t_start_recording._recording_segments[segment_index].t_start = t_start + + return (raw_recording, t_start_recording, all_t_starts) + + def _get_fixture_data(self, request, fixture_name): + """ + A convenience function to get the data from a fixture + based on the name. This is used to allow parameterising + tests across fixtures. + """ + time_recording_fixture = request.getfixturevalue(fixture_name) + raw_recording, times_recording, all_times = time_recording_fixture + return (raw_recording, times_recording, all_times) + + # Tests ##### + def test_has_time_vector(self, time_vector_recording): + """ + Test the `has_time_vector` function returns `False` before + a time vector is added and `True` afterwards. + """ + raw_recording, times_recording, _ = time_vector_recording + + for segment_idx in range(raw_recording.get_num_segments()): + + assert raw_recording.has_time_vector(segment_idx) is False + assert times_recording.has_time_vector(segment_idx) is True + + @pytest.mark.parametrize("mode", ["binary", "zarr"]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path): + """ + Test `t_start` or `time_vector` is propagated to a saved recording, + by saving, reloading, and checking times are correct. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + folder_name = "recording" + recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name) + + if mode == "zarr": + folder_name += ".zarr" + recording_load = si.load_extractor(tmp_path / folder_name) + + self._check_times_match(recording_cache, all_times) + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("sharedmem", [True, False]) + def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): + """ + Test t_start and time_vector are propagated to recording saved into memory. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + recording_load = times_recording.save(format="memory", sharedmem=sharedmem) + + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_propagated_to_select_segments(self, request, fixture_name): + """ + Test that when `recording.select_segments()` is used, the times + are propagated to the new recoridng object. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + for segment_index in range(times_recording.get_num_segments()): + segment = times_recording.select_segments(segment_index) + assert np.array_equal(segment.get_times(), all_times[segment_index]) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_sorting(self, request, fixture_name): + """ + Check that when attached to a sorting object, the times are propagated + to the object. This means that all spike times should respect the + `t_start` or `time_vector` added. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + for segment_index in range(raw_recording.get_num_segments()): + + if fixture_name == "time_vector_recording": + assert sorting.has_time_vector(segment_index=segment_index) + + self._check_spike_times_are_correct(sorting, times_recording, segment_index) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_sample_converters(self, request, fixture_name): + """ + Test the `recording.sample_time_to_index` and + `recording.time_to_sample_index` convenience functions. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + with pytest.raises(ValueError) as e: + times_recording.sample_index_to_time(0) + assert "Provide 'segment_index'" in str(e) + + for segment_index in range(times_recording.get_num_segments()): + + sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index)) + time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index) + + assert time_ == all_times[segment_index][sample_index] + + new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index) + + assert new_sample_index == sample_index + + @pytest.mark.parametrize("time_type", ["time_vector", "t_start"]) + @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) + def test_slice_recording(self, time_type, bounds): + """ + Test after `frame_slice` and `time_slice` a recording or + sorting (for `frame_slice`), the recording times are + correct with respect to the set `t_start` or `time_vector`. + """ + raw_recording = generate_recording(num_channels=4, durations=[10]) + + if time_type == "time_vector": + raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording) + else: + raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording) + + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + + # Take some different times, including min and max bounds of + # the recording, and some arbitaray times in the middle (20% and 80%). + if bounds == "start": + start_frame = 0 + end_frame = int(times_recording.get_num_samples(0) * 0.8) + elif bounds == "end": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = times_recording.get_num_samples(0) - 1 + elif bounds == "middle": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = int(times_recording.get_num_samples(0) * 0.8) + + # Slice the recording and get the new times are correct + rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + + assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0) + + # Test `time_slice` + start_time = times_recording.sample_index_to_time(start_frame) + end_time = times_recording.sample_index_to_time(end_frame) + + rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time) + + assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + # Helpers #### + def _check_times_match(self, recording, all_times): + """ + For every segment in a recording, check the `get_times()` + match the expected times in the list of time vectors, `all_times`. + """ + for segment_index in range(recording.get_num_segments()): + assert np.array_equal(recording.get_times(segment_index), all_times[segment_index]) + + def _check_spike_times_are_correct(self, sorting, times_recording, segment_index): + """ + For every unit in the `sorting`, for a particular segment, check that + the unit times match the times of the original recording as + retrieved with `get_times()`. + """ + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) + spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + rec_times = times_recording.get_times(segment_index=segment_index) + + assert np.array_equal( + spike_times, + rec_times[spike_indexes], + ) + + def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): + """ + Convenience function to create a sorting object with + a recording attached. Typically use the raw recordings + for the durations of which to make the sorter, as + the generate_sorter is not setup to handle the + (strange) edge case of the irregularly spaced + test time vectors. + """ + durations = [ + recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments()) + ] + + sorting = generate_sorting(num_units=10, durations=durations) + + sorting.register_recording(recording_to_attach) + assert sorting.has_recording() + + return sorting +# TODO: deprecate original implementations ### def test_time_handling(create_cache_folder): cache_folder = create_cache_folder durations = [[10], [10, 5]] From 5e11c4effbe08c53f4154342138d2143e12b8021 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:46:50 +0100 Subject: [PATCH 45/57] Remove some indirection in the fixtures. --- .../core/tests/test_time_handling.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 8a6971b0b7..fb929ce5a9 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -10,15 +10,6 @@ class TestTimeHandling: # Fixtures ##### - @pytest.fixture(scope="session") - def raw_recording(self): - """ - A three-segment raw recording without times added. - """ - durations = [10, 15, 20] - recording = generate_recording(num_channels=4, durations=durations) - return recording - @pytest.fixture(scope="session") def time_vector_recording(self, raw_recording): """ @@ -26,6 +17,9 @@ def time_vector_recording(self, raw_recording): raw recording, recording with time vectors added to segments, and list a the time vectors added to the recording. """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) + return self._get_time_vector_recording(raw_recording) @pytest.fixture(scope="session") @@ -36,6 +30,9 @@ def t_start_recording(self, raw_recording): and a list of the time vectors generated from adding the t_start to the recording times. """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) + return self._get_t_start_recording(raw_recording) def _get_time_vector_recording(self, raw_recording): From 3ebd3b50d00feb464ec025f8270d9caf60223a89 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 15:49:49 +0100 Subject: [PATCH 46/57] Minor tidy, maintain order of parameterisation across tests. --- src/spikeinterface/core/tests/test_time_handling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index fb929ce5a9..e80564eb14 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -123,8 +123,8 @@ def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_ self._check_times_match(recording_cache, all_times) self._check_times_match(recording_load, all_times) - @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) @pytest.mark.parametrize("sharedmem", [True, False]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): """ Test t_start and time_vector are propagated to recording saved into memory. @@ -191,9 +191,9 @@ def test_time_sample_converters(self, request, fixture_name): @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) def test_slice_recording(self, time_type, bounds): """ - Test after `frame_slice` and `time_slice` a recording or - sorting (for `frame_slice`), the recording times are - correct with respect to the set `t_start` or `time_vector`. + Test times are correct after applying `frame_slice` or `time_slice` + to a recording or sorting (for `frame_slice`). The the recording times + should be correct with respect to the set `t_start` or `time_vector`. """ raw_recording = generate_recording(num_channels=4, durations=[10]) From 5011dd25fd8e0d92407c737a348f3833a7b86c68 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 16:01:12 +0100 Subject: [PATCH 47/57] Fix tests. --- src/spikeinterface/core/tests/test_time_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index e80564eb14..49fa622f7a 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -11,7 +11,7 @@ class TestTimeHandling: # Fixtures ##### @pytest.fixture(scope="session") - def time_vector_recording(self, raw_recording): + def time_vector_recording(self): """ Add time vectors to the recording, returning the raw recording, recording with time vectors added to @@ -23,7 +23,7 @@ def time_vector_recording(self, raw_recording): return self._get_time_vector_recording(raw_recording) @pytest.fixture(scope="session") - def t_start_recording(self, raw_recording): + def t_start_recording(self): """ Add a t_starts to the recording, returning the raw recording, recording with t_starts added to segments, From 08adf1304833436829eca0951f50c4081f3914ca Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 16:07:40 +0100 Subject: [PATCH 48/57] Add class docstring. --- src/spikeinterface/core/tests/test_time_handling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 49fa622f7a..5d46fb3eed 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -8,6 +8,12 @@ class TestTimeHandling: + """ + This class tests how time is handled in SpikeInterface. Under the hood, + time can be represented as a full `time_vector` or only as + `t_start` attribute on segments from which a vector of times + is generated on the fly. Both time representations are tested here. + """ # Fixtures ##### @pytest.fixture(scope="session") From b2bac4d46679c18d6b6a53079ec9ff4d02e507eb Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 8 Jul 2024 18:47:17 +0100 Subject: [PATCH 49/57] Make test time vector actually irregularly spaced! --- src/spikeinterface/core/tests/test_time_handling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 5d46fb3eed..eb169b77d5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -54,9 +54,12 @@ def _get_time_vector_recording(self, raw_recording): for segment_index in range(raw_recording.get_num_segments()): t_start = segment_index + 1 * 100 - offsets = np.arange(times_recording.get_num_samples(segment_index)) * ( + + some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( 1 / times_recording.get_sampling_frequency() ) + + offsets = np.cumsum(some_small_increasing_numbers) time_vector = t_start + times_recording.get_times(segment_index) + offsets all_time_vectors.append(time_vector) From 4e0f5879d9a5602555b66cdde92a1b59b2b958b3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 8 Jul 2024 14:16:13 -0600 Subject: [PATCH 50/57] small improvement to repr --- src/spikeinterface/core/baserecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index dea622c482..9a9747bf0b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -103,7 +103,7 @@ def _repr_header(self): total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() - sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10000 else f"{sf_hz}Hz" + sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( f"{self.name}: " From d3cba031b3f5d44bcfd02f8e104d2716a2aef23b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 10:39:23 +0200 Subject: [PATCH 51/57] Remove deprecated tests --- .../core/tests/test_time_handling.py | 66 ------------------- 1 file changed, 66 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index eb169b77d5..049d5ab6e5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -287,69 +287,3 @@ def _get_sorting_with_recording_attached(self, recording_for_durations, recordin assert sorting.has_recording() return sorting - - -# TODO: deprecate original implementations ### -def test_time_handling(create_cache_folder): - cache_folder = create_cache_folder - durations = [[10], [10, 5]] - - # test multi-segment - for i, dur in enumerate(durations): - rec = generate_recording(num_channels=4, durations=dur) - sort = generate_sorting(num_units=10, durations=dur) - - for segment_index in range(rec.get_num_segments()): - original_times = rec.get_times(segment_index=segment_index) - new_times = original_times + 5 - rec.set_times(new_times, segment_index=segment_index) - - sort.register_recording(rec) - assert sort.has_recording() - - rec_cache = rec.save(folder=cache_folder / f"rec{i}") - - for segment_index in range(sort.get_num_segments()): - assert rec.has_time_vector(segment_index=segment_index) - assert sort.has_time_vector(segment_index=segment_index) - - # times are correctly saved by the recording - assert np.allclose( - rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index) - ) - - # spike times are correctly adjusted - for u in sort.get_unit_ids(): - spike_times = sort.get_unit_spike_train(u, segment_index=segment_index, return_times=True) - rec_times = rec.get_times(segment_index=segment_index) - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) - - -def test_frame_slicing(): - duration = [10] - - rec = generate_recording(num_channels=4, durations=duration) - sort = generate_sorting(num_units=10, durations=duration) - - original_times = rec.get_times() - new_times = original_times + 5 - rec.set_times(new_times) - - sort.register_recording(rec) - - start_frame = 3 * rec.get_sampling_frequency() - end_frame = 7 * rec.get_sampling_frequency() - - rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) - sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) - - for u in sort_slice.get_unit_ids(): - spike_times = sort_slice.get_unit_spike_train(u, return_times=True) - rec_times = rec_slice.get_times() - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) - - -if __name__ == "__main__": - test_frame_slicing() From 48c10bcf454bf52a81c28166b263d2a730c83fd7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 11:55:59 +0200 Subject: [PATCH 52/57] Extend docs and API for generation module --- doc/api.rst | 46 +++++++++++++++++++ .../benchmark_with_hybrid_recordings.rst | 2 +- doc/modules/generation.rst | 27 +++++++++-- src/spikeinterface/generation/__init__.py | 20 ++++++++ src/spikeinterface/generation/noise_tools.py | 13 ++++-- 5 files changed, 98 insertions(+), 10 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c73cd812da..3e825084e7 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -338,14 +338,60 @@ spikeinterface.curation spikeinterface.generation ------------------------- +Core +~~~~ .. automodule:: spikeinterface.generation + .. autofunction:: generate_recording + .. autofunction:: generate_sorting + .. autofunction:: generate_snippets + .. autofunction:: generate_templates + .. autofunction:: generate_recording_by_size + .. autofunction:: generate_ground_truth_recording + .. autofunction:: add_synchrony_to_sorting + .. autofunction:: synthesize_random_firings + .. autofunction:: inject_some_duplicate_units + .. autofunction:: inject_some_split_units + .. autofunction:: synthetize_spike_train_bad_isi + .. autofunction:: inject_templates + .. autofunction:: noise_generator_recording + .. autoclass:: InjectTemplatesRecording + .. autoclass:: NoiseGeneratorRecording + +Drift +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_drifting_recording + .. autofunction:: generate_displacement_vector + .. autofunction:: make_one_displacement_vector .. autofunction:: make_linear_displacement .. autofunction:: move_dense_templates .. autofunction:: interpolate_templates .. autoclass:: DriftingTemplates .. autoclass:: InjectDriftingTemplatesRecording +Hybrid +~~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_hybrid_recording + .. autofunction:: estimate_templates_from_recording + .. autofunction:: select_templates + .. autofunction:: scale_template_to_range + .. autofunction:: relocate_templates + .. autofunction:: fetch_template_object_from_database + .. autofunction:: fetch_templates_database_info + .. autofunction:: list_available_datasets_in_template_database + .. autofunction:: query_templates_from_database + + +Noise +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_noise + spikeinterface.sortingcomponents -------------------------------- diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 9e8c6c7d65..5870d87955 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -9,7 +9,7 @@ with known spiking activity. The template (aka average waveforms) of the injected units can be from previous spike sorted data. In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on -`DANDI `__). +`DANDI `_). Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in diff --git a/doc/modules/generation.rst b/doc/modules/generation.rst index a647919489..79893aa88d 100644 --- a/doc/modules/generation.rst +++ b/doc/modules/generation.rst @@ -1,9 +1,28 @@ Generation module ================= -The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes. -This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets). +The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes, +which can be used as "ground-truth" for benchmarking spike sorting algorithms. +There are several approaches to generating such recordings. +One possibility is to generate purely synthetic recordings. Another approach is to use real +recordings and add synthetic spikes to them, to make "hybrid" recordings. +The advantage of the former is that the ground-truth is known exactly, which is useful for benchmarking. +The advantage of the latter is that the spikes are added to real noise, which can be more realistic. -The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex -machinery, for instance generating recordings that possess various types of drift. +For hybrid recordings, the main challenge is to generate realistic spike templates. +We therefore built an open database of templates that we have constructed from the International +Brain Laboratory - Brain Wide Map (available on +`DANDI `_). +You can checkout this collection of over 600 templates from this [web app](https://spikeinterface.github.io/hybrid_template_library/). + +The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates, +manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them. +Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. +Such drifts can be taken into account in order to smoothly inject spikes into the recording. + +The :py:mod:`spikeinterface.generation` also includes functions to generate different kinds of drift signals and drifting +recordings, as well as generating synthetic noise profiles of various types. + +Some of the generation functions are defined in the :py:mod:`spikeinterface.core.generate` module, but also exposed at the +:py:mod:`spikeinterface.generation` level for convenience. diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 7a2291d932..5bf42ecf0f 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -14,6 +14,7 @@ relocate_templates, ) from .noise_tools import generate_noise + from .drifting_generator import ( make_one_displacement_vector, generate_displacement_vector, @@ -26,3 +27,22 @@ list_available_datasets_in_template_database, query_templates_from_database, ) + +# expose the core generate functions +from ..core.generate import ( + generate_recording, + generate_sorting, + generate_snippets, + generate_templates, + generate_recording_by_size, + generate_ground_truth_recording, + add_synchrony_to_sorting, + synthesize_random_firings, + inject_some_duplicate_units, + inject_some_split_units, + synthetize_spike_train_bad_isi, + NoiseGeneratorRecording, + noise_generator_recording, + InjectTemplatesRecording, + inject_templates, +) diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index 11f30e352f..d0aee138b6 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -7,22 +7,25 @@ def generate_noise( probe, sampling_frequency, durations, dtype="float32", noise_levels=15.0, spatial_decay=None, seed=None ): """ + Generate a noise recording. Parameters ---------- probe : Probe A probe object. sampling_frequency : float - Sampling frequency + The sampling frequency of the recording. durations : list of float - Durations + The duration(s) of the recording. dtype : np.dtype - Dtype - noise_levels : float | np.array | tuple + The dtype of the recording. + noise_levels : float | np.array | tuple, default: 15.0 If scalar same noises on all channels. If array then per channels noise level. If tuple, then this represent the range. - seed : None | int + spatial_decay : float | None, default: None + If not None, the spatial decay of the noise used to generate the noise covariance matrix. + seed : int | None, default: None The seed for random generator. Returns From b97fc6b5fb6f4f8ce84c8a409f6a6ff88d1027b8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 12:50:04 +0200 Subject: [PATCH 53/57] Update doc/modules/generation.rst Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- doc/modules/generation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/generation.rst b/doc/modules/generation.rst index 79893aa88d..191cb57f30 100644 --- a/doc/modules/generation.rst +++ b/doc/modules/generation.rst @@ -14,7 +14,7 @@ For hybrid recordings, the main challenge is to generate realistic spike templat We therefore built an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on `DANDI `_). -You can checkout this collection of over 600 templates from this [web app](https://spikeinterface.github.io/hybrid_template_library/). +You can check out this collection of over 600 templates from this `web app `_. The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates, manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them. From b3647e95f28e4cbf5584fc5bad2a345a5183ae0c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 12:57:05 +0200 Subject: [PATCH 54/57] Add generate_snippets docs and return for fetch_template_object_from_database --- src/spikeinterface/core/generate.py | 29 +++++++++++++++++++ .../generation/template_database.py | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4f3977d7bb..01c719c39d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -585,6 +585,35 @@ def generate_snippets( empty_units=None, **job_kwargs, ): + """ + Generates a synthetic Snippets object. + + Parameters + ---------- + nbefore : int, default: 20 + Number of samples before the peak. + nafter : int, default: 44 + Number of samples after the peak. + num_channels : int, default: 2 + Number of channels. + wf_folder : str | Path | None, default: None + Optional folder to save the waveform snippets. If None, snippets are in memory. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the snippets. + ndim : int, default: 2 + The number of dimensions of the probe. + num_units : int, default: 5 + The number of units. + empty_units : list | None, default: None + A list of units that will have no spikes. + + Returns + ------- + snippets : NumpySnippets + The snippets object. + sorting : NumpySorting + The associated sorting object. + """ recording = generate_recording( durations=durations, num_channels=num_channels, diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index e1cba07c8e..17d2bdf521 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -20,7 +20,7 @@ def fetch_template_object_from_database(dataset="test_templates.zarr") -> Templa Returns ------- Templates - _description_ + The templates object. """ s3_path = f"s3://spikeinterface-template-database/{dataset}/" zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True}) From 0cbbbe9de409899e3b4428b115732972a5017cba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 13:08:55 +0200 Subject: [PATCH 55/57] Update src/spikeinterface/generation/noise_tools.py --- src/spikeinterface/generation/noise_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index d0aee138b6..685f0113b4 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -16,7 +16,7 @@ def generate_noise( sampling_frequency : float The sampling frequency of the recording. durations : list of float - The duration(s) of the recording. + The duration(s) of the recording segment(s) in seconds. dtype : np.dtype The dtype of the recording. noise_levels : float | np.array | tuple, default: 15.0 From 90cb2b12e6d3724d9c70753b438fac2c71db7988 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 13:18:17 +0200 Subject: [PATCH 56/57] MOre docstrings --- src/spikeinterface/core/generate.py | 140 +++++++++++++++++----------- 1 file changed, 87 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 01c719c39d..f5312f9c46 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -103,11 +103,11 @@ def generate_sorting( Parameters ---------- num_units : int, default: 5 - Number of units + Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency + The sampling frequency. durations : list, default: [10.325, 3.5] - Duration of each segment in s + Duration of each segment in s. firing_rates : float, default: 3.0 The firing rate of each unit (in Hz). empty_units : list, default: None @@ -123,12 +123,12 @@ def generate_sorting( border_size_samples : int, default: 20 The size of the border in samples to add border spikes. seed : int, default: None - The random seed + The random seed. Returns ------- sorting : NumpySorting - The sorting object + The sorting object. """ seed = _ensure_seed(seed) rng = np.random.default_rng(seed) @@ -187,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. sync_event_ratio : float The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). seed : int, default: None - The random seed + The random seed. Returns ------- sorting : TransformSorting - The sorting object, keeping track of added spikes + The sorting object, keeping track of added spikes. """ rng = np.random.default_rng(seed) @@ -249,18 +249,18 @@ def generate_sorting_to_inject( Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times - covering entire the entire duration of the segments + covering entire the entire duration of the segments. max_injected_per_unit: int, default 1000 - The maximal number of spikes injected per units + The maximal number of spikes injected per units. injected_rate: float, default 0.05 - The rate at which spikes are injected + The rate at which spikes are injected. refractory_period_ms: float, default 1.5 - The refractory period that should not be violated while injecting new spikes + The refractory period that should not be violated while injecting new spikes. seed: int, default None - The random seed + The random seed. Returns ------- @@ -312,22 +312,22 @@ class TransformSorting(BaseSorting): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. added_spikes_existing_units : np.array (spike_vector) - The spikes that should be added to the sorting object, for existing units + The spikes that should be added to the sorting object, for existing units. added_spikes_new_units: np.array (spike_vector) - The spikes that should be added to the sorting object, for new units + The spikes that should be added to the sorting object, for new units. new_units_ids: list - The unit_ids that should be added if spikes for new units are added + The unit_ids that should be added if spikes for new units are added. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. Returns ------- sorting : TransformSorting - The sorting object with the added spikes and/or units + The sorting object with the added spikes and/or units. """ def __init__( @@ -428,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: the first sorting - sorting2: the second sorting + sorting1: BaseSorting + The first sorting. + sorting2: BaseSorting + The second sorting. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ assert ( sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() @@ -492,12 +494,14 @@ def add_from_unit_dict( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting dict_list: list of dict + A list of dict with unit_ids as keys and spike times as values. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) @@ -515,18 +519,19 @@ def from_times_labels( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting times_list: list of array (or array) - An array of spike times (in frames) + An array of spike times (in frames). labels_list: list of array (or array) - An array of spike labels corresponding to the given times + An array of spike labels corresponding to the given times. unit_ids: list or None, default: None The explicit list of unit_ids that should be extracted from labels_list - If None, then it will be np.unique(labels_list) + If None, then it will be np.unique(labels_list). refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids) @@ -556,6 +561,16 @@ def clean_refractory_period(self): def create_sorting_npz(num_seg, file_path): + """ + Create a NPZ sorting file. + + Parameters + ---------- + num_seg : int + The number of segments. + file_path : str | Path + The file path to save the NPZ file. + """ # create a NPZ sorting file d = {} d["unit_ids"] = np.array([0, 1, 2], dtype="int64") @@ -674,18 +689,18 @@ def synthesize_poisson_spike_vector( Parameters ---------- num_units : int, default: 20 - Number of neuronal units to simulate + Number of neuronal units to simulate. sampling_frequency : float, default: 30000.0 - Sampling frequency in Hz + Sampling frequency in Hz. duration : float, default: 60.0 - Duration of the simulation in seconds + Duration of the simulation in seconds. refractory_period_ms : float, default: 4.0 - Refractory period between spikes in milliseconds + Refractory period between spikes in milliseconds. firing_rates : float or array_like or tuple, default: 3.0 Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with - each element being the firing rate for one unit + each element being the firing rate for one unit. seed : int, default: 0 - Seed for random number generator + Seed for random number generator. Returns ------- @@ -779,27 +794,27 @@ def synthesize_random_firings( Parameters ---------- num_units : int - number of units + Number of units. sampling_frequency : float - sampling rate + Sampling rate. duration : float - duration of the segment in seconds + Duration of the segment in seconds. refractory_period_ms: float - refractory_period in ms + Refractory period in ms. firing_rates: float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. add_shift_shuffle: bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. seed: int, default: None - seed for the generator + Seed for the generator. Returns ------- - times: - Concatenated and sorted times vector - labels: - Concatenated and sorted label vector + times: np.array + Concatenated and sorted times vector. + labels: np.array + Concatenated and sorted label vector. """ @@ -883,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No Parameters ---------- sorting : - Original sorting + Original sorting. num : int - Number of injected units + Number of injected units. max_shift : int - range of the shift in sample + range of the shift in sample. ratio: float Proportion of original spike in the injected units. @@ -938,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): - """ """ + """ + Inject some split units in a sorting. + Parameters + ---------- + sorting : BaseSorting + Original sorting. + split_ids : list + List of unit_ids to split. + num_split : int, default: 2 + Number of split units. + output_ids : bool, default: False + If True, return the new unit_ids. + seed : int, default: None + Random seed. + + Returns + ------- + sorting_with_split : NumpySorting + A sorting with split units. + """ unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" @@ -989,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol num_violations : int Number of contaminating spikes. violation_delta : float, default: 1e-5 - Temporal offset of contaminating spikes (in seconds) + Temporal offset of contaminating spikes (in seconds). Returns ------- @@ -1246,7 +1280,7 @@ def generate_recording_by_size( num_channels: int Number of channels. seed : int, default: None - The seed for np.random.default_rng + The seed for np.random.default_rng. Returns ------- @@ -1646,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording): * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. nbefore: list[int] | int | None, default: None - Where is the center of the template for each unit? + The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. amplitude_factor: list[float] | float | None, default: None The amplitude of each spike for each unit. @@ -1661,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording): You can use int for mono-segment objects. upsample_vector: np.array or None, default: None. When templates is 4d we can simulate a jitter. - Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3] + Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. Returns ------- From 90b95fcdc91f85e15cd3809a679d25039f802adb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 9 Jul 2024 13:30:50 +0200 Subject: [PATCH 57/57] Build extractor dicts automatically --- src/spikeinterface/extractors/extractorlist.py | 18 ++++++++++++++++++ .../extractors/neoextractors/__init__.py | 9 +++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index e56d4fff52..bd35180a7e 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -116,3 +116,21 @@ event_extractor_full_list += neo_event_extractors_list snippets_extractor_full_list = [NpySnippetsExtractor, WaveClusSnippetsExtractor] + +recording_extractor_full_dict = {} +for rec_class in recording_extractor_full_list: + # here we get the class name, remove "Recording" and "Extractor" and make it lower case + rec_class_name = rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower() + recording_extractor_full_dict[rec_class_name] = rec_class + +sorting_extractor_full_dict = {} +for sort_class in sorting_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + sort_class_name = sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower() + sorting_extractor_full_dict[sort_class_name] = sort_class + +event_extractor_full_dict = {} +for event_class in event_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + event_class_name = event_class.__name__.replace("Event", "").replace("Extractor", "").lower() + event_extractor_full_dict[event_class_name] = event_class diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0b11b72b2a..bf52de7c1d 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -36,7 +36,7 @@ ) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets -from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx +from .spikeglx import SpikeGLXRecordingExtractor, SpikeGLXEventExtractor, read_spikeglx, read_spikeglx_event from .tdt import TdtRecordingExtractor, read_tdt from .neo_utils import get_neo_streams, get_neo_num_blocks @@ -73,4 +73,9 @@ Plexon2SortingExtractor, ] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] +neo_event_extractors_list = [ + AlphaOmegaEventExtractor, + OpenEphysBinaryEventExtractor, + Plexon2EventExtractor, + SpikeGLXEventExtractor, +]