diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index dabad818f9..da94cf549c 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -7,3 +7,4 @@ How to guides get_started analyse_neuropixels handle_drift + load_matlab_data diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst new file mode 100644 index 0000000000..aaca718096 --- /dev/null +++ b/doc/how_to/load_matlab_data.rst @@ -0,0 +1,100 @@ +Exporting MATLAB Data to Binary & Loading in SpikeInterface +=========================================================== + +In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. + +Exporting Data from MATLAB +-------------------------- + +Begin by ensuring your data structure is correct. Organize your data matrix so that the first dimension corresponds to samples/time and the second to channels. +Here, we present a MATLAB code that creates a random dataset and writes it to a binary file as an illustration. + +.. code-block:: matlab + + % Define the size of your data + numSamples = 1000; + numChannels = 384; + + % Generate random data as an example + data = rand(numSamples, numChannels); + + % Write the data to a binary file + fileID = fopen('your_data_as_a_binary.bin', 'wb'); + fwrite(fileID, data, 'double'); + fclose(fileID); + +.. note:: + + In your own script, replace the random data generation with your actual dataset. + +Loading Data in SpikeInterface +------------------------------ + +After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. + +Use the following Python script to load the binary data into SpikeInterface: + +.. code-block:: python + + import spikeinterface as si + from pathlib import Path + + # Define file path + # For Linux or macOS: + file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") + # For Windows: + # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") + + # Confirm file existence + assert file_path.is_file(), f"Error: {file_path} is not a valid file. Please check the path." + + # Define recording parameters + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype = "float64" # MATLAB's double corresponds to Python's float64 + + # Load data using SpikeInterface + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype) + + # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data + print(recording.get_num_frames(), recording.get_num_channels()) + +Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more. + +Common Pitfalls & Tips +---------------------- + +1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`. +2. **File Path**: Always double-check the Python file path. +3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`. +4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. +5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. + +Using gains and offsets for integer data +---------------------------------------- + +Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset. +In SpikeInterface, you can use the `gain_to_uV` and `offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the `read_binary` function. +If your data in MATLAB is stored as `int16`, and you know the gain and offset, you can use the following code to load the data: + +.. code-block:: python + + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype_int = 'int16' # Adjust according to your MATLAB dataset + gain_to_uV = 0.195 # Adjust according to your MATLAB dataset + offset_to_uV = 0 # Adjust according to your MATLAB dataset + + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype_int, + gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) + + recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) + + +This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`. + +.. note:: + + The gain and offset parameters are usually format dependent and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index 8c7c0a2cc3..447d83db52 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -25,9 +25,11 @@ For more details about each metric and it's availability and use within SpikeInt :glob: qualitymetrics/amplitude_cutoff + qualitymetrics/amplitude_cv qualitymetrics/amplitude_median qualitymetrics/d_prime qualitymetrics/drift + qualitymetrics/firing_range qualitymetrics/firing_rate qualitymetrics/isi_violations qualitymetrics/isolation_distance diff --git a/doc/modules/qualitymetrics/amplitude_cutoff.rst b/doc/modules/qualitymetrics/amplitude_cutoff.rst index 9f747f8d40..a1e4d85d01 100644 --- a/doc/modules/qualitymetrics/amplitude_cutoff.rst +++ b/doc/modules/qualitymetrics/amplitude_cutoff.rst @@ -21,12 +21,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitudes from all spikes - fraction_missing = qm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") - # fraction_missing is a dict containing the units' IDs as keys, + fraction_missing = sqm.compute_amplitude_cutoffs(wvf_extractor, peak_sign="neg") + # fraction_missing is a dict containing the unit IDs as keys, # and their estimated fraction of missing spikes as values. Reference diff --git a/doc/modules/qualitymetrics/amplitude_cv.rst b/doc/modules/qualitymetrics/amplitude_cv.rst new file mode 100644 index 0000000000..13117b607c --- /dev/null +++ b/doc/modules/qualitymetrics/amplitude_cv.rst @@ -0,0 +1,55 @@ +Amplitude CV (:code:`amplitude_cv_median`, :code:`amplitude_cv_range`) +====================================================================== + + +Calculation +----------- + +The amplitude CV (coefficient of variation) is a measure of the amplitude variability. +It is computed as the ratio between the standard deviation and the amplitude mean. +To obtain a better estimate of this measure, it is first computed separately for several temporal bins. +Out of these values, the median and the range (percentile distance, by default between the +5th and 95th percentiles) are computed. + +The computation requires either spike amplitudes (see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes()`) +or amplitude scalings (see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings()`) to be pre-computed. + + +Expectation and use +------------------- + +The amplitude CV median is expected to be relatively low for well-isolated units, indicating a "stereotypical" spike shape. + +The amplitude CV range can be high in the presence of noise contamination, due to amplitude outliers like in +the example below. + +.. image:: amplitudes.png + :width: 600 + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + # It is required to run `compute_spike_amplitudes(wvf_extractor)` or + # `compute_amplitude_scalings(wvf_extractor)` (if missing, values will be NaN) + amplitude_cv_median, amplitude_cv_range = sqm.compute_amplitude_cv_metrics(wvf_extractor) + # amplitude_cv_median and amplitude_cv_range are dicts containing the unit ids as keys, + # and their amplitude_cv metrics as values. + + + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_amplitude_cv_metrics + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/amplitude_median.rst b/doc/modules/qualitymetrics/amplitude_median.rst index ffc45d1cf6..3ac52560e8 100644 --- a/doc/modules/qualitymetrics/amplitude_median.rst +++ b/doc/modules/qualitymetrics/amplitude_median.rst @@ -20,12 +20,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # It is also recommended to run `compute_spike_amplitudes(wvf_extractor)` # in order to use amplitude values from all spikes. - amplitude_medians = qm.compute_amplitude_medians(wvf_extractor) - # amplitude_medians is a dict containing the units' IDs as keys, + amplitude_medians = sqm.compute_amplitude_medians(wvf_extractor) + # amplitude_medians is a dict containing the unit IDs as keys, # and their estimated amplitude medians as values. Reference diff --git a/doc/modules/qualitymetrics/amplitudes.png b/doc/modules/qualitymetrics/amplitudes.png new file mode 100644 index 0000000000..0ee4dd1eda Binary files /dev/null and b/doc/modules/qualitymetrics/amplitudes.png differ diff --git a/doc/modules/qualitymetrics/d_prime.rst b/doc/modules/qualitymetrics/d_prime.rst index abb8c1dc74..e3bd61c580 100644 --- a/doc/modules/qualitymetrics/d_prime.rst +++ b/doc/modules/qualitymetrics/d_prime.rst @@ -32,9 +32,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm - d_prime = qm.lda_metrics(all_pcs, all_labels, 0) + d_prime = sqm.lda_metrics(all_pcs, all_labels, 0) Reference diff --git a/doc/modules/qualitymetrics/drift.rst b/doc/modules/qualitymetrics/drift.rst index 0a852f80af..ae52f7f883 100644 --- a/doc/modules/qualitymetrics/drift.rst +++ b/doc/modules/qualitymetrics/drift.rst @@ -40,11 +40,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm + # Make recording, sorting and wvf_extractor object for your data. # It is required to run `compute_spike_locations(wvf_extractor)` # (if missing, values will be NaN) - drift_ptps, drift_stds, drift_mads = qm.compute_drift_metrics(wvf_extractor, peak_sign="neg") + drift_ptps, drift_stds, drift_mads = sqm.compute_drift_metrics(wvf_extractor, peak_sign="neg") # drift_ptps, drift_stds, and drift_mads are dict containing the units' ID as keys, # and their metrics as values. diff --git a/doc/modules/qualitymetrics/firing_range.rst b/doc/modules/qualitymetrics/firing_range.rst new file mode 100644 index 0000000000..925539e9c6 --- /dev/null +++ b/doc/modules/qualitymetrics/firing_range.rst @@ -0,0 +1,40 @@ +Firing range (:code:`firing_range`) +=================================== + + +Calculation +----------- + +The firing range indicates the dispersion of the firing rate of a unit across the recording. It is computed by +taking the difference between the 95th percentile's firing rate and the 5th percentile's firing rate computed over short time bins (e.g. 10 s). + + + +Expectation and use +------------------- + +Very high levels of firing ranges, outside of a physiological range, might indicate noise contamination. + + +Example code +------------ + +.. code-block:: python + + import spikeinterface.qualitymetrics as sqm + + # Make recording, sorting and wvf_extractor object for your data. + firing_range = sqm.compute_firing_ranges(wvf_extractor) + # firing_range is a dict containing the unit IDs as keys, + # and their firing firing_range as values (in Hz). + +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.misc_metrics.compute_firing_ranges + + +Literature +---------- + +Designed by Simon Musall and adapted to SpikeInterface by Alessio Buccino. diff --git a/doc/modules/qualitymetrics/firing_rate.rst b/doc/modules/qualitymetrics/firing_rate.rst index eddef3e48f..c0e15d7c2e 100644 --- a/doc/modules/qualitymetrics/firing_rate.rst +++ b/doc/modules/qualitymetrics/firing_rate.rst @@ -37,11 +37,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - firing_rate = qm.compute_firing_rates(wvf_extractor) - # firing_rate is a dict containing the units' IDs as keys, + firing_rate = sqm.compute_firing_rates(wvf_extractor) + # firing_rate is a dict containing the unit IDs as keys, # and their firing rates across segments as values (in Hz). References diff --git a/doc/modules/qualitymetrics/isi_violations.rst b/doc/modules/qualitymetrics/isi_violations.rst index 947e7d4938..725d9b0fd6 100644 --- a/doc/modules/qualitymetrics/isi_violations.rst +++ b/doc/modules/qualitymetrics/isi_violations.rst @@ -77,11 +77,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - isi_violations_ratio, isi_violations_count = qm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) + isi_violations_ratio, isi_violations_count = sqm.compute_isi_violations(wvf_extractor, isi_threshold_ms=1.0) References ---------- diff --git a/doc/modules/qualitymetrics/presence_ratio.rst b/doc/modules/qualitymetrics/presence_ratio.rst index e4de2248bd..5a420c8ccf 100644 --- a/doc/modules/qualitymetrics/presence_ratio.rst +++ b/doc/modules/qualitymetrics/presence_ratio.rst @@ -23,12 +23,12 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - presence_ratio = qm.compute_presence_ratios(wvf_extractor) - # presence_ratio is a dict containing the units' IDs as keys + presence_ratio = sqm.compute_presence_ratios(wvf_extractor) + # presence_ratio is a dict containing the unit IDs as keys # and their presence ratio (between 0 and 1) as values. Links to original implementations diff --git a/doc/modules/qualitymetrics/sliding_rp_violations.rst b/doc/modules/qualitymetrics/sliding_rp_violations.rst index 843242c1e8..de68c3a92f 100644 --- a/doc/modules/qualitymetrics/sliding_rp_violations.rst +++ b/doc/modules/qualitymetrics/sliding_rp_violations.rst @@ -27,11 +27,11 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - contamination = qm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) + contamination = sqm.compute_sliding_rp_violations(wvf_extractor, bin_size_ms=0.25) References ---------- diff --git a/doc/modules/qualitymetrics/snr.rst b/doc/modules/qualitymetrics/snr.rst index 288ab60515..b88d3291be 100644 --- a/doc/modules/qualitymetrics/snr.rst +++ b/doc/modules/qualitymetrics/snr.rst @@ -41,12 +41,12 @@ With SpikeInterface: .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - SNRs = qm.compute_snrs(wvf_extractor) - # SNRs is a dict containing the units' IDs as keys and their SNRs as values. + SNRs = sqm.compute_snrs(wvf_extractor) + # SNRs is a dict containing the unit IDs as keys and their SNRs as values. Links to original implementations --------------------------------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 2f566bf8a7..0750940199 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -27,9 +27,9 @@ Example code .. code-block:: python - import spikeinterface.qualitymetrics as qm + import spikeinterface.qualitymetrics as sqm # Make recording, sorting and wvf_extractor object for your data. - synchrony = qm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(wvf_extractor, synchrony_sizes=(2, 4, 8)) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index d418b92ab8..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder): stacklevel=2, ) for sorting in self.object_list: - assert ( - sorting.check_if_json_serializable() + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -259,7 +259,8 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..e8b3232e13 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -471,24 +470,33 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. + def check_serializablility(self, type): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + return self._serializablility[type] + + def check_if_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. Returns ------- bool - True if the object is dumpable, False otherwise. + True if the object is memory serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) - return self._is_dumpable + return self.check_serializablility("memory") def check_if_json_serializable(self): """ @@ -499,16 +507,13 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. + return self.check_serializablility("json") + + def check_if_pickle_serializable(self): + # is this needed ??? I think no. + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +562,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +581,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -616,7 +621,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, @@ -653,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -814,10 +819,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder) @@ -911,7 +918,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 07837bcef7..eeb1e8af60 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() @@ -1431,5 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..84ee502c14 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d5663156c7..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + self._serializablility["memory"] = True + self._serializablility["json"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -357,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + + self._serializablility["memory"] = True + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -516,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False - self._is_json_serializable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,39 +31,39 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: + assert extractor.check_if_memory_serializable() - # make not dumpable - test_extractor._is_dumpable = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + # make not not memory serilizable + test_extractor._serializablility["memory"] = False + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: + assert not extractor.check_if_memory_serializable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - test_extractor._is_json_serializable = True + # make a list of json serializable objects + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") - # make not dumpable - test_extractor._is_json_serializable = False + # make of not json serializable objects + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": - test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_memory_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility["memory"] = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +197,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == "__main__": + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..2bbf5e9b0f 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -309,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths @@ -510,10 +516,44 @@ def test_compute_sparsity(): print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() # test_recordingless() # test_compute_sparsity() + test_non_json_object() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..2710ff1338 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -159,11 +159,20 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: sorting = load_extractor(folder / "sorting.json", base_folder=folder) @@ -271,14 +280,22 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + recording.dump(folder / "recording.pickle") + + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back + sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -879,14 +896,19 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -931,16 +953,16 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index c92861a8bf..ebc810b953 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -35,6 +35,7 @@ def export_to_phy( template_mode: str = "median", dtype: Optional[npt.DTypeLike] = None, verbose: bool = True, + use_relative_path: bool = False, **job_kwargs, ): """ @@ -64,6 +65,9 @@ def export_to_phy( Dtype to save binary data verbose: bool If True, output is verbose + use_relative_path : bool, default: False + If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r'recording.dat'`). If `copy_binary=False`, then uses a path relative to the `output_folder` + If False, uses an absolute path in the `params.py` (ie `dat_path=r'path/to/the/recording.dat'`) {} """ @@ -94,7 +98,7 @@ def export_to_phy( used_sparsity = sparsity else: used_sparsity = ChannelSparsity.create_dense(waveform_extractor) - # convinient sparsity dict for the 3 cases to retrieve channl_inds + # convenient sparsity dict for the 3 cases to retrieve channl_inds sparse_dict = used_sparsity.unit_id_to_channel_indices empty_flag = False @@ -106,7 +110,7 @@ def export_to_phy( empty_flag = True unit_ids = non_empty_units if empty_flag: - warnings.warn("Empty units have been removed when being exported to Phy") + warnings.warn("Empty units have been removed while exporting to Phy") if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") @@ -149,7 +153,15 @@ def export_to_phy( # write params.py with (output_folder / "params.py").open("w") as f: - f.write(f"dat_path = r'{str(rec_path)}'\n") + if use_relative_path: + if copy_binary: + f.write(f"dat_path = r'recording.dat'\n") + elif rec_path == "None": + f.write(f"dat_path = {rec_path}\n") + else: + f.write(f"dat_path = r'{str(Path(rec_path).relative_to(output_folder))}'\n") + else: + f.write(f"dat_path = r'{str(rec_path)}'\n") f.write(f"n_channels_dat = {num_chans}\n") f.write(f"dtype = '{dtype_str}'\n") f.write(f"offset = 0\n") diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index a771dc47b1..cd2b6fb941 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -22,6 +22,19 @@ from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts +def drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs): + # Temporary function until neo version 0.13.0 is released + from packaging.version import Version + from importlib.metadata import version as lib_version + + neo_version = lib_version("neo") + # The possibility of ignoring timestamps errors is not present in neo <= 0.12.0 + if Version(neo_version) <= Version("0.12.0"): + neo_kwargs.pop("ignore_timestamps_errors") + + return neo_kwargs + + class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): """ Class for reading data saved by the Open Ephys GUI. @@ -45,14 +58,24 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load. all_annotations: bool (default False) Load exhaustively all annotation from neo. + ignore_timestamps_errors: bool (default False) + Ignore the discontinuous timestamps errors in neo. """ mode = "folder" NeoRawIOClass = "OpenEphysRawIO" name = "openephyslegacy" - def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): - neo_kwargs = self.map_to_neo_kwargs(folder_path) + def __init__( + self, + folder_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations=False, + ignore_timestamps_errors=False, + ): + neo_kwargs = self.map_to_neo_kwargs(folder_path, ignore_timestamps_errors) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, @@ -64,8 +87,9 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod - def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + def map_to_neo_kwargs(cls, folder_path, ignore_timestamps_errors=False): + neo_kwargs = {"dirname": str(folder_path), "ignore_timestamps_errors": ignore_timestamps_errors} + neo_kwargs = drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs) return neo_kwargs diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 38cb714d59..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,7 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dd5f857f6..e9726a16da 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -499,9 +499,8 @@ def compute_sliding_rp_violations( ) -def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **kwargs): - """ - Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of +def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): + """Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of "synchrony_size" spikes at the exact same sample index. Parameters @@ -510,6 +509,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k The waveform extractor object. synchrony_sizes : list or tuple, default: (2, 4, 8) The synchrony sizes to compute. + unit_ids : list or None, default: None + List of unit ids to compute the synchrony metrics. If None, all units are used. Returns ------- @@ -522,16 +523,20 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k Based on concepts described in [Gruen]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert np.all(s > 1 for s in synchrony_sizes), "Synchrony sizes must be greater than 1" + assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" spike_counts = waveform_extractor.sorting.count_num_spikes_per_unit() sorting = waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: + unit_ids = sorting.unit_ids + # Pre-allocate synchrony counts synchrony_counts = {} for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size] = np.zeros(len(waveform_extractor.unit_ids), dtype=np.int64) + all_unit_ids = list(sorting.unit_ids) for segment_index in range(sorting.get_num_segments()): spikes_in_segment = spikes[segment_index] @@ -539,7 +544,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) # add counts for this segment - for unit_index in np.arange(len(sorting.unit_ids)): + for unit_id in unit_ids: + unit_index = all_unit_ids.index(unit_id) spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] # some segments/units might have no spikes if len(spikes_per_unit) == 0: @@ -551,8 +557,8 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # add counts for this segment synchrony_metrics_dict = { f"sync_spike_{synchrony_size}": { - unit_id: synchrony_counts[synchrony_size][unit_index] / spike_counts[unit_id] - for unit_index, unit_id in enumerate(sorting.unit_ids) + unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id] + for unit_id in unit_ids } for synchrony_size in synchrony_sizes } @@ -563,7 +569,172 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k return synchrony_metrics -_default_params["synchrony_metrics"] = dict(synchrony_sizes=(0, 2, 4)) +_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) + + +def compute_firing_ranges(waveform_extractor, bin_size_s=5, percentiles=(5, 95), unit_ids=None, **kwargs): + """Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution + computed in non-overlapping time bins. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + bin_size_s : float, default: 5 + The size of the bin in seconds. + percentiles : tuple, default: (5, 95) + The percentiles to compute. + unit_ids : list or None + List of unit ids to compute the firing range. If None, all units are used. + + Returns + ------- + firing_ranges : dict + The firing range for each unit. + + Notes + ----- + Designed by Simon Musall and ported to SpikeInterface by Alessio Buccino. + """ + sampling_frequency = waveform_extractor.sampling_frequency + bin_size_samples = int(bin_size_s * sampling_frequency) + sorting = waveform_extractor.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + + # for each segment, we compute the firing rate histogram and we concatenate them + firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} + for segment_index in range(waveform_extractor.get_num_segments()): + num_samples = waveform_extractor.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) + + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + + # finally we compute the percentiles + firing_ranges = {} + for unit_id in unit_ids: + firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( + firing_rate_histograms[unit_id], percentiles[0] + ) + + return firing_ranges + + +_default_params["firing_range"] = dict(bin_size_s=5, percentiles=(5, 95)) + + +def compute_amplitude_cv_metrics( + waveform_extractor, + average_num_spikes_per_bin=50, + percentiles=(5, 95), + min_num_bins=10, + amplitude_extension="spike_amplitudes", + unit_ids=None, +): + """Calculate coefficient of variation of spike amplitudes within defined temporal bins. + From the distribution of coefficient of variations, both the median and the "range" (the distance between the + percentiles defined by `percentiles` parameter) are returned. + + Parameters + ---------- + waveform_extractor : WaveformExtractor + The waveform extractor object. + average_num_spikes_per_bin : int, default: 50 + The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate + of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is + 100, then the temporal bin size will be 100/10 Hz = 10 s. + min_num_bins : int, default: 10 + The minimum number of bins to compute the median and range. If the number of bins is less than this then + the median and range are set to NaN. + amplitude_extension : str, default: 'spike_amplitudes' + The name of the extension to load the amplitudes from. 'spike_amplitudes' or 'amplitude_scalings'. + unit_ids : list or None + List of unit ids to compute the amplitude spread. If None, all units are used. + + Returns + ------- + amplitude_cv_median : dict + The median of the CV + amplitude_cv_range : dict + The range of the CV, computed as the distance between the percentiles. + + Notes + ----- + Designed by Simon Musall and Alessio Buccino. + """ + res = namedtuple("amplitude_cv", ["amplitude_cv_median", "amplitude_cv_range"]) + assert amplitude_extension in ( + "spike_amplitudes", + "amplitude_scalings", + ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" + sorting = waveform_extractor.sorting + total_duration = waveform_extractor.get_total_duration() + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit() + if unit_ids is None: + unit_ids = sorting.unit_ids + + if waveform_extractor.is_extension(amplitude_extension): + sac = waveform_extractor.load_extension(amplitude_extension) + amps = sac.get_data(outputs="concatenated") + if amplitude_extension == "spike_amplitudes": + amps = np.concatenate(amps) + else: + warnings.warn("") + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + return empty_dict + + # precompute segment slice + segment_slices = [] + for segment_index in range(waveform_extractor.get_num_segments()): + i0 = np.searchsorted(spikes["segment_index"], segment_index) + i1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append(slice(i0, i1)) + + all_unit_ids = list(sorting.unit_ids) + amplitude_cv_medians, amplitude_cv_ranges = {}, {} + for unit_id in unit_ids: + firing_rate = num_spikes[unit_id] / total_duration + temporal_bin_size_samples = int( + (average_num_spikes_per_bin / firing_rate) * waveform_extractor.sampling_frequency + ) + + amp_spreads = [] + # bins and amplitude means are computed for each segment + for segment_index in range(waveform_extractor.get_num_segments()): + sample_bin_edges = np.arange( + 0, waveform_extractor.get_num_samples(segment_index) + 1, temporal_bin_size_samples + ) + spikes_in_segment = spikes[segment_slices[segment_index]] + amps_in_segment = amps[segment_slices[segment_index]] + unit_mask = spikes_in_segment["unit_index"] == all_unit_ids.index(unit_id) + spike_indices_unit = spikes_in_segment["sample_index"][unit_mask] + amps_unit = amps_in_segment[unit_mask] + amp_mean = np.abs(np.mean(amps_unit)) + for t0, t1 in zip(sample_bin_edges[:-1], sample_bin_edges[1:]): + i0 = np.searchsorted(spike_indices_unit, t0) + i1 = np.searchsorted(spike_indices_unit, t1) + amp_spreads.append(np.std(amps_unit[i0:i1]) / amp_mean) + + if len(amp_spreads) < min_num_bins: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + else: + amplitude_cv_medians[unit_id] = np.median(amp_spreads) + amplitude_cv_ranges[unit_id] = np.percentile(amp_spreads, percentiles[1]) - np.percentile( + amp_spreads, percentiles[0] + ) + + return res(amplitude_cv_medians, amplitude_cv_ranges) + + +_default_params["amplitude_cv"] = dict( + average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes" +) def compute_amplitude_cutoffs( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 90dbb47a3a..97f14ec6f4 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -12,6 +12,8 @@ compute_amplitude_medians, compute_drift_metrics, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) from .pca_metrics import ( @@ -40,6 +42,8 @@ "sliding_rp_violation": compute_sliding_rp_violations, "amplitude_cutoff": compute_amplitude_cutoffs, "amplitude_median": compute_amplitude_medians, + "amplitude_cv": compute_amplitude_cv_metrics, "synchrony": compute_synchrony_metrics, + "firing_range": compute_firing_ranges, "drift": compute_drift_metrics, } diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d927d64c4f..2d63a06b17 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,6 +12,7 @@ compute_principal_components, compute_spike_locations, compute_spike_amplitudes, + compute_amplitude_scalings, ) from spikeinterface.qualitymetrics import ( @@ -31,6 +32,8 @@ compute_drift_metrics, compute_amplitude_medians, compute_synchrony_metrics, + compute_firing_ranges, + compute_amplitude_cv_metrics, ) @@ -212,6 +215,12 @@ def test_calculate_firing_rate_num_spikes(waveform_extractor_simple): # np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) +def test_calculate_firing_range(waveform_extractor_simple): + we = waveform_extractor_simple + firing_ranges = compute_firing_ranges(we) + print(firing_ranges) + + def test_calculate_amplitude_cutoff(waveform_extractor_simple): we = waveform_extractor_simple spike_amps = compute_spike_amplitudes(we) @@ -234,6 +243,24 @@ def test_calculate_amplitude_median(waveform_extractor_simple): # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) +def test_calculate_amplitude_cv_metrics(waveform_extractor_simple): + we = waveform_extractor_simple + spike_amps = compute_spike_amplitudes(we) + amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(we, average_num_spikes_per_bin=20) + print(amp_cv_median) + print(amp_cv_range) + + amps_scalings = compute_amplitude_scalings(we) + amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( + we, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + print(amp_cv_median_scalings) + print(amp_cv_range_scalings) + + def test_calculate_snrs(waveform_extractor_simple): we = waveform_extractor_simple snrs = compute_snrs(we) @@ -358,4 +385,6 @@ def test_calculate_drift_metrics(waveform_extractor_simple): # test_calculate_isi_violations(we) # test_calculate_sliding_rp_violations(we) # test_calculate_drift_metrics(we) - test_synchrony_metrics(we) + # test_synchrony_metrics(we) + test_calculate_firing_range(we) + test_calculate_amplitude_cv_metrics(we) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index c7581ba1e1..8d87558191 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,8 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): - recording.dump_to_json(rec_file, relative_to=output_folder) + if recording.check_serializablility("json"): + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -185,6 +187,26 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exits(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -271,7 +293,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -294,27 +316,21 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 55a36d26d5..710c4f76f4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -52,9 +52,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..ed327e0f3c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..bd5667b15f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,9 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 14c938f8ba..a5e29c8fd9 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -178,7 +178,7 @@ def test_run_sorters_with_list(): if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable + # make serializable rec0 = load_extractor(cache_folder / "toy_rec_0") rec1 = load_extractor(cache_folder / "toy_rec_1") diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 273b1402fe..af3a9cb86a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -533,7 +533,13 @@ def remove_duplicates( def remove_duplicates_via_matching( - waveform_extractor, noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, method="circus-omp-svd" + waveform_extractor, + noise_levels, + peak_labels, + method_kwargs={}, + job_kwargs={}, + tmp_folder=None, + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels