diff --git a/doc/api.rst b/doc/api.rst index c5c9ebe4dd..3e825084e7 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -60,6 +60,10 @@ spikeinterface.core .. autofunction:: select_segment_sorting .. autofunction:: read_binary .. autofunction:: read_zarr + .. autofunction:: apply_merges_to_sorting + .. autofunction:: spike_vector_to_spike_trains + .. autofunction:: random_spikes_selection + Low-level ~~~~~~~~~ @@ -67,7 +71,6 @@ Low-level .. automodule:: spikeinterface.core :noindex: - .. autoclass:: BaseWaveformExtractorExtension .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors @@ -335,14 +338,60 @@ spikeinterface.curation spikeinterface.generation ------------------------- +Core +~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_recording + .. autofunction:: generate_sorting + .. autofunction:: generate_snippets + .. autofunction:: generate_templates + .. autofunction:: generate_recording_by_size + .. autofunction:: generate_ground_truth_recording + .. autofunction:: add_synchrony_to_sorting + .. autofunction:: synthesize_random_firings + .. autofunction:: inject_some_duplicate_units + .. autofunction:: inject_some_split_units + .. autofunction:: synthetize_spike_train_bad_isi + .. autofunction:: inject_templates + .. autofunction:: noise_generator_recording + .. autoclass:: InjectTemplatesRecording + .. autoclass:: NoiseGeneratorRecording + +Drift +~~~~~ .. automodule:: spikeinterface.generation + .. autofunction:: generate_drifting_recording + .. autofunction:: generate_displacement_vector + .. autofunction:: make_one_displacement_vector .. autofunction:: make_linear_displacement .. autofunction:: move_dense_templates .. autofunction:: interpolate_templates .. autoclass:: DriftingTemplates .. autoclass:: InjectDriftingTemplatesRecording +Hybrid +~~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_hybrid_recording + .. autofunction:: estimate_templates_from_recording + .. autofunction:: select_templates + .. autofunction:: scale_template_to_range + .. autofunction:: relocate_templates + .. autofunction:: fetch_template_object_from_database + .. autofunction:: fetch_templates_database_info + .. autofunction:: list_available_datasets_in_template_database + .. autofunction:: query_templates_from_database + + +Noise +~~~~~ +.. automodule:: spikeinterface.generation + + .. autofunction:: generate_noise + spikeinterface.sortingcomponents -------------------------------- diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyze_neuropixels.rst similarity index 98% rename from doc/how_to/analyse_neuropixels.rst rename to doc/how_to/analyze_neuropixels.rst index 02e497b0fe..1fe741ea48 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyze_neuropixels.rst @@ -1,4 +1,4 @@ -Analyse Neuropixels datasets +Analyze Neuropixels datasets ============================ This example shows how to perform Neuropixels-specific analysis, @@ -218,7 +218,7 @@ We need to specify which one to read: -.. image:: analyse_neuropixels_files/analyse_neuropixels_8_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_8_1.png Preprocess the recording @@ -286,7 +286,7 @@ is lazy, so you can change the previsous cell (parameters, step order, -.. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_13_0.png .. code:: ipython3 @@ -306,7 +306,7 @@ is lazy, so you can change the previsous cell (parameters, step order, -.. image:: analyse_neuropixels_files/analyse_neuropixels_14_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_14_1.png Should we save the preprocessed data to a binary file? @@ -389,7 +389,7 @@ Noise levels can be estimated on the scaled traces or on the raw -.. image:: analyse_neuropixels_files/analyse_neuropixels_21_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_21_1.png Detect and localize peaks @@ -480,7 +480,7 @@ documentation for motion estimation and correction for more details. -.. image:: analyse_neuropixels_files/analyse_neuropixels_26_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_26_1.png .. code:: ipython3 @@ -502,7 +502,7 @@ documentation for motion estimation and correction for more details. -.. image:: analyse_neuropixels_files/analyse_neuropixels_27_1.png +.. image:: analyze_neuropixels_files/analyze_neuropixels_27_1.png Run a spike sorter diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_13_0.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_13_0.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_13_0.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_13_0.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_14_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_14_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_14_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_14_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_21_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_21_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_21_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_21_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_26_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_26_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_26_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_26_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_27_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_27_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_27_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_27_1.png diff --git a/doc/how_to/analyse_neuropixels_files/analyse_neuropixels_8_1.png b/doc/how_to/analyze_neuropixels_files/analyze_neuropixels_8_1.png similarity index 100% rename from doc/how_to/analyse_neuropixels_files/analyse_neuropixels_8_1.png rename to doc/how_to/analyze_neuropixels_files/analyze_neuropixels_8_1.png diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 9e8c6c7d65..5870d87955 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -9,7 +9,7 @@ with known spiking activity. The template (aka average waveforms) of the injected units can be from previous spike sorted data. In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on -`DANDI `__). +`DANDI `_). Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 64e650deac..cf9cadcfc3 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -8,7 +8,7 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. viewers handle_drift - analyse_neuropixels + analyze_neuropixels load_matlab_data combine_recordings process_by_channel_group diff --git a/doc/modules/generation.rst b/doc/modules/generation.rst index a647919489..191cb57f30 100644 --- a/doc/modules/generation.rst +++ b/doc/modules/generation.rst @@ -1,9 +1,28 @@ Generation module ================= -The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes. -This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets). +The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes, +which can be used as "ground-truth" for benchmarking spike sorting algorithms. +There are several approaches to generating such recordings. +One possibility is to generate purely synthetic recordings. Another approach is to use real +recordings and add synthetic spikes to them, to make "hybrid" recordings. +The advantage of the former is that the ground-truth is known exactly, which is useful for benchmarking. +The advantage of the latter is that the spikes are added to real noise, which can be more realistic. -The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex -machinery, for instance generating recordings that possess various types of drift. +For hybrid recordings, the main challenge is to generate realistic spike templates. +We therefore built an open database of templates that we have constructed from the International +Brain Laboratory - Brain Wide Map (available on +`DANDI `_). +You can check out this collection of over 600 templates from this `web app `_. + +The :py:mod:`spikeinterface.generation` module offers tools to interact with this database to select and download templates, +manupulating (e.g. rescaling and relocating them), and construct hybrid recordings with them. +Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. +Such drifts can be taken into account in order to smoothly inject spikes into the recording. + +The :py:mod:`spikeinterface.generation` also includes functions to generate different kinds of drift signals and drifting +recordings, as well as generating synthetic noise profiles of various types. + +Some of the generation functions are defined in the :py:mod:`spikeinterface.core.generate` module, but also exposed at the +:py:mod:`spikeinterface.generation` level for convenience. diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 6682252349..ba7268b4f0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -9,9 +9,6 @@ import numpy as np from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core.core_tools import SIJsonEncoder -from spikeinterface.core.job_tools import split_job_kwargs - from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface.qualitymetrics import compute_quality_metrics @@ -54,6 +51,7 @@ def __init__(self, study_folder): self.cases = {} self.sortings = {} self.comparisons = {} + self.colors = None self.scan_folder() @@ -175,6 +173,22 @@ def remove_sorting(self, key): if f.exists(): f.unlink() + def set_colors(self, colors=None, map_name="tab20"): + from spikeinterface.widgets import get_some_colors + + if colors is None: + case_keys = list(self.cases.keys()) + self.colors = get_some_colors( + case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 + ) + else: + self.colors = colors + + def get_colors(self): + if self.colors is None: + self.set_colors() + return self.colors + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): if case_keys is None: case_keys = self.cases.keys() diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index a5e1f44842..674f1ac463 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -101,7 +101,7 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection +from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5800166f39..8b037ad10f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -7,7 +7,6 @@ import weakref import json import pickle -import os import random import string from packaging.version import parse @@ -41,7 +40,7 @@ class BaseExtractor: # This replaces the old key_properties # These are annotations/properties that always need to be # dumped (for instance locations, groups, is_fileterd, etc.) - _main_annotations = [] + _main_annotations = ["name"] _main_properties = [] # these properties are skipped by default in copy_metadata @@ -79,6 +78,19 @@ def __init__(self, main_ids: Sequence) -> None: # preferred context for multiprocessing self._preferred_mp_context = None + @property + def name(self): + name = self._annotations.get("name", None) + return name if name is not None else self.__class__.__name__ + + @name.setter + def name(self, value): + if value is not None: + self.annotate(name=value) + else: + # we remove the annotation if it exists + _ = self._annotations.pop("name", None) + def get_num_segments(self) -> int: # This is implemented in BaseRecording or BaseSorting raise NotImplementedError @@ -128,8 +140,18 @@ def ids_to_indices( indices = np.arange(len(self._main_ids)) else: assert isinstance(ids, (list, np.ndarray, tuple)), "'ids' must be a list, np.ndarray or tuple" + + non_existent_ids = [id for id in ids if id not in self._main_ids] + if non_existent_ids: + error_msg = ( + f"IDs {non_existent_ids} are not channel ids of the extractor. \n" + f"Available ids are {self._main_ids} with dtype {self._main_ids.dtype}" + ) + raise ValueError(error_msg) + _main_ids = self._main_ids.tolist() indices = np.array([_main_ids.index(id) for id in ids], dtype=int) + if prefer_slice: if np.all(np.diff(indices) == 1): indices = slice(indices[0], indices[-1] + 1) @@ -928,13 +950,14 @@ def save_to_folder( folder.mkdir(parents=True, exist_ok=False) # dump provenance - provenance_file = folder / f"provenance.json" if self.check_serializability("json"): + provenance_file = folder / f"provenance.json" + self.dump(provenance_file) + elif self.check_serializability("pickle"): + provenance_file = folder / f"provenance.pkl" self.dump(provenance_file) else: - provenance_file.write_text( - json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" - ) + warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") self.save_metadata_to_folder(folder) @@ -1001,7 +1024,6 @@ def save_to_zarr( cached: ZarrExtractor Saved copy of the extractor. """ - import zarr from .zarrextractors import read_zarr save_kwargs.pop("format", None) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e70c95bb65..0ea9426674 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets): Internally handle list of RecordingSegment """ - _main_annotations = ["is_filtered"] + _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] # recording do not handle features @@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): self.annotate(is_filtered=False) def __repr__(self): - - class_name = self.__class__.__name__ - name_to_display = class_name num_segments = self.get_num_segments() txt = self._repr_header() @@ -57,7 +54,7 @@ def __repr__(self): split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100 if split_index != -1: first_line = txt[:split_index] - recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": " + recording_string_space = len(self.name) + 2 # Length of self.name plus ": " white_space_to_align_with_first_line = " " * recording_string_space second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip() txt = first_line + "\n" + second_line @@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6): return txt def _repr_header(self): - class_name = self.__class__.__name__ - name_to_display = class_name num_segments = self.get_num_segments() num_channels = self.get_num_channels() - sf_khz = self.get_sampling_frequency() / 1000.0 + sf_hz = self.get_sampling_frequency() + sf_khz = sf_hz / 1000 dtype = self.get_dtype() total_samples = self.get_total_samples() total_duration = self.get_total_duration() total_memory_size = self.get_total_memory_size() + sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz" txt = ( - f"{name_to_display}: " + f"{self.name}: " f"{num_channels} channels - " - f"{sf_khz:0.1f}kHz - " + f"{sampling_frequency_repr} - " f"{num_segments} segments - " f"{total_samples:,} samples - " f"{convert_seconds_to_str(total_duration)} - " @@ -501,24 +498,35 @@ def time_to_sample_index(self, time_s, segment_index=None): rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) - def _save(self, format="binary", verbose: bool = False, **save_kwargs): + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for segment_index, rs in enumerate(self._recording_segments): + for rs in self._recording_segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) - has_time_vectors.append(d["time_vector"] is not None) if all(t_start is None for t_start in t_starts): t_starts = None + return t_starts + + def _get_time_vectors(self): + time_vectors = [] + for rs in self._recording_segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -575,11 +583,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] - if time_vector is not None: - cached._recording_segments[segment_index].time_vector = time_vector + time_vectors = self._get_time_vectors() + if time_vectors is not None: + for segment_index, time_vector in enumerate(time_vectors): + if time_vector is not None: + cached.set_times(time_vector, segment_index=segment_index) return cached diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 1f3fee74a8..869842779d 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_annotations = [] _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d9a567dedf..2af48407a3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List): self._cached_spike_trains = {} def __repr__(self): - clsname = self.__class__.__name__ nseg = self.get_num_segments() nunits = self.get_num_units() sf_khz = self.get_sampling_frequency() / 1000.0 - txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz" + txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz" if "file_path" in self._kwargs: txt += "\n file_path: {}".format(self._kwargs["file_path"]) return txt diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index fca08d9c26..86f14faa30 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -25,9 +25,6 @@ class BinaryFolderRecording(BinaryRecordingExtractor): The recording """ - mode = "folder" - name = "binaryfolder" - def __init__(self, folder_path): folder_path = Path(folder_path) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 64c1b9b2e6..a0e349728e 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -52,9 +52,6 @@ class BinaryRecordingExtractor(BaseRecording): The recording Extractor """ - mode = "file" - name = "binary" - def __init__( self, file_paths, @@ -166,25 +163,17 @@ def get_binary_description(self): class BinaryRecordingSegment(BaseRecordingSegment): - def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): + def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) self.num_channels = num_channels self.dtype = np.dtype(dtype) self.file_offset = file_offset self.time_axis = time_axis - self.datfile = datfile - self.file = open(self.datfile, "r") - self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize) - if self.time_axis == 0: - self.shape = (self.num_samples, self.num_channels) - else: - self.shape = (self.num_channels, self.num_samples) - - byte_offset = self.file_offset - dtype_size_bytes = self.dtype.itemsize - data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels - self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) - self.memmap_length = data_size_bytes + self.array_offset + self.file_path = file_path + self.file = open(self.file_path, "rb") + self.bytes_per_sample = self.num_channels * self.dtype.itemsize + self.data_size_in_bytes = Path(file_path).stat().st_size - file_offset + self.num_samples = self.data_size_in_bytes // self.bytes_per_sample def get_num_samples(self) -> int: """Returns the number of samples in this signal block @@ -200,23 +189,43 @@ def get_traces( end_frame: int | None = None, channel_indices: list | None = None, ) -> np.ndarray: - length = self.memmap_length - memmap_offset = self.memmap_offset + + # Calculate byte offsets for start and end frames + start_byte = self.file_offset + start_frame * self.bytes_per_sample + end_byte = self.file_offset + end_frame * self.bytes_per_sample + + # Calculate the length of the data chunk to load into memory + length = end_byte - start_byte + + # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY + memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) + memmap_offset *= mmap.ALLOCATIONGRANULARITY + + # Adjust the length so it includes the extra data from rounding down + # the memmap offset to a multiple of ALLOCATIONGRANULARITY + length += start_offset + + # Create the mmap object memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset) - array = np.ndarray.__new__( - np.ndarray, - shape=self.shape, + # Create a numpy array using the mmap object as the buffer + # Note that the shape must be recalculated based on the new data chunk + if self.time_axis == 0: + shape = ((end_frame - start_frame), self.num_channels) + else: + shape = (self.num_channels, (end_frame - start_frame)) + + # Now the entire array should correspond to the data between start_frame and end_frame, so we can use it directly + traces = np.ndarray( + shape=shape, dtype=self.dtype, buffer=memmap_obj, - order="C", - offset=self.array_offset, + offset=start_offset, ) if self.time_axis == 1: - array = array.T + traces = traces.T - traces = array[start_frame:end_frame] if channel_indices is not None: traces = traces[:, channel_indices] diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index b8735dff3c..820b4fcd91 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -15,10 +15,15 @@ class ChannelsAggregationRecording(BaseRecording): def __init__(self, recording_list, renamed_channel_ids=None): + self._recordings = recording_list + + self._perform_consistency_checks() + sampling_frequency = recording_list[0].get_sampling_frequency() + dtype = recording_list[0].get_dtype() + num_segments = recording_list[0].get_num_segments() + # Generate a default list of channel ids that are unique and consecutive numbers as strings. - channel_map = {} num_all_channels = sum(rec.get_num_channels() for rec in recording_list) - if renamed_channel_ids is not None: assert ( len(np.unique(renamed_channel_ids)) == num_all_channels @@ -39,33 +44,6 @@ def __init__(self, recording_list, renamed_channel_ids=None): default_channel_ids = [str(i) for i in range(num_all_channels)] channel_ids = default_channel_ids - ch_id = 0 - for r_i, recording in enumerate(recording_list): - single_channel_ids = recording.get_channel_ids() - single_channel_indices = recording.ids_to_indices(single_channel_ids) - for chan_id, chan_idx in zip(single_channel_ids, single_channel_indices): - channel_map[ch_id] = {"recording_id": r_i, "channel_index": chan_idx} - ch_id += 1 - - sampling_frequency = recording_list[0].get_sampling_frequency() - num_segments = recording_list[0].get_num_segments() - dtype = recording_list[0].get_dtype() - - ok1 = all(sampling_frequency == rec.get_sampling_frequency() for rec in recording_list) - ok2 = all(num_segments == rec.get_num_segments() for rec in recording_list) - ok3 = all(dtype == rec.get_dtype() for rec in recording_list) - ok4 = True - for i_seg in range(num_segments): - num_samples = recording_list[0].get_num_samples(i_seg) - ok4 = all(num_samples == rec.get_num_samples(i_seg) for rec in recording_list) - if not ok4: - break - - if not (ok1 and ok2 and ok3 and ok4): - raise ValueError( - "Recordings do not have consistent sampling frequency, number of segments, data type, or number of samples." - ) - BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) property_keys = recording_list[0].get_property_keys() @@ -99,19 +77,60 @@ def __init__(self, recording_list, renamed_channel_ids=None): "Locations are not unique! " "Cannot aggregate recordings!" ) - # finally add segments + # finally add segments, we need a channel mapping + ch_id = 0 + channel_map = {} + for r_i, recording in enumerate(recording_list): + single_channel_ids = recording.get_channel_ids() + single_channel_indices = recording.ids_to_indices(single_channel_ids) + for chan_id, chan_idx in zip(single_channel_ids, single_channel_indices): + channel_map[ch_id] = {"recording_id": r_i, "channel_index": chan_idx} + ch_id += 1 + for i_seg in range(num_segments): parent_segments = [rec._recording_segments[i_seg] for rec in recording_list] sub_segment = ChannelsAggregationRecordingSegment(channel_map, parent_segments) self.add_recording_segment(sub_segment) - self._recordings = recording_list self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids} @property def recordings(self): return self._recordings + def _perform_consistency_checks(self): + + # Check for consistent sampling frequency across recordings + sampling_frequencies = [rec.get_sampling_frequency() for rec in self.recordings] + sampling_frequency = sampling_frequencies[0] + consistent_sampling_frequency = all(sampling_frequency == sf for sf in sampling_frequencies) + if not consistent_sampling_frequency: + raise ValueError(f"Inconsistent sampling frequency among recordings: {sampling_frequencies}") + + # Check for consistent number of segments across recordings + num_segments_list = [rec.get_num_segments() for rec in self.recordings] + num_segments = num_segments_list[0] + consistent_num_segments = all(num_segments == ns for ns in num_segments_list) + if not consistent_num_segments: + raise ValueError(f"Inconsistent number of segments among recordings: {num_segments_list}") + + # Check for consistent data type across recordings + data_types = [rec.get_dtype() for rec in self.recordings] + dtype = data_types[0] + consistent_dtype = all(dtype == dt for dt in data_types) + if not consistent_dtype: + raise ValueError(f"Inconsistent data type among recordings: {data_types}") + + # Check for consistent number of samples across recordings for each segment + for segment_index in range(num_segments): + num_samples_list = [rec.get_num_samples(segment_index=segment_index) for rec in self.recordings] + num_samples = num_samples_list[0] + consistent_num_samples = all(num_samples == ns for ns in num_samples_list) + if not consistent_num_samples: + raise ValueError( + f"Inconsistent number of samples in segment {segment_index} among recordings: {num_samples_list}" + ) + class ChannelsAggregationRecordingSegment(BaseRecordingSegment): """ diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 1f2e644be6..996718dc42 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -659,3 +659,28 @@ def retrieve_importing_provenance(a_class): } return info + + +def measure_memory_allocation(measure_in_process: bool = True) -> float: + """ + A local utility to measure memory allocation at a specific point in time. + Can measure either the process resident memory or system wide memory available + + Uses psutil package. + + Parameters + ---------- + measure_in_process : bool, True by default + Mesure memory allocation in the current process only, if false then measures at the system + level. + """ + import psutil + + if measure_in_process: + process = psutil.Process() + memory = process.memory_info().rss + else: + mem_info = psutil.virtual_memory() + memory = mem_info.total - mem_info.available + + return memory diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 11909bce0e..6ce94114c4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -80,6 +80,8 @@ def generate_recording( probe.set_device_channel_indices(np.arange(num_channels)) recording.set_probe(probe, in_place=True) + recording.name = "SyntheticRecording" + return recording @@ -101,11 +103,11 @@ def generate_sorting( Parameters ---------- num_units : int, default: 5 - Number of units + Number of units. sampling_frequency : float, default: 30000.0 - The sampling frequency + The sampling frequency. durations : list, default: [10.325, 3.5] - Duration of each segment in s + Duration of each segment in s. firing_rates : float, default: 3.0 The firing rate of each unit (in Hz). empty_units : list, default: None @@ -121,12 +123,12 @@ def generate_sorting( border_size_samples : int, default: 20 The size of the border in samples to add border spikes. seed : int, default: None - The random seed + The random seed. Returns ------- sorting : NumpySorting - The sorting object + The sorting object. """ seed = _ensure_seed(seed) rng = np.random.default_rng(seed) @@ -185,19 +187,19 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. sync_event_ratio : float The ratio of added synchronous spikes with respect to the total number of spikes. E.g., 0.5 means that the final sorting will have 1.5 times number of spikes, and all the extra spikes are synchronous (same sample_index), but on different units (not duplicates). seed : int, default: None - The random seed + The random seed. Returns ------- sorting : TransformSorting - The sorting object, keeping track of added spikes + The sorting object, keeping track of added spikes. """ rng = np.random.default_rng(seed) @@ -247,18 +249,18 @@ def generate_sorting_to_inject( Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. num_samples: list of size num_segments. The number of samples in all the segments of the sorting, to generate spike times - covering entire the entire duration of the segments + covering entire the entire duration of the segments. max_injected_per_unit: int, default 1000 - The maximal number of spikes injected per units + The maximal number of spikes injected per units. injected_rate: float, default 0.05 - The rate at which spikes are injected + The rate at which spikes are injected. refractory_period_ms: float, default 1.5 - The refractory period that should not be violated while injecting new spikes + The refractory period that should not be violated while injecting new spikes. seed: int, default None - The random seed + The random seed. Returns ------- @@ -310,22 +312,22 @@ class TransformSorting(BaseSorting): Parameters ---------- sorting : BaseSorting - The sorting object + The sorting object. added_spikes_existing_units : np.array (spike_vector) - The spikes that should be added to the sorting object, for existing units + The spikes that should be added to the sorting object, for existing units. added_spikes_new_units: np.array (spike_vector) - The spikes that should be added to the sorting object, for new units + The spikes that should be added to the sorting object, for new units. new_units_ids: list - The unit_ids that should be added if spikes for new units are added + The unit_ids that should be added if spikes for new units are added. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. Returns ------- sorting : TransformSorting - The sorting object with the added spikes and/or units + The sorting object with the added spikes and/or units. """ def __init__( @@ -426,12 +428,14 @@ def add_from_sorting(sorting1: BaseSorting, sorting2: BaseSorting, refractory_pe Parameters ---------- - sorting1: the first sorting - sorting2: the second sorting + sorting1: BaseSorting + The first sorting. + sorting2: BaseSorting + The second sorting. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ assert ( sorting1.get_sampling_frequency() == sorting2.get_sampling_frequency() @@ -490,12 +494,14 @@ def add_from_unit_dict( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting dict_list: list of dict + A list of dict with unit_ids as keys and spike times as values. refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_unit_dict(units_dict_list, sorting1.get_sampling_frequency()) sorting = TransformSorting.add_from_sorting(sorting1, sorting2, refractory_period_ms) @@ -513,18 +519,19 @@ def from_times_labels( Parameters ---------- - sorting1: the first sorting + sorting1: BaseSorting + The first sorting times_list: list of array (or array) - An array of spike times (in frames) + An array of spike times (in frames). labels_list: list of array (or array) - An array of spike labels corresponding to the given times + An array of spike labels corresponding to the given times. unit_ids: list or None, default: None The explicit list of unit_ids that should be extracted from labels_list - If None, then it will be np.unique(labels_list) + If None, then it will be np.unique(labels_list). refractory_period_ms : float, default None The refractory period violation to prevent duplicates and/or unphysiological addition of spikes. Any spike times in added_spikes violating the refractory period will be - discarded + discarded. """ sorting2 = NumpySorting.from_times_labels(times_list, labels_list, sampling_frequency, unit_ids) @@ -554,6 +561,16 @@ def clean_refractory_period(self): def create_sorting_npz(num_seg, file_path): + """ + Create a NPZ sorting file. + + Parameters + ---------- + num_seg : int + The number of segments. + file_path : str | Path + The file path to save the NPZ file. + """ # create a NPZ sorting file d = {} d["unit_ids"] = np.array([0, 1, 2], dtype="int64") @@ -583,6 +600,35 @@ def generate_snippets( empty_units=None, **job_kwargs, ): + """ + Generates a synthetic Snippets object. + + Parameters + ---------- + nbefore : int, default: 20 + Number of samples before the peak. + nafter : int, default: 44 + Number of samples after the peak. + num_channels : int, default: 2 + Number of channels. + wf_folder : str | Path | None, default: None + Optional folder to save the waveform snippets. If None, snippets are in memory. + sampling_frequency : float, default: 30000.0 + The sampling frequency of the snippets. + ndim : int, default: 2 + The number of dimensions of the probe. + num_units : int, default: 5 + The number of units. + empty_units : list | None, default: None + A list of units that will have no spikes. + + Returns + ------- + snippets : NumpySnippets + The snippets object. + sorting : NumpySorting + The associated sorting object. + """ recording = generate_recording( durations=durations, num_channels=num_channels, @@ -643,18 +689,18 @@ def synthesize_poisson_spike_vector( Parameters ---------- num_units : int, default: 20 - Number of neuronal units to simulate + Number of neuronal units to simulate. sampling_frequency : float, default: 30000.0 - Sampling frequency in Hz + Sampling frequency in Hz. duration : float, default: 60.0 - Duration of the simulation in seconds + Duration of the simulation in seconds. refractory_period_ms : float, default: 4.0 - Refractory period between spikes in milliseconds + Refractory period between spikes in milliseconds. firing_rates : float or array_like or tuple, default: 3.0 Firing rate(s) in Hz. Can be a single value for all units or an array of firing rates with - each element being the firing rate for one unit + each element being the firing rate for one unit. seed : int, default: 0 - Seed for random number generator + Seed for random number generator. Returns ------- @@ -748,27 +794,27 @@ def synthesize_random_firings( Parameters ---------- num_units : int - number of units + Number of units. sampling_frequency : float - sampling rate + Sampling rate. duration : float - duration of the segment in seconds + Duration of the segment in seconds. refractory_period_ms: float - refractory_period in ms + Refractory period in ms. firing_rates: float or list[float] The firing rate of each unit (in Hz). If float, all units will have the same firing rate. add_shift_shuffle: bool, default: False Optionally add a small shuffle on half of the spikes to make the autocorrelogram less flat. seed: int, default: None - seed for the generator + Seed for the generator. Returns ------- - times: - Concatenated and sorted times vector - labels: - Concatenated and sorted label vector + times: np.array + Concatenated and sorted times vector. + labels: np.array + Concatenated and sorted label vector. """ @@ -852,11 +898,11 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No Parameters ---------- sorting : - Original sorting + Original sorting. num : int - Number of injected units + Number of injected units. max_shift : int - range of the shift in sample + range of the shift in sample. ratio: float Proportion of original spike in the injected units. @@ -907,8 +953,27 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): - """ """ + """ + Inject some split units in a sorting. + Parameters + ---------- + sorting : BaseSorting + Original sorting. + split_ids : list + List of unit_ids to split. + num_split : int, default: 2 + Number of split units. + output_ids : bool, default: False + If True, return the new unit_ids. + seed : int, default: None + Random seed. + + Returns + ------- + sorting_with_split : NumpySorting + A sorting with split units. + """ unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" @@ -958,7 +1023,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol num_violations : int Number of contaminating spikes. violation_delta : float, default: 1e-5 - Temporal offset of contaminating spikes (in seconds) + Temporal offset of contaminating spikes (in seconds). Returns ------- @@ -1215,7 +1280,7 @@ def generate_recording_by_size( num_channels: int Number of channels. seed : int, default: None - The seed for np.random.default_rng + The seed for np.random.default_rng. Returns ------- @@ -1615,7 +1680,7 @@ class InjectTemplatesRecording(BaseRecording): * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. nbefore: list[int] | int | None, default: None - Where is the center of the template for each unit? + The number of samples before the peak of the template to align the spike. If None, will default to the highest peak. amplitude_factor: list[float] | float | None, default: None The amplitude of each spike for each unit. @@ -1630,7 +1695,7 @@ class InjectTemplatesRecording(BaseRecording): You can use int for mono-segment objects. upsample_vector: np.array or None, default: None. When templates is 4d we can simulate a jitter. - Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.sahpe[3] + Optional the upsample_vector is the jitter index with a number per spike in range 0-templates.shape[3]. Returns ------- @@ -1738,6 +1803,8 @@ def __init__( ) self.add_recording_segment(recording_segment) + # to discuss: maybe we could set json serializability to False always + # because templates could be large! if not sorting.check_serializability("json"): self._serializability["json"] = False if parent_recording is not None: @@ -2122,4 +2189,7 @@ def generate_ground_truth_recording( recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) + recording.name = "GroundTruthRecording" + sorting.name = "GroundTruthSorting" + return recording, sorting diff --git a/src/spikeinterface/core/npzsortingextractor.py b/src/spikeinterface/core/npzsortingextractor.py index f60dadd8ec..b8e7357e8c 100644 --- a/src/spikeinterface/core/npzsortingextractor.py +++ b/src/spikeinterface/core/npzsortingextractor.py @@ -16,9 +16,6 @@ class NpzSortingExtractor(BaseSorting): All spike are store in two columns maner index+labels """ - mode = "file" - name = "npz" - def __init__(self, file_path): self.npz_filename = file_path diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1ee472ffa4..f4790817a8 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -37,9 +37,6 @@ class NumpyRecording(BaseRecording): An optional list of channel_ids. If None, linear channels are assumed """ - mode = "memory" - name = "numpy" - def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=None): if isinstance(traces_list, list): all_elements_are_list = all(isinstance(e, list) for e in traces_list) @@ -86,6 +83,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N @staticmethod def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + + t_starts = source_recording._get_t_starts() + if shms[0] is not None: # if the computation was done in parallel then traces_list is shared array # this can lead to problem @@ -94,13 +94,14 @@ def from_recording(source_recording, **job_kwargs): for shm in shms: shm.close() shm.unlink() - # TODO later : propagte t_starts ? + recording = NumpyRecording( traces_list, source_recording.get_sampling_frequency(), - t_starts=None, + t_starts=t_starts, channel_ids=source_recording.channel_ids, ) + return recording class NumpyRecordingSegment(BaseRecordingSegment): @@ -142,9 +143,6 @@ class SharedMemoryRecording(BaseRecording): If True, the main instance will unlink the sharedmem buffer when deleted """ - mode = "memory" - name = "SharedMemory" - def __init__( self, shm_names, shape_list, dtype, sampling_frequency, channel_ids=None, t_starts=None, main_shm_owner=True ): @@ -212,7 +210,7 @@ def __del__(self): def from_recording(source_recording, **job_kwargs): traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) - # TODO later : propagte t_starts ? + t_starts = source_recording._get_t_starts() recording = SharedMemoryRecording( shm_names=[shm.name for shm in shms], @@ -220,7 +218,7 @@ def from_recording(source_recording, **job_kwargs): dtype=source_recording.dtype, sampling_frequency=source_recording.sampling_frequency, channel_ids=source_recording.channel_ids, - t_starts=None, + t_starts=t_starts, main_shm_owner=True, ) @@ -252,8 +250,6 @@ class NumpySorting(BaseSorting): A list of unit_ids. """ - name = "numpy" - def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index b23b7202c6..039fa8fd60 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -191,16 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices): seg_start = self.cumsum_length[i] if i == i0: # first - traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices) + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame_, channel_indices) all_traces.append(traces_chunk) elif i == i1: # last if (end_frame - seg_start) > 0: - traces_chunk = rec_seg.get_traces(None, end_frame - seg_start, channel_indices) + start_frame_ = 0 + traces_chunk = rec_seg.get_traces(start_frame_, end_frame - seg_start, channel_indices) all_traces.append(traces_chunk) else: # in between - traces_chunk = rec_seg.get_traces(None, None, channel_indices) + start_frame_ = 0 + end_frame_ = rec_seg.get_num_samples() + traces_chunk = rec_seg.get_traces(start_frame_, end_frame_, channel_indices) all_traces.append(traces_chunk) traces = np.concatenate(all_traces, axis=0) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 02f4529a98..918d95bf52 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,7 +1,10 @@ from __future__ import annotations -from .basesorting import BaseSorting + import numpy as np +from .basesorting import BaseSorting +from .numpyextractors import NumpySorting + def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]: """ @@ -220,3 +223,202 @@ def random_spikes_selection( raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") return random_spikes_indices + + +def apply_merges_to_sorting( + sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" +): + """ + Apply a resolved representation of the merges to a sorting object. + + This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible. + + If `censor_ms` is not None, duplicated spikes violating the `censor_ms` refractory period are removed. + + Optionally, the boolean mask of kept spikes is returned. + + Parameters + ---------- + sorting : Sorting + The Sorting object to apply merges. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None, default: None + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + merged units will have the first unit_id of every lists of merges. + censor_ms: float | None, default: None + When applying the merges, should be discard consecutive spikes violating a given refractory per + return_kept : bool, default: False + If True, also return also a booolean mask of kept spikes. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + sorting : The new Sorting object + The newly create sorting with the merged units + keep_mask : numpy.array + A boolean mask, if censor_ms is not None, telling which spike from the original spike vector + has been kept, given the refractory period violations (None if censor_ms is None) + """ + + spikes = sorting.to_spike_vector().copy() + keep_mask = np.ones(len(spikes), dtype=bool) + + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) + + rename_ids = {} + for i, merge_group in enumerate(units_to_merge): + for unit_id in merge_group: + rename_ids[unit_id] = new_unit_ids[i] + + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for old_unit_id in sorting.unit_ids: + if old_unit_id in rename_ids.keys(): + new_unit_id = rename_ids[old_unit_id] + else: + new_unit_id = old_unit_id + + new_unit_index = all_unit_ids.index(new_unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][old_unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + + if censor_ms is not None: + rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) + for group_old_ids in units_to_merge: + for segment_index in range(num_seg): + group_indices = [] + for unit_id in group_old_ids: + group_indices.append(spike_indices[segment_index][unit_id]) + group_indices = np.concatenate(group_indices) + group_indices = np.sort(group_indices) + inds = np.flatnonzero(np.diff(spikes["sample_index"][group_indices]) < rpv) + keep_mask[group_indices[inds + 1]] = False + + spikes = spikes[keep_mask] + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_kept: + return sorting, keep_mask + else: + return sorting + + +def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): + """ + Function to get the list of unique unit_ids after some merges, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None + A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + + Returns + ------- + + all_unit_ids : The unit ids in the merged sorting + The units_ids that will be present after merges + + """ + old_unit_ids = np.asarray(old_unit_ids) + + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + + all_unit_ids = list(old_unit_ids.copy()) + for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge): + assert len(group_ids) > 1, "A merge should have at least two units" + for unit_id in group_ids: + assert unit_id in old_unit_ids, "Merged ids should be in the sorting" + for unit_id in group_ids: + if unit_id != new_unit_id: + # new_unit_id can be inside group_ids + all_unit_ids.remove(unit_id) + if new_unit_id not in all_unit_ids: + all_unit_ids.append(new_unit_id) + return np.array(all_unit_ids) + + +def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the the same + length as `units_to_merge`. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + units_to_merge : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + + Returns + ------- + new_unit_ids : The new unit ids + The new units_ids associated with the merges. + """ + old_unit_ids = np.asarray(old_unit_ids) + + if new_unit_ids is not None: + # then only doing a consistency check + assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + for i, new_unit_id in enumerate(new_unit_ids): + if new_unit_id in old_unit_ids: + assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" + else: + dtype = old_unit_ids.dtype + num_merge = len(units_to_merge) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "take_first": + new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge] + elif new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids = [str(m + i) for i in range(num_merge)] + else: + # we cannot automatically find new names + new_unit_ids = [f"merge{i}" for i in range(num_merge)] + else: + # dtype int + new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) + else: + raise ValueError("wrong new_id_strategy") + + return new_unit_ids diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 947c5686d8..7f55646b63 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -4,8 +4,9 @@ """ from typing import Sequence +import numpy as np from spikeinterface.core.base import BaseExtractor -from spikeinterface.core import generate_recording, concatenate_recordings +from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings class DummyDictExtractor(BaseExtractor): @@ -65,6 +66,34 @@ def test_check_if_serializable(): assert not extractor.check_serializability("json") +def test_name_and_repr(): + test_recording, test_sorting = generate_ground_truth_recording(seed=0, durations=[2]) + assert test_recording.name == "GroundTruthRecording" + assert test_sorting.name == "GroundTruthSorting" + + # set a different name + test_recording.name = "MyRecording" + assert test_recording.name == "MyRecording" + + # to/from dict + test_recording_dict = test_recording.to_dict() + test_recording2 = BaseExtractor.from_dict(test_recording_dict) + assert test_recording2.name == "MyRecording" + + # repr + rec_str = str(test_recording2) + assert "MyRecording" in rec_str + test_recording2.name = None + assert "MyRecording" not in str(test_recording2) + assert test_recording2.__class__.__name__ in str(test_recording2) + # above 10khz, sampling frequency is printed in kHz + assert f"kHz" in rec_str + # below 10khz sampling frequency is printed in Hz + test_rec_low_fs = generate_recording(seed=0, durations=[2], sampling_frequency=5000) + rec_str = str(test_rec_low_fs) + assert "Hz" in rec_str + + if __name__ == "__main__": test_check_if_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 8ea99e3d04..ea5edc6e6e 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -1,8 +1,11 @@ import pytest import numpy as np +from pathlib import Path from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording +from spikeinterface.core.core_tools import measure_memory_allocation +from spikeinterface.core.generate import NoiseGeneratorRecording def test_BinaryRecordingExtractor(create_cache_folder): @@ -51,15 +54,75 @@ def test_round_trip(tmp_path): dtype=dtype, ) + # Test for full traces assert np.allclose(recording.get_traces(), binary_recorder.get_traces()) - start_frame = 200 - end_frame = 500 + # Ttest for a sub-set of the traces + start_frame = 20 + end_frame = 40 smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame) np.allclose(smaller_traces, binary_smaller_traces) +@pytest.fixture(scope="module") +def folder_with_binary_files(tmpdir_factory): + tmp_path = Path(tmpdir_factory.mktemp("spike_interface_test")) + folder = tmp_path / "test_binary_recording" + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + recording = NoiseGeneratorRecording( + durations=[1.0], + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + ) + dtype = recording.get_dtype() + recording.save(folder=folder, overwrite=True) + + return folder + + +def test_sequential_reading_of_small_traces(folder_with_binary_files): + # Test that memmap is readed correctly when pointing to specific frames + folder = folder_with_binary_files + num_channels = 32 + sampling_frequency = 30_000.0 + dtype = "float32" + + file_paths = [folder / "traces_cached_seg0.raw"] + recording = BinaryRecordingExtractor( + num_chan=num_channels, + file_paths=file_paths, + sampling_frequency=sampling_frequency, + dtype=dtype, + ) + + full_traces = recording.get_traces() + + # Test for a sub-set of the traces + start_frame = 10 + end_frame = 15 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 1000 + end_frame = 1100 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + # Test for a sub-set of the traces + start_frame = 10_000 + end_frame = 11_000 + small_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) + expected_traces = full_traces[start_frame:end_frame, :] + assert np.allclose(small_traces, expected_traces) + + if __name__ == "__main__": test_BinaryRecordingExtractor() diff --git a/src/spikeinterface/core/tests/test_segmentutils.py b/src/spikeinterface/core/tests/test_segmentutils.py index d3c73805f0..166ecafd09 100644 --- a/src/spikeinterface/core/tests/test_segmentutils.py +++ b/src/spikeinterface/core/tests/test_segmentutils.py @@ -5,10 +5,6 @@ from numpy.testing import assert_raises from spikeinterface.core import ( - AppendSegmentRecording, - AppendSegmentSorting, - ConcatenateSegmentRecording, - ConcatenateSegmentSorting, NumpyRecording, NumpySorting, append_recordings, diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 1aefeeb062..38baf62c35 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -9,6 +9,9 @@ spike_vector_to_spike_trains, random_spikes_selection, spike_vector_to_indices, + apply_merges_to_sorting, + _get_ids_after_merging, + generate_unit_ids_for_merge_group, ) @@ -75,7 +78,87 @@ def test_random_spikes_selection(): assert random_spikes_indices.size == spikes.size +def test_apply_merges_to_sorting(): + + times = np.array([0, 0, 10, 20, 300]) + labels = np.array(["a", "b", "c", "a", "b"]) + + # unit_ids str + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + spikes1 = sorting1.to_spike_vector() + + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None) + spikes2 = sorting2.to_spike_vector() + assert sorting2.unit_ids.size == 2 + assert sorting1.to_spike_vector().size == sorting1.to_spike_vector().size + assert np.array_equal(["c", "merge0"], sorting2.unit_ids) + assert np.array_equal( + spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"] + ) + + sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True) + spikes3 = sorting3.to_spike_vector() + assert spikes3.size < spikes1.size + assert not keep_mask[1] + st = sorting3.get_unit_spike_train(segment_index=0, unit_id="merge0") + assert st.size == 3 # one spike is removed by censor period + + # unit_ids int + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=[10, 20, 30]) + spikes1 = sorting1.to_spike_vector() + sorting2 = apply_merges_to_sorting(sorting1, [[10, 20]], censor_ms=None) + assert np.array_equal(sorting2.unit_ids, [30, 31]) + + sorting1 = NumpySorting.from_times_labels([times, times], [labels, labels], 10_000.0, unit_ids=["a", "b", "c"]) + sorting2 = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=None, new_id_strategy="take_first") + assert np.array_equal(sorting2.unit_ids, ["a", "c"]) + + +def test_get_ids_after_merging(): + + all_unit_ids = _get_ids_after_merging(["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], ["x", "d"]) + assert np.array_equal(all_unit_ids, ["c", "d", "x"]) + # print(all_unit_ids) + + all_unit_ids = _get_ids_after_merging([0, 5, 12, 9, 15], [[0, 5], [9, 15]], [28, 9]) + assert np.array_equal(all_unit_ids, [12, 9, 28]) + # print(all_unit_ids) + + +def test_generate_unit_ids_for_merge_group(): + + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["merge0", "merge1"]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["a", "b", "c", "d", "e"], [["a", "b"], ["d", "e"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["a", "d"]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="append") + assert np.array_equal(new_unit_ids, [16, 17]) + + new_unit_ids = generate_unit_ids_for_merge_group([0, 5, 12, 9, 15], [[0, 5], [9, 15]], new_id_strategy="take_first") + assert np.array_equal(new_unit_ids, [0, 9]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="append" + ) + assert np.array_equal(new_unit_ids, ["16", "17"]) + + new_unit_ids = generate_unit_ids_for_merge_group( + ["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="take_first" + ) + assert np.array_equal(new_unit_ids, ["0", "9"]) + + if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - test_random_spikes_selection() + # test_random_spikes_selection() + + test_apply_merges_to_sorting() + test_get_ids_after_merging() + test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 487a893096..049d5ab6e5 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -1,69 +1,289 @@ +import copy + import pytest import numpy as np from spikeinterface.core import generate_recording, generate_sorting +import spikeinterface.full as si + +class TestTimeHandling: + """ + This class tests how time is handled in SpikeInterface. Under the hood, + time can be represented as a full `time_vector` or only as + `t_start` attribute on segments from which a vector of times + is generated on the fly. Both time representations are tested here. + """ -def test_time_handling(create_cache_folder): - cache_folder = create_cache_folder - durations = [[10], [10, 5]] + # Fixtures ##### + @pytest.fixture(scope="session") + def time_vector_recording(self): + """ + Add time vectors to the recording, returning the + raw recording, recording with time vectors added to + segments, and list a the time vectors added to the recording. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - # test multi-segment - for i, dur in enumerate(durations): - rec = generate_recording(num_channels=4, durations=dur) - sort = generate_sorting(num_units=10, durations=dur) + return self._get_time_vector_recording(raw_recording) - for segment_index in range(rec.get_num_segments()): - original_times = rec.get_times(segment_index=segment_index) - new_times = original_times + 5 - rec.set_times(new_times, segment_index=segment_index) + @pytest.fixture(scope="session") + def t_start_recording(self): + """ + Add a t_starts to the recording, returning the + raw recording, recording with t_starts added to segments, + and a list of the time vectors generated from adding the + t_start to the recording times. + """ + durations = [10, 15, 20] + raw_recording = generate_recording(num_channels=4, durations=durations) - sort.register_recording(rec) - assert sort.has_recording() + return self._get_t_start_recording(raw_recording) - rec_cache = rec.save(folder=cache_folder / f"rec{i}") + def _get_time_vector_recording(self, raw_recording): + """ + Loop through all recording segments, adding a different time + vector to each segment. The time vector is the original times with + a t_start and irregularly spaced offsets to mimic irregularly + spaced timeseries data. Return the original recording, + recoridng with time vectors added and list including the added time vectors. + """ + times_recording = copy.deepcopy(raw_recording) + all_time_vectors = [] + for segment_index in range(raw_recording.get_num_segments()): - for segment_index in range(sort.get_num_segments()): - assert rec.has_time_vector(segment_index=segment_index) - assert sort.has_time_vector(segment_index=segment_index) + t_start = segment_index + 1 * 100 - # times are correctly saved by the recording - assert np.allclose( - rec.get_times(segment_index=segment_index), rec_cache.get_times(segment_index=segment_index) + some_small_increasing_numbers = np.arange(times_recording.get_num_samples(segment_index)) * ( + 1 / times_recording.get_sampling_frequency() ) - # spike times are correctly adjusted - for u in sort.get_unit_ids(): - spike_times = sort.get_unit_spike_train(u, segment_index=segment_index, return_times=True) - rec_times = rec.get_times(segment_index=segment_index) - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + offsets = np.cumsum(some_small_increasing_numbers) + time_vector = t_start + times_recording.get_times(segment_index) + offsets + + all_time_vectors.append(time_vector) + times_recording.set_times(times=time_vector, segment_index=segment_index) + + assert np.array_equal( + times_recording._recording_segments[segment_index].time_vector, + time_vector, + ), "time_vector was not properly set during test setup" + + return (raw_recording, times_recording, all_time_vectors) + + def _get_t_start_recording(self, raw_recording): + """ + For each segment in the recording, add a different `t_start`. + Return a list of time vectors generating from the recording times + + the t_starts. + """ + t_start_recording = copy.deepcopy(raw_recording) + + all_t_starts = [] + for segment_index in range(raw_recording.get_num_segments()): + + t_start = (segment_index + 1) * 100 + + all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) + t_start_recording._recording_segments[segment_index].t_start = t_start + + return (raw_recording, t_start_recording, all_t_starts) + + def _get_fixture_data(self, request, fixture_name): + """ + A convenience function to get the data from a fixture + based on the name. This is used to allow parameterising + tests across fixtures. + """ + time_recording_fixture = request.getfixturevalue(fixture_name) + raw_recording, times_recording, all_times = time_recording_fixture + return (raw_recording, times_recording, all_times) + + # Tests ##### + def test_has_time_vector(self, time_vector_recording): + """ + Test the `has_time_vector` function returns `False` before + a time vector is added and `True` afterwards. + """ + raw_recording, times_recording, _ = time_vector_recording + + for segment_idx in range(raw_recording.get_num_segments()): + + assert raw_recording.has_time_vector(segment_idx) is False + assert times_recording.has_time_vector(segment_idx) is True + + @pytest.mark.parametrize("mode", ["binary", "zarr"]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_path): + """ + Test `t_start` or `time_vector` is propagated to a saved recording, + by saving, reloading, and checking times are correct. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + folder_name = "recording" + recording_cache = times_recording.save(format=mode, folder=tmp_path / folder_name) + + if mode == "zarr": + folder_name += ".zarr" + recording_load = si.load_extractor(tmp_path / folder_name) + + self._check_times_match(recording_cache, all_times) + self._check_times_match(recording_load, all_times) + + @pytest.mark.parametrize("sharedmem", [True, False]) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_save_memory(self, request, fixture_name, sharedmem): + """ + Test t_start and time_vector are propagated to recording saved into memory. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + recording_load = times_recording.save(format="memory", sharedmem=sharedmem) + self._check_times_match(recording_load, all_times) -def test_frame_slicing(): - duration = [10] + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_propagated_to_select_segments(self, request, fixture_name): + """ + Test that when `recording.select_segments()` is used, the times + are propagated to the new recoridng object. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) - rec = generate_recording(num_channels=4, durations=duration) - sort = generate_sorting(num_units=10, durations=duration) + for segment_index in range(times_recording.get_num_segments()): + segment = times_recording.select_segments(segment_index) + assert np.array_equal(segment.get_times(), all_times[segment_index]) - original_times = rec.get_times() - new_times = original_times + 5 - rec.set_times(new_times) + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_times_propagated_to_sorting(self, request, fixture_name): + """ + Check that when attached to a sorting object, the times are propagated + to the object. This means that all spike times should respect the + `t_start` or `time_vector` added. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + for segment_index in range(raw_recording.get_num_segments()): - sort.register_recording(rec) + if fixture_name == "time_vector_recording": + assert sorting.has_time_vector(segment_index=segment_index) - start_frame = 3 * rec.get_sampling_frequency() - end_frame = 7 * rec.get_sampling_frequency() + self._check_spike_times_are_correct(sorting, times_recording, segment_index) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_time_sample_converters(self, request, fixture_name): + """ + Test the `recording.sample_time_to_index` and + `recording.time_to_sample_index` convenience functions. + """ + raw_recording, times_recording, all_times = self._get_fixture_data(request, fixture_name) + with pytest.raises(ValueError) as e: + times_recording.sample_index_to_time(0) + assert "Provide 'segment_index'" in str(e) + + for segment_index in range(times_recording.get_num_segments()): + + sample_index = np.random.randint(low=0, high=times_recording.get_num_samples(segment_index)) + time_ = times_recording.sample_index_to_time(sample_index, segment_index=segment_index) + + assert time_ == all_times[segment_index][sample_index] + + new_sample_index = times_recording.time_to_sample_index(time_, segment_index=segment_index) + + assert new_sample_index == sample_index + + @pytest.mark.parametrize("time_type", ["time_vector", "t_start"]) + @pytest.mark.parametrize("bounds", ["start", "middle", "end"]) + def test_slice_recording(self, time_type, bounds): + """ + Test times are correct after applying `frame_slice` or `time_slice` + to a recording or sorting (for `frame_slice`). The the recording times + should be correct with respect to the set `t_start` or `time_vector`. + """ + raw_recording = generate_recording(num_channels=4, durations=[10]) + + if time_type == "time_vector": + raw_recording, times_recording, all_times = self._get_time_vector_recording(raw_recording) + else: + raw_recording, times_recording, all_times = self._get_t_start_recording(raw_recording) + + sorting = self._get_sorting_with_recording_attached( + recording_for_durations=raw_recording, recording_to_attach=times_recording + ) + + # Take some different times, including min and max bounds of + # the recording, and some arbitaray times in the middle (20% and 80%). + if bounds == "start": + start_frame = 0 + end_frame = int(times_recording.get_num_samples(0) * 0.8) + elif bounds == "end": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = times_recording.get_num_samples(0) - 1 + elif bounds == "middle": + start_frame = int(times_recording.get_num_samples(0) * 0.2) + end_frame = int(times_recording.get_num_samples(0) * 0.8) + + # Slice the recording and get the new times are correct + rec_frame_slice = times_recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sort_frame_slice = sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + + assert np.allclose(rec_frame_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + self._check_spike_times_are_correct(sort_frame_slice, rec_frame_slice, segment_index=0) + + # Test `time_slice` + start_time = times_recording.sample_index_to_time(start_frame) + end_time = times_recording.sample_index_to_time(end_frame) + + rec_slice = times_recording.time_slice(start_time=start_time, end_time=end_time) + + assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + + # Helpers #### + def _check_times_match(self, recording, all_times): + """ + For every segment in a recording, check the `get_times()` + match the expected times in the list of time vectors, `all_times`. + """ + for segment_index in range(recording.get_num_segments()): + assert np.array_equal(recording.get_times(segment_index), all_times[segment_index]) + + def _check_spike_times_are_correct(self, sorting, times_recording, segment_index): + """ + For every unit in the `sorting`, for a particular segment, check that + the unit times match the times of the original recording as + retrieved with `get_times()`. + """ + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) + spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + rec_times = times_recording.get_times(segment_index=segment_index) + + assert np.array_equal( + spike_times, + rec_times[spike_indexes], + ) - rec_slice = rec.frame_slice(start_frame=start_frame, end_frame=end_frame) - sort_slice = sort.frame_slice(start_frame=start_frame, end_frame=end_frame) + def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): + """ + Convenience function to create a sorting object with + a recording attached. Typically use the raw recordings + for the durations of which to make the sorter, as + the generate_sorter is not setup to handle the + (strange) edge case of the irregularly spaced + test time vectors. + """ + durations = [ + recording_for_durations.get_duration(idx) for idx in range(recording_for_durations.get_num_segments()) + ] - for u in sort_slice.get_unit_ids(): - spike_times = sort_slice.get_unit_spike_train(u, return_times=True) - rec_times = rec_slice.get_times() - assert np.all(spike_times >= rec_times[0]) - assert np.all(spike_times <= rec_times[-1]) + sorting = generate_sorting(num_units=10, durations=durations) + sorting.register_recording(recording_to_attach) + assert sorting.has_recording() -if __name__ == "__main__": - test_frame_slicing() + return sorting diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4851c0eb5c..1b9637e097 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -31,13 +31,7 @@ class ZarrRecordingExtractor(BaseRecording): The recording Extractor """ - installed = True - mode = "folder" - installation_mesg = "" - name = "zarr" - def __init__(self, folder_path: Path | str, storage_options: dict | None = None): - assert self.installed, self.installation_mesg folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) @@ -167,13 +161,7 @@ class ZarrSortingExtractor(BaseSorting): The sorting Extractor """ - installed = True - mode = "folder" - installation_mesg = "" - name = "zarr" - def __init__(self, folder_path: Path | str, storage_options: dict | None = None, zarr_group: str | None = None): - assert self.installed, self.installation_mesg folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index bbdb70b2f6..11f26ea778 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -4,6 +4,7 @@ from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class from copy import deepcopy +from spikeinterface.core.sorting_tools import generate_unit_ids_for_merge_group class MergeUnitsSorting(BaseSorting): @@ -44,35 +45,15 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy parents_unit_ids = sorting.unit_ids sampling_frequency = sorting.get_sampling_frequency() + new_unit_ids = generate_unit_ids_for_merge_group( + sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy="append" + ) + all_removed_ids = [] for ids in units_to_merge: all_removed_ids.extend(ids) keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids] - if new_unit_ids is None: - dtype = parents_unit_ids.dtype - # select new_unit_ids greater that the max id, event greater than the numerical str ids - if np.issubdtype(dtype, np.character): - # dtype str - if all(p.isdigit() for p in parents_unit_ids): - # All str are digit : we can generate a max - m = max(int(p) for p in parents_unit_ids) + 1 - new_unit_ids = [str(m + i) for i in range(num_merge)] - else: - # we cannot automatically find new names - new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError( - "Unable to find 'new_unit_ids' because it is a string and parents " - "already contain merges. Pass a list of 'new_unit_ids' as an argument." - ) - else: - # dtype int - new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) - else: - if np.any(np.isin(new_unit_ids, keep_unit_ids)): - raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") - assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" # some checks @@ -81,7 +62,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy assert properties_policy in ("keep", "remove"), "properties_policy must be " "keep" " or " "remove" "" # new units are put at the end - unit_ids = keep_unit_ids + new_unit_ids + unit_ids = keep_unit_ids + list(new_unit_ids) BaseSorting.__init__(self, sampling_frequency, unit_ids) # assert all(np.isin(keep_unit_ids, self.unit_ids)), 'new_unit_id should have a compatible format with the parent ids' diff --git a/src/spikeinterface/extractors/alfsortingextractor.py b/src/spikeinterface/extractors/alfsortingextractor.py index fa6490135c..f7b5401182 100644 --- a/src/spikeinterface/extractors/alfsortingextractor.py +++ b/src/spikeinterface/extractors/alfsortingextractor.py @@ -25,12 +25,11 @@ class ALFSortingExtractor(BaseSorting): """ installation_mesg = "To use the ALF extractors, install ONE-api: \n\n pip install ONE-api\n\n" - name = "alf" def __init__(self, folder_path, sampling_frequency=30000): try: import one.alf.io as alfio - except ImportError as e: + except ImportError: raise ImportError(self.installation_mesg) self._folder_path = Path(folder_path) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index a09cea9863..d7e5b58e11 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -39,16 +39,14 @@ class CompressedBinaryIblExtractor(BaseRecording): The loaded data. """ - mode = "folder" installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - name = "cbin_ibl" def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp - except: + except ImportError: raise ImportError(self.installation_mesg) if cbin_file is None: folder_path = Path(folder_path) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 736927a1ee..0dfa3a85ad 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -29,7 +29,6 @@ class CellExplorerSortingExtractor(BaseSorting): Path to the `sessionInfo.mat` file. If None, it will be inferred from the file_path. """ - mode = "file" installation_mesg = "To use the CellExplorerSortingExtractor install pymatreader" def __init__( diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 8828ea8b64..35fce3a8e3 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - class CombinatoSortingExtractor(BaseSorting): """Load Combinato format data as a sorting extractor. @@ -37,11 +30,14 @@ class CombinatoSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_H5PY installation_mesg = "To use the CombinatoSortingExtractor install h5py: \n\n pip install h5py\n\n" - name = "combinato" def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign="both", keep_good_only=True): + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) + folder_path = Path(folder_path) assert folder_path.is_dir(), "Folder {} doesn't exist".format(folder_path) if sampling_frequency is None: diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 8948aad606..bd35180a7e 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -117,51 +117,20 @@ snippets_extractor_full_list = [NpySnippetsExtractor, WaveClusSnippetsExtractor] - -recording_extractor_full_dict = {recext.name: recext for recext in recording_extractor_full_list} -sorting_extractor_full_dict = {recext.name: recext for recext in sorting_extractor_full_list} -snippets_extractor_full_dict = {recext.name: recext for recext in snippets_extractor_full_list} - - -def get_recording_extractor_from_name(name: str) -> Type[BaseRecording]: - """ - Returns the Recording Extractor class based on its name. - - Parameters - ---------- - name: str - The Recording Extractor's name. - - Returns - ------- - recording_extractor: BaseRecording - The Recording Extractor class. - """ - - for recording_extractor in recording_extractor_full_list: - if recording_extractor.__name__ == name: - return recording_extractor - - raise ValueError(f"Recording extractor '{name}' not found.") - - -def get_sorting_extractor_from_name(name: str) -> Type[BaseSorting]: - """ - Returns the Sorting Extractor class based on its name. - - Parameters - ---------- - name: str - The Sorting Extractor's name. - - Returns - ------- - sorting_extractor: BaseSorting - The Sorting Extractor class. - """ - - for sorting_extractor in sorting_extractor_full_list: - if sorting_extractor.__name__ == name: - return sorting_extractor - - raise ValueError(f"Sorting extractor '{name}' not found.") +recording_extractor_full_dict = {} +for rec_class in recording_extractor_full_list: + # here we get the class name, remove "Recording" and "Extractor" and make it lower case + rec_class_name = rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower() + recording_extractor_full_dict[rec_class_name] = rec_class + +sorting_extractor_full_dict = {} +for sort_class in sorting_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + sort_class_name = sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower() + sorting_extractor_full_dict[sort_class_name] = sort_class + +event_extractor_full_dict = {} +for event_class in event_extractor_full_list: + # here we get the class name, remove "Extractor" and make it lower case + event_class_name = event_class.__name__.replace("Event", "").replace("Extractor", "").lower() + event_extractor_full_dict[event_class_name] = event_class diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 19038344ee..fa627d2ee3 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -25,9 +25,6 @@ class HDSortSortingExtractor(MatlabHelper, BaseSorting): The loaded data. """ - mode = "file" - name = "hdsort" - def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 4fe915a96b..de4929218b 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_HS2SX = True -except ImportError: - HAVE_HS2SX = False - class HerdingspikesSortingExtractor(BaseSorting): """Load HerdingSpikes format data as a sorting extractor. @@ -31,15 +24,13 @@ class HerdingspikesSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_HS2SX # check at class level if installed or not - mode = "file" - installation_mesg = ( - "To use the HS2SortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - name = "herdingspikes" + installation_mesg = "To use the HS2SortingExtractor install h5py: \n\n pip install h5py\n\n" def __init__(self, file_path, load_unit_info=True): - assert self.installed, self.installation_mesg + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) self._recording_file = file_path self._rf = h5py.File(self._recording_file, mode="r") diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 34481c94f1..5dd549347d 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -65,9 +65,7 @@ class IblRecordingExtractor(BaseRecording): The recording extractor which allows access to the traces. """ - mode = "folder" installation_mesg = "To use the IblRecordingSegment, install ibllib: \n\n pip install ONE-api\npip install ibllib\n" - name = "ibl_recording" @staticmethod def _get_default_one(cache_folder: Optional[Union[Path, str]] = None): @@ -304,7 +302,6 @@ class IblSortingExtractor(BaseSorting): The loaded data. """ - name = "ibl" installation_mesg = "IBL extractors require ibllib as a dependency." " To install, run: \n\n pip install ibllib\n\n" def __init__(self, pid: str, good_clusters_only: bool = False, load_unit_properties: bool = True, one=None): diff --git a/src/spikeinterface/extractors/klustaextractors.py b/src/spikeinterface/extractors/klustaextractors.py index 82534771a1..162376cb3c 100644 --- a/src/spikeinterface/extractors/klustaextractors.py +++ b/src/spikeinterface/extractors/klustaextractors.py @@ -18,13 +18,6 @@ from spikeinterface.core import BaseRecording, BaseSorting, BaseRecordingSegment, BaseSortingSegment, read_python from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - # noinspection SpellCheckingInspection class KlustaSortingExtractor(BaseSorting): @@ -43,18 +36,15 @@ class KlustaSortingExtractor(BaseSorting): The loaded data. """ - installed = HAVE_H5PY # check at class level if installed or not - installation_mesg = ( - "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - mode = "file_or_folder" - name = "klusta" + installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" default_cluster_groups = {0: "Noise", 1: "MUA", 2: "Good", 3: "Unsorted"} def __init__(self, file_or_folder_path, exclude_cluster_groups=None): - assert HAVE_H5PY, self.installation_mesg - # ~ SortingExtractor.__init__(self) + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None diff --git a/src/spikeinterface/extractors/mclustextractors.py b/src/spikeinterface/extractors/mclustextractors.py index 5cfa583054..d611a1576a 100644 --- a/src/spikeinterface/extractors/mclustextractors.py +++ b/src/spikeinterface/extractors/mclustextractors.py @@ -29,8 +29,6 @@ class MClustSortingExtractor(BaseSorting): Loaded data. """ - name = "mclust" - def __init__(self, folder_path, sampling_frequency, sampling_frequency_raw=None): end_header_str = "%%ENDHEADER" ext_list = ["t64", "t32", "t", "raw64", "raw32"] diff --git a/src/spikeinterface/extractors/mcsh5extractors.py b/src/spikeinterface/extractors/mcsh5extractors.py index f419b7e64d..78296926d0 100644 --- a/src/spikeinterface/extractors/mcsh5extractors.py +++ b/src/spikeinterface/extractors/mcsh5extractors.py @@ -24,18 +24,12 @@ class MCSH5RecordingExtractor(BaseRecording): The loaded data. """ - mode = "file" - installation_mesg = ( - "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed - ) - name = "mcsh5" + installation_mesg = "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" def __init__(self, file_path, stream_id=0): try: import h5py - - HAVE_MCSH5 = True except ImportError: raise ImportError(self.installation_mesg) @@ -61,6 +55,9 @@ def __init__(self, file_path, stream_id=0): # set gain self.set_channel_gains(mcs_info["gain"]) + # set offsets + self.set_channel_offsets(mcs_info["offset"]) + # set other properties self.set_property("electrode_labels", mcs_info["electrode_labels"]) @@ -100,7 +97,11 @@ def get_traces(self, start_frame=None, end_frame=None, channel_indices=None): def openMCSH5File(filename, stream_id): - """Open an MCS hdf5 file, read and return the recording info.""" + """Open an MCS hdf5 file, read and return the recording info. + Specs can be found online + https://www.multichannelsystems.com/downloads/documentation?page=3 + """ + import h5py rf = h5py.File(filename, "r") @@ -121,7 +122,8 @@ def openMCSH5File(filename, stream_id): Tick = info["Tick"][0] / 1e6 exponent = info["Exponent"][0] convFact = info["ConversionFactor"][0] - gain = convFact.astype(float) * (10.0**exponent) + gain_uV = 1e6 * (convFact.astype(float) * (10.0**exponent)) + offset_uV = -1e6 * (info["ADZero"].astype(float) * (10.0**exponent)) * gain_uV nRecCh, nFrames = data.shape channel_ids = [f"Ch{ch}" for ch in info["ChannelID"]] @@ -149,8 +151,9 @@ def openMCSH5File(filename, stream_id): "num_channels": nRecCh, "channel_ids": channel_ids, "electrode_labels": electrodeLabels, - "gain": gain, + "gain": gain_uV, "dtype": dtype, + "offset": offset_uV, } return mcs_info diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index acc7be58dd..f055e1d7c9 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -36,9 +36,6 @@ class MdaRecordingExtractor(BaseRecording): The loaded data. """ - mode = "folder" - name = "mda" - def __init__(self, folder_path, raw_fname="raw.mda", params_fname="params.json", geom_fname="geom.csv"): folder_path = Path(folder_path) self._folder_path = folder_path @@ -192,9 +189,6 @@ class MdaSortingExtractor(BaseSorting): The loaded data. """ - mode = "file" - name = "mda" - def __init__(self, file_path, sampling_frequency): firings = readmda(str(Path(file_path).absolute())) labels = firings[2, :] diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0b11b72b2a..bf52de7c1d 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -36,7 +36,7 @@ ) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets -from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx +from .spikeglx import SpikeGLXRecordingExtractor, SpikeGLXEventExtractor, read_spikeglx, read_spikeglx_event from .tdt import TdtRecordingExtractor, read_tdt from .neo_utils import get_neo_streams, get_neo_num_blocks @@ -73,4 +73,9 @@ Plexon2SortingExtractor, ] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] +neo_event_extractors_list = [ + AlphaOmegaEventExtractor, + OpenEphysBinaryEventExtractor, + Plexon2EventExtractor, + SpikeGLXEventExtractor, +] diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 239928f66d..2e70d5ba41 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -27,9 +27,7 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "AlphaOmegaRawIO" - name = "alphaomega" def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path, lsx_files) diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index 71e1277946..e086cb5dde 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -21,9 +21,7 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "AxonaRawIO" - name = "axona" def __init__(self, file_path, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 9f23575dba..e7b6199ea9 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -31,9 +31,7 @@ class BiocamRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "BiocamRawIO" - name = "biocam" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index ab3710e05e..9bd2b05f24 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -31,9 +31,7 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "file" NeoRawIOClass = "BlackrockRawIO" - name = "blackrock" def __init__( self, @@ -87,10 +85,8 @@ class BlackrockSortingExtractor(NeoBaseSortingExtractor): Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ - mode = "file" NeoRawIOClass = "BlackrockRawIO" neo_returns_frames = False - name = "blackrock" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index e2c79478fa..73a783ec5d 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -27,9 +27,7 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "CedRawIO" - name = "ced" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 90627d5772..8369369922 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -26,9 +26,7 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "EDFRawIO" - name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = {"filename": str(file_path)} diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 43439b80c9..34c8bf2eb5 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -33,9 +33,7 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "file" NeoRawIOClass = "IntanRawIO" - name = "intan" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index a66075b451..6c72696e16 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -39,9 +39,7 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load """ - mode = "file" NeoRawIOClass = "MaxwellRawIO" - name = "maxwell" def __init__( self, @@ -96,8 +94,6 @@ class MaxwellEventExtractor(BaseEvent): Class for reading TTL events from Maxwell files. """ - name = "maxwell" - def __init__(self, file_path): import h5py diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 0cbd9263ba..307a6c1fba 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -30,9 +30,7 @@ class MCSRawRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "RawMCSRawIO" - name = "mcsraw" def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 76f0b29f54..21a597029b 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -40,9 +40,7 @@ class MEArecRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "MEArecRawIO" - name = "mearec" def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -75,10 +73,8 @@ def map_to_neo_kwargs( class MEArecSortingExtractor(NeoBaseSortingExtractor): - mode = "file" NeoRawIOClass = "MEArecRawIO" neo_returns_frames = False - name = "mearec" def __init__(self, file_path: Union[str, Path]): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 0670371ba9..98f4a7c2ff 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -40,9 +40,7 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): Note that here the default is False contrary to neo. """ - mode = "folder" NeoRawIOClass = "NeuralynxRawIO" - name = "neuralynx" def __init__( self, @@ -90,11 +88,9 @@ class NeuralynxSortingExtractor(NeoBaseSortingExtractor): Used to extract information about the sampling frequency and t_start from the analog signal if provided. """ - mode = "folder" NeoRawIOClass = "NeuralynxRawIO" neo_returns_frames = True need_t_start_from_signal_stream = True - name = "neuralynx" def __init__( self, diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index 49784418e1..ac569c0df0 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -47,9 +47,7 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NeuroExplorerRawIO" - name = "neuroexplorer" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = {"filename": str(file_path)} diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 104f47af24..6c6f1d4bea 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -37,9 +37,7 @@ class NeuroScopeRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NeuroScopeRawIO" - name = "neuroscope" def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path, xml_file_path) @@ -103,8 +101,6 @@ class NeuroScopeSortingExtractor(BaseSorting): Path to the .xml file referenced by this sorting. """ - name = "neuroscope" - def __init__( self, folder_path: OptionalPathType = None, diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 00e5f8bfc1..b869936fa3 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -27,9 +27,7 @@ class NixRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "NIXRawIO" - name = "nix" def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index f3363b9013..04c25998f0 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -64,9 +64,7 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): neo.OpenEphysRawIO is now handling gaps directly but makes the read slower. """ - mode = "folder" NeoRawIOClass = "OpenEphysRawIO" - name = "openephyslegacy" def __init__( self, @@ -138,9 +136,7 @@ class OpenEphysBinaryRecordingExtractor(NeoBaseRecordingExtractor): """ - mode = "folder" NeoRawIOClass = "OpenEphysBinaryRawIO" - name = "openephys" def __init__( self, @@ -287,9 +283,7 @@ class OpenEphysBinaryEventExtractor(NeoBaseEventExtractor): """ - mode = "folder" NeoRawIOClass = "OpenEphysBinaryRawIO" - name = "openephys" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index cf08778ffa..9c2586dd5a 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -25,9 +25,7 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "PlexonRawIO" - name = "plexon" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -54,9 +52,7 @@ class PlexonSortingExtractor(NeoBaseSortingExtractor): The file path to load the recordings from. """ - mode = "file" NeoRawIOClass = "PlexonRawIO" - name = "plexon" neo_returns_frames = True def __init__(self, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 6c9160f13b..4434d02cc1 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -30,9 +30,7 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" - name = "plexon2" def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) @@ -66,10 +64,8 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" neo_returns_frames = True - name = "plexon2" def __init__(self, file_path, sampling_frequency=None): from neo.rawio import Plexon2RawIO @@ -98,9 +94,7 @@ class Plexon2EventExtractor(NeoBaseEventExtractor): """ - mode = "file" NeoRawIOClass = "Plexon2RawIO" - name = "plexon2" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 1bd0351553..cbc1db3f74 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -26,9 +26,7 @@ class Spike2RecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "Spike2RawIO" - name = "spike2" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 3d57817f88..e7c31b8afa 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -28,9 +28,7 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "file" NeoRawIOClass = "SpikeGadgetsRawIO" - name = "spikegadgets" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index adfd0f702e..10a1f78265 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -40,9 +40,7 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): Load exhaustively all annotations from neo. """ - mode = "folder" NeoRawIOClass = "SpikeGLXRawIO" - name = "spikeglx" def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_name=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path, load_sync_channel=load_sync_channel) @@ -110,9 +108,7 @@ class SpikeGLXEventExtractor(NeoBaseEventExtractor): """ - mode = "folder" NeoRawIOClass = "SpikeGLXRawIO" - name = "spikeglx" def __init__(self, folder_path, block_index=None): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 27b456102f..a1298dece7 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -27,9 +27,7 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load """ - mode = "folder" NeoRawIOClass = "TdtRawIO" - name = "tdt" def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(folder_path) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 9786766af1..7164afeac6 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -401,7 +401,40 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect return electrodes_indices -class NwbRecordingExtractor(BaseRecording): +class _BaseNWBExtractor: + "A class for common methods for NWB extractors." + + def _close_hdf5_file(self): + has_hdf5_backend = hasattr(self, "_file") + if has_hdf5_backend: + import h5py + + main_file_id = self._file.id + open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) + for object_id in open_object_ids_main: + object_name = h5py.h5i.get_name(object_id).decode("utf-8") + try: + object_id.close() + except: + import warnings + + warnings.warn(f"Error closing object {object_name}") + + def __del__(self): + # backend mode + if hasattr(self, "_file"): + if hasattr(self._file, "store"): + self._file.store.close() + else: + self._close_hdf5_file() + # pynwb mode + elif hasattr(self, "_nwbfile"): + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + + +class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): """Load an NWBFile as a RecordingExtractor. Parameters @@ -472,8 +505,6 @@ class NwbRecordingExtractor(BaseRecording): >>> rec = NwbRecordingExtractor(s3_url, stream_mode="fsspec", stream_cache_path="cache") """ - mode = "file" - name = "nwb" installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" def __init__( @@ -625,19 +656,6 @@ def __init__( "file": file, } - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._file.close() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation): self._nwbfile = read_nwbfile( backend=self.backend, @@ -951,7 +969,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -class NwbSortingExtractor(BaseSorting): +class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor): """Load an NWBFile as a SortingExtractor. Parameters ---------- @@ -1000,9 +1018,7 @@ class NwbSortingExtractor(BaseSorting): The sorting extractor for the NWB file. """ - mode = "file" installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" - name = "nwb" def __init__( self, @@ -1109,19 +1125,6 @@ def __init__( "t_start": self.t_start, } - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._file.close() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_sorting_segment_info_pynwb( self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False ): diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 3287f7422f..737a88c51a 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -26,12 +26,9 @@ class BasePhyKilosortSortingExtractor(BaseSorting): If True, all cluster properties are loaded from the tsv/csv files. """ - installed = False # check at class level if installed or not - mode = "folder" installation_mesg = ( "To use the PhySortingExtractor install pandas: \n\n pip install pandas\n\n" # error message when not installed ) - name = "phykilosort" def __init__( self, @@ -43,14 +40,10 @@ def __init__( ): try: import pandas as pd - - HAVE_PD = True except ImportError: - HAVE_PD = False - assert HAVE_PD, self.installation_mesg + raise ImportError(self.installation_mesg) phy_folder = Path(folder_path) - spike_times = np.load(phy_folder / "spike_times.npy").astype(int) if (phy_folder / "spike_clusters.npy").is_file(): @@ -228,8 +221,6 @@ class PhySortingExtractor(BasePhyKilosortSortingExtractor): The loaded Sorting object. """ - name = "phy" - def __init__( self, folder_path: Path | str, @@ -269,8 +260,6 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): The loaded Sorting object. """ - name = "kilosort" - def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove_empty_units: bool = True): BasePhyKilosortSortingExtractor.__init__( self, diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index b53b3b2056..1c5c147c6a 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -30,24 +30,19 @@ class SHYBRIDRecordingExtractor(BinaryRecordingExtractor): Loaded data. """ - mode = "folder" installation_mesg = ( "To use the SHYBRID extractors, install SHYBRID and pyyaml: " "\n\n pip install shybrid pyyaml\n\n" ) - name = "shybrid" def __init__(self, file_path): try: import hybridizer.io as sbio import hybridizer.probes as sbprb import yaml - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(self.installation_mesg) # load params file related to the given shybrid recording - assert HAVE_SBEX, self.installation_mesg assert Path(file_path).suffix in [".yml", ".yaml"], "The 'file_path' should be a yaml file!" params = sbio.get_params(file_path)["data"] file_path = Path(file_path) @@ -102,12 +97,9 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * import hybridizer.io as sbio import hybridizer.probes as sbprb import yaml - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(SHYBRIDRecordingExtractor.installation_mesg) - assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg assert recording.get_num_segments() == 1, "SHYBRID can only write single segment recordings" save_path = Path(save_path) recording_name = "recording.bin" @@ -159,18 +151,14 @@ class SHYBRIDSortingExtractor(BaseSorting): """ installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n" - name = "shybrid" def __init__(self, file_path, sampling_frequency, delimiter=","): try: import hybridizer.io as sbio import hybridizer.probes as sbprb - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(self.installation_mesg) - assert HAVE_SBEX, self.installation_mesg assert Path(file_path).suffix == ".csv", "The 'file_path' should be a csv file!" if Path(file_path).is_file(): @@ -205,12 +193,9 @@ def write_sorting(sorting, save_path): try: import hybridizer.io as sbio import hybridizer.probes as sbprb - - HAVE_SBEX = True except ImportError: - HAVE_SBEX = False + raise ImportError(SHYBRIDSortingExtractor.installation_mesg) - assert HAVE_SBEX, SHYBRIDSortingExtractor.installation_mesg assert sorting.get_num_segments() == 1, "SHYBRID can only write single segment sortings" save_path = Path(save_path) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 522f639760..c3e92a63ff 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -23,10 +23,6 @@ class SinapsResearchPlatformRecordingExtractor(ChannelSliceRecording): "filt" extracts the filtered data, "raw" extracts the raw data, and "aux" extracts the auxiliary data. """ - extractor_name = "SinapsResearchPlatform" - mode = "file" - name = "sinaps_research_platform" - def __init__(self, file_path: str | Path, stream_name: str = "filt"): from ..preprocessing import UnsignedToSignedRecording @@ -91,10 +87,6 @@ class SinapsResearchPlatformH5RecordingExtractor(BaseRecording): Path to the SiNAPS .h5 file. """ - extractor_name = "SinapsResearchPlatformH5" - mode = "file" - name = "sinaps_research_platform_h5" - def __init__(self, file_path: str | Path): self._file_path = file_path diff --git a/src/spikeinterface/extractors/spykingcircusextractors.py b/src/spikeinterface/extractors/spykingcircusextractors.py index 7c3fb154fe..b8a1e5635e 100644 --- a/src/spikeinterface/extractors/spykingcircusextractors.py +++ b/src/spikeinterface/extractors/spykingcircusextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False - class SpykingCircusSortingExtractor(BaseSorting): """Load SpykingCircus format data as a recording extractor. @@ -29,13 +22,13 @@ class SpykingCircusSortingExtractor(BaseSorting): Loaded data. """ - installed = HAVE_H5PY # check at class level if installed or not - mode = "folder" installation_mesg = "To use the SpykingCircusSortingExtractor install h5py: \n\n pip install h5py\n\n" - name = "spykingcircus" def __init__(self, folder_path): - assert HAVE_H5PY, self.installation_mesg + try: + import h5py + except ImportError: + raise ImportError(self.installation_mesg) spykingcircus_folder = Path(folder_path) listfiles = spykingcircus_folder.iterdir() diff --git a/src/spikeinterface/extractors/tridesclousextractors.py b/src/spikeinterface/extractors/tridesclousextractors.py index 8589f03fd4..ac1ce4727b 100644 --- a/src/spikeinterface/extractors/tridesclousextractors.py +++ b/src/spikeinterface/extractors/tridesclousextractors.py @@ -22,9 +22,7 @@ class TridesclousSortingExtractor(BaseSorting): Loaded data. """ - mode = "folder" installation_mesg = "To use the TridesclousSortingExtractor install tridesclous: \n\n pip install tridesclous\n\n" # error message when not installed - name = "tridesclous" def __init__(self, folder_path, chan_grp=None): try: diff --git a/src/spikeinterface/extractors/waveclussnippetstextractors.py b/src/spikeinterface/extractors/waveclussnippetstextractors.py index 7c26eee7bd..75bae32519 100644 --- a/src/spikeinterface/extractors/waveclussnippetstextractors.py +++ b/src/spikeinterface/extractors/waveclussnippetstextractors.py @@ -10,7 +10,6 @@ class WaveClusSnippetsExtractor(MatlabHelper, BaseSnippets): - name = "waveclus" def __init__(self, file_path): file_path = Path(file_path) if isinstance(file_path, str) else file_path diff --git a/src/spikeinterface/extractors/waveclustextractors.py b/src/spikeinterface/extractors/waveclustextractors.py index 844b1cc7cf..3d024910fa 100644 --- a/src/spikeinterface/extractors/waveclustextractors.py +++ b/src/spikeinterface/extractors/waveclustextractors.py @@ -25,8 +25,6 @@ class WaveClusSortingExtractor(MatlabHelper, BaseSorting): Loaded data. """ - name = "waveclus" - def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 61a49ccf01..7a76906acc 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -7,13 +7,6 @@ from spikeinterface.core import BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -try: - import yaml - - HAVE_YAML = True -except: - HAVE_YAML = False - class YassSortingExtractor(BaseSorting): """Load YASS format data as a sorting extractor. @@ -29,15 +22,13 @@ class YassSortingExtractor(BaseSorting): Loaded data. """ - mode = "folder" - installed = HAVE_YAML # check at class level if installed or not - installation_mesg = ( - "To use the Yass extractor, install pyyaml: \n\n pip install pyyaml\n\n" # error message when not installed - ) - name = "yass" + installation_mesg = "To use the Yass extractor, install pyyaml: \n\n pip install pyyaml\n\n" def __init__(self, folder_path): - assert HAVE_YAML, self.installation_mesg + try: + import yaml + except: + raise ImportError(self.installation_mesg) folder_path = Path(folder_path) diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 7a2291d932..5bf42ecf0f 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -14,6 +14,7 @@ relocate_templates, ) from .noise_tools import generate_noise + from .drifting_generator import ( make_one_displacement_vector, generate_displacement_vector, @@ -26,3 +27,22 @@ list_available_datasets_in_template_database, query_templates_from_database, ) + +# expose the core generate functions +from ..core.generate import ( + generate_recording, + generate_sorting, + generate_snippets, + generate_templates, + generate_recording_by_size, + generate_ground_truth_recording, + add_synchrony_to_sorting, + synthesize_random_firings, + inject_some_duplicate_units, + inject_some_split_units, + synthetize_spike_train_bad_isi, + NoiseGeneratorRecording, + noise_generator_recording, + InjectTemplatesRecording, + inject_templates, +) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index cce2e08b58..70e13160f4 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -458,6 +458,9 @@ def __init__( self.set_probe(drifting_templates.probe, in_place=True) + # templates are too large, we don't serialize them to JSON + self._serializability["json"] = False + self._kwargs = { "sorting": sorting, "drifting_templates": drifting_templates, diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index 11f30e352f..685f0113b4 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -7,22 +7,25 @@ def generate_noise( probe, sampling_frequency, durations, dtype="float32", noise_levels=15.0, spatial_decay=None, seed=None ): """ + Generate a noise recording. Parameters ---------- probe : Probe A probe object. sampling_frequency : float - Sampling frequency + The sampling frequency of the recording. durations : list of float - Durations + The duration(s) of the recording segment(s) in seconds. dtype : np.dtype - Dtype - noise_levels : float | np.array | tuple + The dtype of the recording. + noise_levels : float | np.array | tuple, default: 15.0 If scalar same noises on all channels. If array then per channels noise level. If tuple, then this represent the range. - seed : None | int + spatial_decay : float | None, default: None + If not None, the spatial decay of the noise used to generate the noise covariance matrix. + seed : int | None, default: None The seed for random generator. Returns diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index e1cba07c8e..17d2bdf521 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -20,7 +20,7 @@ def fetch_template_object_from_database(dataset="test_templates.zarr") -> Templa Returns ------- Templates - _description_ + The templates object. """ s3_path = f"s3://spikeinterface-template-database/{dataset}/" zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True}) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index bc8ecb4cb7..93d0448ef4 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be re-referenced reference : "global" | "single" | "local", default: "global" - If "global" the reference is the average or median across all the channels. + If "global" the reference is the average or median across all the channels. To select a subset of channels, + you can use the `ref_channel_ids` parameter. If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`. If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter. operator : "median" | "average", default: "median" @@ -51,10 +52,10 @@ class CommonReferenceRecording(BasePreprocessor): List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to single channels are applied group-wise. However, this is not applied for the local CAR. It is useful when dealing with different channel groups, e.g. multiple tetrodes. - ref_channel_ids : list or str or int, default: None - If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a - list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an - int is expected. + ref_channel_ids : list | str | int | None, default: None + If "global" reference, a list of channels to be used as reference. + If "single" reference, a list of one channel or a single channel id is expected. + If "groups" is provided, then a list of channels to be applied to each group is expected. local_radius : tuple(int, int), default: (30, 55) Use in the local CAR implementation as the selecting annulus with the following format: @@ -82,10 +83,10 @@ def __init__( recording: BaseRecording, reference: Literal["global", "single", "local"] = "global", operator: Literal["median", "average"] = "median", - groups=None, - ref_channel_ids=None, - local_radius=(30, 55), - dtype=None, + groups: list | None = None, + ref_channel_ids: list | str | int | None = None, + local_radius: tuple[float, float] = (30.0, 55.0), + dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None @@ -96,7 +97,9 @@ def __init__( raise ValueError("'operator' must be either 'median', 'average'") if reference == "global": - pass + if ref_channel_ids is not None: + if not isinstance(ref_channel_ids, list): + raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) if self.reference == "global": - shift = self.operator_func(traces, axis=1, keepdims=True) + if self.ref_channel_indices is None: + shift = self.operator_func(traces, axis=1, keepdims=True) + else: + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift elif self.reference == "single": # single channel -> no need of operator diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 1df9b21c81..8b37e7f4b9 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -11,7 +11,7 @@ def _generate_test_recording(): recording = generate_recording(durations=[1.0], num_channels=4) - recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"])) + recording = recording.rename_channels(np.array(["a", "b", "c", "d"])) return recording @@ -23,12 +23,14 @@ def recording(): def test_common_reference(recording): # Test simple case rec_cmr = common_reference(recording, reference="global", operator="median") + rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"]) rec_car = common_reference(recording, reference="global", operator="average") rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"]) rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") traces = recording.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) + assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01) assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01) assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0]) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index c3d1544869..04b731de4f 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -8,13 +8,13 @@ def test_whiten(create_cache_folder): cache_folder = create_cache_folder - rec = generate_recording(num_channels=4) + rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) random_chunk_kwargs = {} - W, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) - print(W) - print(M) + W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) + # print(W) + # print(M) with pytest.raises(AssertionError): W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None) @@ -41,6 +41,10 @@ def test_whiten(create_cache_folder): assert rec4.get_dtype() == "int16" assert rec4._kwargs["M"] is None + # test regularization : norm should be smaller + W2, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, regularize=True) + assert np.linalg.norm(W1) > np.linalg.norm(W2) + if __name__ == "__main__": test_whiten() diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 874d4304e3..96cf5e028f 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -7,6 +7,7 @@ from ..core import get_random_data_chunks, get_channel_distances from .filter import fix_dtype +from ..core.globals import get_global_job_kwargs class WhitenRecording(BasePreprocessor): @@ -40,6 +41,12 @@ class WhitenRecording(BasePreprocessor): M : 1d np.array or None, default: None Pre-computed means. M can be None when previously computed with apply_mean=False + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. **random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function Returns @@ -55,6 +62,8 @@ def __init__( recording, dtype=None, apply_mean=False, + regularize=False, + regularize_kwargs=None, mode="global", radius_um=100.0, int_scale=None, @@ -75,7 +84,14 @@ def __init__( M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps + recording, + mode, + random_chunk_kwargs, + apply_mean, + radius_um=radius_um, + eps=eps, + regularize=regularize, + regularize_kwargs=regularize_kwargs, ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -90,6 +106,8 @@ def __init__( mode=mode, radius_um=radius_um, apply_mean=apply_mean, + regularize=regularize, + regularize_kwargs=regularize_kwargs, int_scale=float(int_scale) if int_scale is not None else None, M=M.tolist() if M is not None else None, W=W.tolist(), @@ -129,7 +147,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): whiten = define_function_from_class(source_class=WhitenRecording, name="whiten") -def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None): +def compute_whitening_matrix( + recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None +): """ Compute whitening matrix @@ -152,7 +172,12 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r eps : float or None, default: None Small epsilon to regularize SVD. If None, the default is set to 1e-8, but if the data is float type and scaled down to very small values, eps is automatically set to a small fraction (1e-3) of the median of the squared data. - + regularize : bool, default: False + Boolean to decide if we want to regularize the covariance matrix, using a chosen method + of sklearn, specified in regularize_kwargs. Default is GraphicalLassoCV + regularize_kwargs : {'method' : 'GraphicalLassoCV'} + Dictionary of the parameters that could be provided to the method of sklearn, if + the covariance matrix needs to be regularized. Returns ------- W : 2D array @@ -162,7 +187,8 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) - random_data = random_data.astype("float32") + + regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} if apply_mean: M = np.mean(random_data, axis=0) @@ -172,8 +198,18 @@ def compute_whitening_matrix(recording, mode, random_chunk_kwargs, apply_mean, r M = None data = random_data - cov = data.T @ data - cov = cov / data.shape[0] + if not regularize: + cov = data.T @ data + cov = cov / data.shape[0] + else: + import sklearn.covariance + + method = regularize_kwargs.pop("method") + regularize_kwargs["assume_centered"] = True + estimator_class = getattr(sklearn.covariance, method) + estimator = estimator_class(**regularize_kwargs) + estimator.fit(data) + cov = estimator.covariance_ # Here we determine eps used below to avoid division by zero. # Typically we can assume that data is either unscaled integers or in units of diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 788044c0f1..3502d27548 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -15,7 +15,7 @@ import warnings -from spikeinterface.core import load_extractor, BaseRecordingSnippets +from spikeinterface.core import load_extractor, BaseRecordingSnippets, BaseRecording from spikeinterface.core.core_tools import check_json from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs @@ -167,16 +167,20 @@ def params_description(cls): return p @classmethod - def set_params_to_folder(cls, recording, output_folder, new_params, verbose): + def set_params_to_folder( + cls, + recording: BaseRecording, + output_folder: str | Path, + new_params: dict, + verbose: bool, + ) -> dict: params = cls.default_params() + valid_parameters = params.keys() + invalid_parameters = [k for k in new_params.keys() if k not in valid_parameters] - # verify params are in list - bad_params = [] - for p in new_params.keys(): - if p not in params.keys(): - bad_params.append(p) - if len(bad_params) > 0: - raise AttributeError("Bad parameters: " + str(bad_params)) + if invalid_parameters: + error_msg = f"Invalid parameters: {invalid_parameters} \n" f"Valid parameters are: {valid_parameters}" + raise ValueError(error_msg) params.update(new_params) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 45cc93d0b6..be75877f02 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -147,7 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32") + recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) noise_levels = get_noise_levels(recording_w, return_scaled=False) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 2da950ceda..92fcda35d9 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -185,11 +185,11 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize) for count, k in enumerate(("accuracy", "recall", "precision")): - ax = axs[count] + ax = axes[count] for key in case_keys: label = self.cases[key]["label"] @@ -211,7 +211,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -234,21 +234,25 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): else: distances = sklearn.metrics.pairwise_distances(a, b, metric) - im = axs[0, count].imshow(distances, aspect="auto") - axs[0, count].set_title(metric) - fig.colorbar(im, ax=axs[0, count]) + im = axes[0, count].imshow(distances, aspect="auto") + axes[0, count].set_title(metric) + fig.colorbar(im, ax=axes[0, count]) label = self.cases[key]["label"] - axs[0, count].set_title(label) + axes[0, count].set_title(label) return fig - def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)): + def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5), axes=None): if case_keys is None: case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + if axes is None: + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + axes = axes.flatten() + else: + fig = None for count, key in enumerate(case_keys): @@ -287,13 +291,13 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5 elif metric == "agreement": for found, real in zip(matched_ids2[mask], unit_ids1[mask]): to_plot += [scores.at[real, found]] - axs[0, count].plot(snr_matched, to_plot, ".", label="matched") - axs[0, count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed") - axs[0, count].set_xlabel("snr") - axs[0, count].set_ylabel(metric) + axes[count].plot(snr_matched, to_plot, ".", label="matched") + axes[count].plot(snr_missed, np.zeros(len(snr_missed)), ".", c="r", label="missed") + axes[count].set_xlabel("snr") + axes[count].set_ylabel(metric) label = self.cases[key]["label"] - axs[0, count].set_title(label) - axs[0, count].legend() + axes[count].set_title(label) + axes[count].legend() return fig @@ -303,7 +307,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs case_keys = list(self.cases.keys()) import pylab as plt - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) for count, key in enumerate(case_keys): @@ -348,47 +352,61 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs elif metric == "agreement": for found, real in zip(matched_ids2[mask], unit_ids1[mask]): to_plot += [scores.at[real, found]] - axs[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") - axs[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") - axs[0, count].set_xlabel("depth") - axs[0, count].set_ylabel("snr") + elif metric in ["recall", "precision", "accuracy"]: + to_plot = result["gt_comparison"].get_performance()[metric].values + depth_matched = depth + snr_matched = metrics["snr"] + + im = axes[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") + im.set_clim(0, 1) + axes[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") + axes[0, count].set_xlabel("depth") + axes[0, count].set_ylabel("snr") label = self.cases[key]["label"] - axs[0, count].set_title(label) + axes[0, count].set_title(label) + if count > 0: + axes[0, count].set_ylabel("") + axes[0, count].set_yticks([], []) # axs[0, count].legend() + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + fig.colorbar(im, cax=cbar_ax, label=metric) + return fig - def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None): - import pylab as plt + def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None): - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=len(cases_before), nrows=1, figsize=figsize) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, (case_before, case_after) in enumerate(zip(cases_before, cases_after)): ax = axs[count] - - # label = self.cases[case_after]["label"] - - # positions = self.get_result(case_before)["gt_comparison"].sorting1.get_property("gt_unit_locations") - dataset_key = self.cases[case_before]["dataset"] - rec, gt_sorting1 = self.datasets[dataset_key] + _, gt_sorting1 = self.datasets[dataset_key] positions = gt_sorting1.get_property("gt_unit_locations") analyzer = self.get_sorting_analyzer(case_before) metrics_before = analyzer.get_extension("quality_metrics").get_data() x = metrics_before["snr"].values - y_before = self.get_result(case_before)["gt_comparison"].get_performance()[k].values - y_after = self.get_result(case_after)["gt_comparison"].get_performance()[k].values - if count < 2: - ax.set_xticks([], []) - elif count == 2: - ax.set_xlabel("depth (um)") - im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") - fig.colorbar(im, ax=ax) - ax.set_title(k) + y_before = self.get_result(case_before)["gt_comparison"].get_performance()[metric].values + y_after = self.get_result(case_after)["gt_comparison"].get_performance()[metric].values + ax.set_ylabel("depth (um)") ax.set_ylabel("snr") + if count > 0: + ax.set_ylabel("") + ax.set_yticks([], []) + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + im.set_clim(-1, 1) + # fig.colorbar(im, ax=ax) + # ax.set_title(k) + + fig.subplots_adjust(right=0.85) + cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + # cbar.set_clim(-1, 1) + return fig def plot_comparison_clustering( diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index cf91c8b873..ab1523d13a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -11,6 +11,9 @@ import numpy as np from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.sortingcomponents.tools import remove_empty_templates +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.sparsity import compute_sparsity class MatchingBenchmark(Benchmark): @@ -73,17 +76,15 @@ def plot_agreements(self, case_keys=None, figsize=None): ax.set_title(self.cases[key]["label"]) plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - return fig - - def plot_performances_vs_snr(self, case_keys=None, figsize=None): + def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): if case_keys is None: case_keys = list(self.cases.keys()) - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, k in enumerate(metrics): - ax = axs[count] + ax = axs[count, 0] for key in case_keys: label = self.cases[key]["label"] @@ -223,13 +224,13 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) - def plot_unit_losses(self, before, after, figsize=None): + def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) + fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) - for count, k in enumerate(("accuracy", "recall", "precision")): + for count, k in enumerate(metric): - ax = axs[count] + ax = axs[0, count] label = self.cases[after]["label"] @@ -241,15 +242,20 @@ def plot_unit_losses(self, before, after, figsize=None): y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values - if count < 2: - ax.set_xticks([], []) - elif count == 2: - ax.set_xlabel("depth (um)") - im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), marker=".", s=50, cmap="copper") - fig.colorbar(im, ax=ax) + # if count < 2: + # ax.set_xticks([], []) + # elif count == 2: + ax.set_xlabel("depth (um)") + im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") + fig.colorbar(im, ax=ax, label=k) + im.set_clim(-1, 1) ax.set_title(k) ax.set_ylabel("snr") + # fig.subplots_adjust(right=0.85) + # cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + # cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + # if count == 2: # ax.legend() return fig diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index 062309b581..7d862343d2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -196,6 +196,8 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_thres abs_threshold = -detect_threshold * noise_levels ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") + return fig + def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)): if case_keys is None: diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 7dc3fad280..4d6dd43bce 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -10,9 +10,11 @@ from spikeinterface.core import SortingAnalyzer -from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer + +from spikeinterface import load_extractor, create_sorting_analyzer, load_sorting_analyzer from spikeinterface.widgets import get_some_colors + import pickle _key_separator = "_-°°-_" diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 0872a6066c..facefac4c5 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -140,3 +140,14 @@ def remove_empty_templates(templates): probe=templates.probe, is_scaled=templates.is_scaled, ) + + +def sigmoid(x, x0, k, b): + return (1 / (1 + np.exp(-k * (x - x0)))) + b + + +def fit_sigmoid(xdata, ydata, p0=None): + from scipy.optimize import curve_fit + + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + return popt diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index a2c366851b..85043d0d12 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -30,9 +30,7 @@ def __init__( case_keys = list(study.cases.keys()) plot_data = dict( - study=study, - run_times=study.get_run_times(case_keys), - case_keys=case_keys, + study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -48,8 +46,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for i, key in enumerate(dp.case_keys): label = dp.study.cases[key]["label"] rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label) - + self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) + self.ax.set_ylabel("run time (s)") self.ax.legend() @@ -167,6 +165,8 @@ def __init__( case_keys=case_keys, ) + self.colors = study.get_colors() + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): @@ -192,7 +192,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label = study.cases[key]["label"] val = perfs.xs(key).loc[:, performance_name].values val = np.sort(val)[::-1] - ax.plot(val, label=label) + ax.plot(val, label=label, c=self.colors[key]) ax.set_title(performance_name) if count == len(dp.performance_names) - 1: ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) @@ -207,7 +207,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x = study.get_metrics(key).loc[:, metric_name].values y = perfs.xs(key).loc[:, performance_name].values label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label) + ax.scatter(x, y, s=10, label=label, color=self.colors[key]) max_metric = max(max_metric, np.max(x)) ax.set_title(performance_name) ax.set_xlim(0, max_metric * 1.05)