diff --git a/doc/api.rst b/doc/api.rst index 3e825084e7..1ac37e4740 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -338,59 +338,58 @@ spikeinterface.curation spikeinterface.generation ------------------------- +.. currentmodule:: spikeinterface.generation + Core ~~~~ -.. automodule:: spikeinterface.generation - - .. autofunction:: generate_recording - .. autofunction:: generate_sorting - .. autofunction:: generate_snippets - .. autofunction:: generate_templates - .. autofunction:: generate_recording_by_size - .. autofunction:: generate_ground_truth_recording - .. autofunction:: add_synchrony_to_sorting - .. autofunction:: synthesize_random_firings - .. autofunction:: inject_some_duplicate_units - .. autofunction:: inject_some_split_units - .. autofunction:: synthetize_spike_train_bad_isi - .. autofunction:: inject_templates - .. autofunction:: noise_generator_recording - .. autoclass:: InjectTemplatesRecording - .. autoclass:: NoiseGeneratorRecording + + +.. autofunction:: generate_recording +.. autofunction:: generate_sorting +.. autofunction:: generate_snippets +.. autofunction:: generate_templates +.. autofunction:: generate_recording_by_size +.. autofunction:: generate_ground_truth_recording +.. autofunction:: add_synchrony_to_sorting +.. autofunction:: synthesize_random_firings +.. autofunction:: inject_some_duplicate_units +.. autofunction:: inject_some_split_units +.. autofunction:: synthetize_spike_train_bad_isi +.. autofunction:: inject_templates +.. autofunction:: noise_generator_recording +.. autoclass:: InjectTemplatesRecording +.. autoclass:: NoiseGeneratorRecording Drift ~~~~~ -.. automodule:: spikeinterface.generation - .. autofunction:: generate_drifting_recording - .. autofunction:: generate_displacement_vector - .. autofunction:: make_one_displacement_vector - .. autofunction:: make_linear_displacement - .. autofunction:: move_dense_templates - .. autofunction:: interpolate_templates - .. autoclass:: DriftingTemplates - .. autoclass:: InjectDriftingTemplatesRecording +.. autofunction:: generate_drifting_recording +.. autofunction:: generate_displacement_vector +.. autofunction:: make_one_displacement_vector +.. autofunction:: make_linear_displacement +.. autofunction:: move_dense_templates +.. autofunction:: interpolate_templates +.. autoclass:: DriftingTemplates +.. autoclass:: InjectDriftingTemplatesRecording Hybrid ~~~~~~ -.. automodule:: spikeinterface.generation - .. autofunction:: generate_hybrid_recording - .. autofunction:: estimate_templates_from_recording - .. autofunction:: select_templates - .. autofunction:: scale_template_to_range - .. autofunction:: relocate_templates - .. autofunction:: fetch_template_object_from_database - .. autofunction:: fetch_templates_database_info - .. autofunction:: list_available_datasets_in_template_database - .. autofunction:: query_templates_from_database +.. autofunction:: generate_hybrid_recording +.. autofunction:: estimate_templates_from_recording +.. autofunction:: select_templates +.. autofunction:: scale_template_to_range +.. autofunction:: relocate_templates +.. autofunction:: fetch_template_object_from_database +.. autofunction:: fetch_templates_database_info +.. autofunction:: list_available_datasets_in_template_database +.. autofunction:: query_templates_from_database Noise ~~~~~ -.. automodule:: spikeinterface.generation - .. autofunction:: generate_noise +.. autofunction:: generate_noise spikeinterface.sortingcomponents @@ -408,12 +407,6 @@ Peak Detection .. autofunction:: detect_peaks -Motion Correction -~~~~~~~~~~~~~~~~~ -.. automodule:: spikeinterface.sortingcomponents.motion_interpolation - - .. autoclass:: InterpolateMotionRecording - Clustering ~~~~~~~~~~ .. automodule:: spikeinterface.sortingcomponents.clustering @@ -425,3 +418,15 @@ Template Matching .. automodule:: spikeinterface.sortingcomponents.matching .. autofunction:: find_spikes_from_templates + +Motion Correction +~~~~~~~~~~~~~~~~~ +.. automodule:: spikeinterface.sortingcomponents.motion + + .. autoclass:: Motion + .. autofunction:: estimate_motion + .. autofunction:: interpolate_motion + .. autofunction:: correct_motion_on_peaks + .. autofunction:: interpolate_motion_on_traces + .. autofunction:: clean_motion_vector + .. autoclass:: InterpolateMotionRecording diff --git a/doc/conf.py b/doc/conf.py index 4373ec3c36..e3d58ca8f2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -74,6 +74,8 @@ "IPython.sphinxext.ipython_console_highlighting" ] +autosectionlabel_prefix_document = True + numpydoc_show_class_members = False @@ -128,7 +130,7 @@ '../examples/tutorials/widgets', ]), 'within_subsection_order': FileNameSortKey, - 'ignore_pattern': '/generate_', + 'ignore_pattern': '/generate_*', 'nested_sections': False, 'copyfile_regex': r'.*\.rst|.*\.png|.*\.svg' } diff --git a/doc/get_started/install_sorters.rst b/doc/get_started/install_sorters.rst index c1666352b1..12233784fd 100644 --- a/doc/get_started/install_sorters.rst +++ b/doc/get_started/install_sorters.rst @@ -27,7 +27,7 @@ sorters to retrieve installation instructions for other operating systems. We use **pip** to install packages, but **conda** should also work in many cases. Some novel spike sorting algorithms are implemented directly in SpikeInterface using the -:py:mod:`spikeinterface.sortingcomponents` module. Checkout the :ref:`SpikeInterface-based spike sorters` section of this page +:py:mod:`spikeinterface.sortingcomponents` module. Checkout the :ref:`get_started/install_sorters:SpikeInterface-based spike sorters` section of this page for more information! If you experience installation problems please directly contact the authors of these tools or write on the diff --git a/doc/get_started/installation.rst b/doc/get_started/installation.rst index aae9cdf63c..182ce67b94 100644 --- a/doc/get_started/installation.rst +++ b/doc/get_started/installation.rst @@ -102,4 +102,4 @@ Sub-modules have more dependencies, so you should also install: All external spike sorters can be either run inside containers (Docker or Singularity - see :ref:`containerizedsorters`) -or must be installed independently (see :ref:`Installing Spike Sorters`). +or must be installed independently (see :ref:`get_started/install_sorters:Installing Spike Sorters`). diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 5870d87955..9975bb1a4b 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -24,7 +24,7 @@ order to smoothly inject spikes into the recording. import spikeinterface.generation as sgen import spikeinterface.widgets as sw - from spikeinterface.sortingcomponents.motion_estimation import estimate_motion + from spikeinterface.sortingcomponents.motion import estimate_motion import numpy as np import matplotlib.pyplot as plt @@ -1202,63 +1202,63 @@ drifts when injecting hybrid spikes. 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 1. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 2. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 3. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 4. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 5. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 6. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 7. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 8. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 9. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 10. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 11. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 12. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 13. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 14. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385 - 0. 0. 0.07692308 0.07692308 0.15384615 0.15384615 + 15. 0. 0.07692308 0.07692308 0.15384615 0.15384615 0.23076923 0.23076923 0.30769231 0.30769231 0.38461538 0.38461538 0.46153846 0.46153846 0.53846154 0.53846154 0.61538462 0.61538462 0.69230769 0.69230769 0.76923077 0.76923077 0.84615385 0.84615385] diff --git a/doc/how_to/drift_with_lfp.rst b/doc/how_to/drift_with_lfp.rst new file mode 100644 index 0000000000..0decc1058a --- /dev/null +++ b/doc/how_to/drift_with_lfp.rst @@ -0,0 +1,163 @@ +Estimate drift using the LFP traces +=================================== + +Drift is a well known issue for long shank probes. Some datasets, especially from primates and humans, +can experience very fast motion due to breathing and heart beats. In these cases, the standard motion +estimation methods that use detected spikes as a basis for motion inference will fail, because there +are not enough spikes to "follow" such fast drifts. + +Charlie Windolf and colleagues from the Paninski Lab at Columbia have developed a method to estimate +the motion using the LFP signal: **DREDge**. (more details about the method in the paper +`DREDge: robust motion correction for high-density extracellular recordings across species `_). + +This method is particularly suited for the open dataset recorded at Massachusetts General Hospital by Angelique Paulk and colleagues in humans (more details in the [paper](https://doi.org/10.1038/s41593-021-00997-0)). The dataset can be dowloaed from [datadryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.d2547d840) and it contains recordings on human patients with a Neuropixels probe, some of which with very high and fast motion on the probe, which prevents accurate spike sorting without a proper and adequate motion correction + +The **DREDge** method has two options: **dredge_lfp** and **dredge_ap**, which have both been ported inside `SpikeInterface`. + +Here we will demonstrate the **dredge_lfp** method to estimate the fast and high drift on this recording. + +For each patient, the dataset contains two streams: + +* a highpass "action potential" (AP), sampled at 30kHz +* a lowpass "local field" (LF) sampled at 2.5kHz + +For this demonstration, we will use the LF stream. + +.. code:: ipython3 + + %matplotlib inline + %load_ext autoreload + %autoreload 2 + +.. code:: ipython3 + + from pathlib import Path + import matplotlib.pyplot as plt + + import spikeinterface.full as si + from spikeinterface.sortingcomponents.motion import estimate_motion + +.. code:: ipython3 + + # the dataset has been locally downloaded + base_folder = Path("/mnt/data/sam/DataSpikeSorting/") + np_data_drift = base_folder / 'human_neuropixel/Pt02/' + +Read the spikeglx file +~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: ipython3 + + raw_rec = si.read_spikeglx(np_data_drift) + print(raw_rec) + + +.. parsed-literal:: + + SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples + 873.32s (14.56 minutes) - int16 dtype - 1.56 GiB + + +Preprocessing +~~~~~~~~~~~~~ + +Contrary to the **dredge_ap** approach, which needs detected peaks and peak locations, the **dredge_lfp** +method is estimating the motion directly on traces. +Importantly, the method requires some additional pre-processing steps: + * ``bandpass_filter``: to "focus" the signal on a particular band + * ``phase_shift``: to compensate for the sampling misalignement + * ``resample``: to further reduce the sampling fequency of the signal and speed up the computation. The sampling frequency of the estimated motion will be the same as the resampling frequency. Here we choose 250Hz, which corresponds to a sampling interval of 4ms. + * ``directional_derivative``: this optional step applies a second order derivative in the spatial dimension to enhance edges on the traces. + This is not a general rules and need to be tested case by case. + * ``average_across_direction``: Neuropixels 1.0 probes have two contacts per depth. This steps averages them to obtain a unique virtual signal along the probe depth ("y" in ``spikeinterface``). + +After appying this preprocessing chain, the motion can be estimated almost by eyes ont the traces plotted with the map mode. + +.. code:: ipython3 + + lfprec = si.bandpass_filter( + raw_rec, + freq_min=0.5, + freq_max=250, + + margin_ms=1500., + filter_order=3, + dtype="float32", + add_reflect_padding=True, + ) + lfprec = si.phase_shift(lfprec) + lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000) + + lfprec = si.directional_derivative(lfprec, order=2, edge_order=1) + lfprec = si.average_across_direction(lfprec) + + print(lfprec) + + +.. parsed-literal:: + + AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples + 873.32s (14.56 minutes) - float32 dtype - 159.91 MiB + + +.. code:: ipython3 + + %matplotlib inline + si.plot_traces(lfprec, backend="matplotlib", mode="map", clim=(-0.05, 0.05), time_range=(400, 420)) + + + +.. image:: drift_with_lfp_files/drift_with_lfp_8_1.png + + +Run the method +~~~~~~~~~~~~~~ + +``estimate_motion()`` is the generic function to estimate motion with multiple +methods in ``spikeinterface``. + +This function returns a ``Motion`` object and we can notice that the interval is exactly +the same as downsampled signal. + +Here we use ``rigid=True``, which means that we have one unqiue signal to +describe the motion across the entire probe depth. + +.. code:: ipython3 + + motion = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True) + motion + + +.. parsed-literal:: + + Online chunks [10.0s each]: 0%| | 0/87 [00:00`_. - Please refer to the stable documentation `here `_. - Learn how to update your code `here `_ and read more about the - :code:`SortingAnalyzer` `here `_. + Please refer to the `stable documentation `_. + Learn how to :ref:`update your code here ` and read more about the + :ref:`SortingAnalyzer here `. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index ed1d37dc64..e993d0120b 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -21,7 +21,7 @@ All classes support: * data on-demand (lazy loading) * multiple segments, where each segment is a contiguous piece of data (recording, sorting, events). -.. _core-recording: + Recording --------- @@ -162,7 +162,6 @@ Internally, any sorting object can construct 2 internal caches: 2. a unique numpy.array with structured dtype aka "spikes vector". This is useful for processing by small chunks of time, like for extracting amplitudes from a recording. -.. _core-sorting-analyzer: SortingAnalyzer --------------- @@ -178,7 +177,7 @@ to perform further analysis, such as calculating :code:`waveforms` and :code:`te Importantly, the :py:class:`~spikeinterface.core.SortingAnalyzer` handles the *sparsity* and the physical *scaling*. Sparsity defines the channels on which waveforms and templates are calculated using, for example, a -physical distance from the channel with the largest peak amplitude (see the :ref:`Sparsity` section). Scaling, set by +physical distance from the channel with the largest peak amplitude (see the :ref:`modules/core:Sparsity` section). Scaling, set by the :code:`return_scaled` argument, determines whether the data is converted from integer values to :math:`\mu V` or not. By default, :code:`return_scaled` is true and all processed data voltage values are in :math:`\mu V` (e.g., waveforms, templates, spike amplitudes, etc.). @@ -207,7 +206,7 @@ Now we will create a :code:`SortingAnalyzer` called :code:`sorting_analyzer`. The :py:class:`~spikeinterface.core.SortingAnalyzer` by default is defined *in memory*, but it can be saved at any time (or upon instantiation) to one of the following backends: -* | :code:`zarr`: the sorting analyzer is saved to a `Zarr `_ folder, and each extension is a Zarr group. This is the recommended backend, since Zarr files can be written to/read from the cloud and compression is applied. +* | :code:`zarr`: the sorting analyzer is saved to a `Zarr `__ folder, and each extension is a Zarr group. This is the recommended backend, since Zarr files can be written to/read from the cloud and compression is applied. * | :code:`binary_folder`: the sorting analyzer is saved to a folder, and each extension creates a sub-folder. The extension data are saved to either :code:`npy` (for arrays), :code:`csv` (for dataframes), or :code:`pickle` (for everything else). If the sorting analyzer is in memory, the :code:`SortingAnalyzer.save_as` function can be used to save it @@ -568,7 +567,7 @@ re-instantiate the object from scratch (this is true for all objects except in-m The :code:`save()` function allows to easily store SI objects to a folder on disk. :py:class:`~spikeinterface.core.BaseRecording` objects are stored in binary (.raw) or -`Zarr `_ (.zarr) format and +`Zarr `__ (.zarr) format and :py:class:`~spikeinterface.core.BaseSorting` and :py:class:`~spikeinterface.core.BaseSnippets` object in numpy (.npz) format. With the actual data, the :code:`save()` function also stores the provenance dictionary and all the properties and annotations associated to the object. @@ -922,7 +921,7 @@ The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: * extract waveforms * sub-sample spikes for waveform extraction * compute templates (i.e. average extracellular waveforms) with different modes -* save waveforms in a folder (in numpy / `Zarr `_) for easy retrieval +* save waveforms in a folder (in numpy / `Zarr `__) for easy retrieval * save sparse waveforms or *sparsify* dense waveforms * select units and associated waveforms diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 46fdcc6d65..45e6fb9ae8 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -65,12 +65,12 @@ This format has two part: * "merged_unit_groups" * "removed_units" -Here is the description of the format with a simple example: +Here is the description of the format with a simple example (the first part of the +format is the definition; the second part of the format is manual action): .. code-block:: json { - # the first part of the format is the definitation "format_version": "1", "unit_ids": [ "u1", @@ -91,7 +91,7 @@ Here is the description of the format with a simple example: "MUA", "artifact" ], - "exclusive": true + "exclusive": "true" }, "putative_type": { "label_options": [ @@ -100,10 +100,10 @@ Here is the description of the format with a simple example: "pyramidal", "mitral" ], - "exclusive": false + "exclusive": "false" } }, - # the second part of the format is manual action + "manual_labels": [ { "unit_id": "u1", diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 6ccee2246c..7f8eeeb19e 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -15,7 +15,7 @@ results. **Note** : :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. -So if these pieces of information have already been computed as extensions (see :ref:`analyzer_extensions`), +So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) can be computed directly at export time. diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index af81cb42d1..076a560e31 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -151,8 +151,7 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks - from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - from spikeinterface.sortingcomponents.motion_interpolation import interpolate_motion + from spikeinterface.sortingcomponents.motion import estimate_motion, interpolate_motion job_kwargs = dict(chunk_duration="1s", n_jobs=20, progress_bar=True) # Step 1 : activity profile diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index ac80cc082a..ffb55d2929 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -59,7 +59,7 @@ To check what extensions spikeinterface can calculate, you can use the :code:`ge >>> ['random_spikes', 'waveforms', 'templates', 'noise_levels', 'amplitude_scalings', 'correlograms', 'isi_histograms', 'principal_components', 'spike_amplitudes', 'spike_locations', 'template_metrics', 'template_similarity', 'unit_locations', 'quality_metrics'] -There is detailed documentation about each extension :ref:`below`. +There is detailed documentation about each extension :ref:`below`. Each extension comes from a different module. To use the :code:`postprocessing` extensions, you'll need to have the `postprocessing` module loaded. @@ -68,11 +68,9 @@ both `random_spikes` and `waveforms`. We say that `principal_components` is a ch two. Other extensions, like `isi_histograms`, don't depend on anything. It has no children and no parents. The parent/child relationships of all the extensions currently defined in spikeinterface can be found in this diagram: -| .. figure:: ../images/parent_child.svg :alt: Parent child relationships for the extensions in spikeinterface :align: center -| If you try to calculate a child before calculating a parent, an error will be thrown. Further, when a parent is recalculated we delete its children. Why? Consider calculating :code:`principal_components`. This depends on random selection of spikes chosen @@ -312,7 +310,7 @@ By default, the following metrics are computed: The units of :code:`recovery_slope` and :code:`repolarization_slope` depend on the input. Voltages are based on the units of the template. By default this is :math:`\mu V` but can be the raw output from the recording device (this depends on the -:code:`return_scaled` parameter, read more here: :ref:`core-sorting-analyzer`). +:code:`return_scaled` parameter, read more here: :ref:`modules/core:SortingAnalyzer`). Distances are in :math:`\mu m` and times are in seconds. So, for example, if the templates are in units of :math:`\mu V` then: :code:`repolarization_slope` is in :math:`mV / s`; :code:`peak_to_trough_ratio` is in :math:`\mu m` and the diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 38a7da995b..4fbadd3ab1 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -32,7 +32,7 @@ These two preprocessors will not compute anything at instantiation, but the comp traces = recording_cmr.get_traces(start_frame=100_000, end_frame=200_000) -Some internal sorters (see :ref:`si_based`) can work directly on these preprocessed objects so there is no need to +Some internal sorters (see :ref:`modules/sorters:Intertnal Sorters`) can work directly on these preprocessed objects so there is no need to save the object: .. code-block:: python diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index f5f3581c31..04a302c597 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -67,12 +67,3 @@ This code snippet shows how to compute quality metrics (with or without principa For more information about quality metrics, check out this excellent `documentation `_ from the Allen Institute. - - -Quality Metrics References --------------------------- - -.. toctree:: - :maxdepth: 1 - - qualitymetrics/references diff --git a/doc/modules/qualitymetrics/references.rst b/doc/modules/qualitymetrics/references.rst deleted file mode 100644 index f5236cff66..0000000000 --- a/doc/modules/qualitymetrics/references.rst +++ /dev/null @@ -1,30 +0,0 @@ -References ----------- - -.. [Buzsáki] Buzsáki, György, and Kenji Mizuseki. “The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations.” Nature reviews. Neuroscience 15.4 (2014): 264–278. Web. - -.. [Chung] Chung, Jason E et al. “A Fully Automated Approach to Spike Sorting.” Neuron (Cambridge, Mass.) 95.6 (2017): 1381–1394.e6. Web. - -.. [Harris] Kenneth D Harris, Hajime Hirase, Xavier Leinekugel, Darrell A Henze, and Gy ̈orgy Buzs ́aki. Temporal interaction between single spikes and complex spike bursts in hippocampal pyramidal cells. Neuron (Cambridge, Mass.), 32(1):141–149, 2001. - -.. [Hill] Hill, Daniel N., Samar B. Mehta, and David Kleinfeld. “Quality Metrics to Accompany Spike Sorting of Extracellular Signals.” The Journal of neuroscience 31.24 (2011): 8699–8705. Web. - -.. [Hruschka] Hruschka, E.R., de Castro, L.N., Campello R.J.G.B. "Evolutionary algorithms for clustering gene-expression data." Fourth IEEE International Conference on Data Mining (ICDM'04) 2004, pp 403-406. - -.. [Gruen] Sonja Grün, Moshe Abeles, and Markus Diesmann. Impact of higher-order correlations on coincidence distributions of massively parallel data. In International School on Neural Networks, Initiated by IIASS and EMFCSC, volume 5286, 96–114. Springer, 2007. - -.. [IBL] International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. - -.. [Jackson] Jadin Jackson, Neil Schmitzer-Torbert, K.D. Harris, and A.D. Redish. Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Soc Neurosci Abstr, 518, 01 2005. - -.. [Lemon] R. Lemon. Methods for neuronal recording in conscious animals. IBRO Handbook Series, 4:56–60, 1984. - -.. [Llobet] Llobet Victor, Wyngaard Aurélien and Barbour Boris. “Automatic post-processing and merging of multiple spike-sorting analyses with Lussac“. BioRxiv (2022). - -.. [Pouzat] Pouzat Christophe, Mazor Ofer and Laurent Gilles. “Using noise signature to optimize spike-sorting and to assess neuronal classification quality“. Journal of Neuroscience Methods (2002). - -.. [Rousseeuw] Peter J Rousseeuw. Silhouettes: A graphical aid to the interpretation and validation of cluster analysis. Journal of computational and applied mathematics, 20(C):53–65, 1987. - -.. [Schmitzer-Torbert] Schmitzer-Torbert, Neil, and A. David Redish. “Neuronal Activity in the Rodent Dorsal Striatum in Sequential Navigation: Separation of Spatial and Reward Responses on the Multiple T Task.” Journal of neurophysiology 91.5 (2004): 2259–2272. Web. - -.. [Siegle] Siegle, Joshua H. et al. “Survey of Spiking in the Mouse Visual System Reveals Functional Hierarchy.” Nature (London) 592.7852 (2021): 86–. Web. diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index 41c92dd99e..d244fd0c0f 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -48,4 +48,4 @@ References Literature ---------- -Based on concepts described in [Gruen]_ +Based on concepts described in [Grün]_ diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index d8a4708236..2a440f9c0a 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -473,6 +473,8 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **HDSort** :code:`run_sorter(sorter_name='hdsort')` * **YASS** :code:`run_sorter(sorter_name='yass')` +Intertnal Sorters +----------------- Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index f33a0b3cf2..a32e111bd7 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -190,10 +190,10 @@ Here is an example with non-rigid motion estimation: peak_locations = localize_peaks(recording=recording, peaks=peaks, ...) # as above - from spikeinterface.sortingcomponents.motion_estimation import estimate_motion + from spikeinterface.sortingcomponents.motion import estimate_motion motion, temporal_bins, spatial_bins, extra_check = estimate_motion(recording=recording, peaks=peaks, peak_locations=peak_locations, - direction='y', bin_duration_s=10., bin_um=10., margin_um=0., + direction='y', bin_s=10., bin_um=10., margin_um=0., method='decentralized_registration', rigid=False, win_shape='gaussian', win_step_um=50., win_sigma_um=150., progress_bar=True, verbose=True) @@ -206,7 +206,7 @@ Motion interpolation The estimated motion can be used to interpolate traces, in other words, for drift correction. One possible way is to make an interpolation sample-by-sample to compensate for the motion. -The :py:class:`~spikeinterface.sortingcomponents.motion_interpolation.InterpolateMotionRecording` is a preprocessing +The :py:class:`~spikeinterface.sortingcomponents.motion.InterpolateMotionRecording` is a preprocessing step doing this. This preprocessing is *lazy*, so that interpolation is done on-the-fly. However, the class needs the "motion vector" as input, which requires a relatively long computation (peak detection, localization and motion estimation). @@ -216,7 +216,7 @@ Here is a short example that depends on the output of "Motion interpolation": .. code-block:: python - from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording + from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording recording_corrected = InterpolateMotionRecording(recording=recording_with_drift, motion=motion, temporal_bins=temporal_bins, spatial_bins=spatial_bins spatial_interpolation_method='kriging', diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 4d69867d83..5bf0658e99 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -30,8 +30,8 @@ backends can be installed with: pip install spikeinterface[widgets] -matplotlib -^^^^^^^^^^ +Install matplotlib +^^^^^^^^^^^^^^^^^^ The :code:`matplotlib` backend (default) uses the :code:`matplotlib` package to generate static figures. @@ -41,8 +41,8 @@ To install it, run: pip install matplotlib -ipywidgets -^^^^^^^^^^ +Install ipywidgets +^^^^^^^^^^^^^^^^^^ The :code:`ipywidgets` backend allows users to interact with the plot, for example, by selecting units or scrolling through a time series. @@ -62,8 +62,8 @@ To enable interactive widgets in your notebook, add and run a cell with: .. _sorting_view: -sortingview -^^^^^^^^^^^ +Install sortingview +^^^^^^^^^^^^^^^^^^^ The :code:`sortingview` backend generates web-based and shareable links that can be viewed in the browser. @@ -89,8 +89,8 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr `kachery-cloud `_ package ("Using your own storage bucket"). -ephyviewer -^^^^^^^^^^ +Install ephyviewer +^^^^^^^^^^^^^^^^^^ This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install. diff --git a/doc/references.rst b/doc/references.rst index 5fbcbecb63..ce4672a9ca 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -58,7 +58,7 @@ important for your research: - :code:`rp_violation` [Llobet]_ - :code:`sd_ratio` [Pouzat]_ - :code:`snr` [Lemon]_ [Jackson]_ -- :code:`synchrony` [Grun]_ +- :code:`synchrony` [Grün]_ If you use the :code:`qualitymetrics.pca_metrics` module, i.e. you use the :code:`compute_pc_metrics()` method, please include the citations for the :code:`metric_names` that were particularly @@ -78,7 +78,7 @@ References .. [Buccino] `SpikeInterface, a unified framework for spike sorting. 2020. `_ -.. [Buzsaki] `The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations. 2014. `_ +.. [Buzsáki] `The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations. 2014. `_ .. [Chaure] `A novel and fully automatic spike-sorting implementation with variable number of features. 2018. `_ @@ -88,7 +88,7 @@ References .. [Garcia] `A Modular Implementation to Handle and Benchmark Drift Correction for High-Density Extracellular Recordings. 2024. `_ -.. [Grun] `Impact of higher-order correlations on coincidence distributions of massively parallel data. 2007. `_ +.. [Grün] `Impact of higher-order correlations on coincidence distributions of massively parallel data. 2007. `_ .. [Harris] `Temporal interaction between single spikes and complex spike bursts in hippocampal pyramidal cells. 2001. `_ diff --git a/doc/releases/0.100.0.rst b/doc/releases/0.100.0.rst index d39b5569da..6bd8a33989 100644 --- a/doc/releases/0.100.0.rst +++ b/doc/releases/0.100.0.rst @@ -24,8 +24,7 @@ Main changes: core: * Add `Templates` class (#1982) -* Use python methods instead of parsing and eleminate try-except in to_dict -(#2157) +* Use python methods instead of parsing and eleminate try-except in to_dict (#2157) * `WaveformExtractor.is_extension` --> `has_extension` (#2158) * Speed improvement to `get_empty_units()` (#2173) * Allow precomputing spike trains (#2175) diff --git a/examples/how_to/benchmark_with_hybrid_recordings.py b/examples/how_to/benchmark_with_hybrid_recordings.py index 5507ab7a7f..abf6a25ff5 100644 --- a/examples/how_to/benchmark_with_hybrid_recordings.py +++ b/examples/how_to/benchmark_with_hybrid_recordings.py @@ -32,7 +32,7 @@ import spikeinterface.generation as sgen import spikeinterface.widgets as sw -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion import estimate_motion import numpy as np import matplotlib.pyplot as plt diff --git a/examples/how_to/drift_with_lfp.py b/examples/how_to/drift_with_lfp.py new file mode 100644 index 0000000000..66a31bd6f2 --- /dev/null +++ b/examples/how_to/drift_with_lfp.py @@ -0,0 +1,112 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: ipynb,py +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.2 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Estimate drift using the LFP traces +# +# Drift is a well known issue for long shank probes. Some datasets, especially from primates and humans, can experience very fast motion due to breathing and heart beats. In these cases, the standard motion estimation methods that use detected spikes as a basis for motion inference will fail, because there are not enough spikes to "follow" such fast drifts. +# +# Charlie Windolf and colleagues from the Paninski Lab at Columbia have developed a method to estimate the motion using the LFP signal: **DREDge**. (more details about the method in the paper [DREDge: robust motion correction for high-density extracellular recordings across species](https://doi.org/10.1101/2023.10.24.563768)). +# +# This method is particularly suited for the open dataset recorded at Massachusetts General Hospital by Angelique Paulk and colleagues in humans (more details in the [paper](https://doi.org/10.1038/s41593-021-00997-0)). The dataset can be dowloaed from [datadryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.d2547d840) and it contains recordings on human patients with a Neuropixels probe, some of which with very high and fast motion on the probe, which prevents accurate spike sorting without a proper and adequate motion correction +# +# The **DREDge** method has two options: **dredge_lfp** and **dredge_ap**, which have both been ported inside `SpikeInterface`. +# +# Here we will demonstrate the **dredge_lfp** method to estimate the fast and high drift on this recording. +# +# For each patient, the dataset contains two streams: +# +# * a highpass "action potential" (AP), sampled at 30kHz +# * a lowpass "local field" (LF) sampled at 2.5kHz +# +# For this demonstration, we will use the LF stream. + +# %matplotlib inline +# %load_ext autoreload +# %autoreload 2 + +# + +from pathlib import Path +import matplotlib.pyplot as plt + +import spikeinterface.full as si +from spikeinterface.sortingcomponents.motion import estimate_motion +# - + +# the dataset has been downloaded locally +base_folder = Path("/mnt/data/sam/DataSpikeSorting/") +np_data_drift = base_folder / 'human_neuropixel" / "Pt02" + +# ### Read the spikeglx file + +raw_rec = si.read_spikeglx(np_data_drift) +print(raw_rec) + +# ### Preprocessing +# +# Contrary to the **dredge_ap** approach, which needs detected peaks and peak locations, the **dredge_lfp** method is estimating the motion directly on traces. +# Importantly, the method requires some additional pre-processing steps: +# * `bandpass_filter`: to "focus" the signal on a particular band +# * `phase_shift`: to compensate for the sampling misalignement +# * `resample`: to further reduce the sampling fequency of the signal and speed up the computation. The sampling frequency of the estimated motion will be the same as the resampling frequency. Here we choose 250Hz, which corresponds to a sampling interval of 4ms. +# * `directional_derivative`: this optional step applies a second order derivative in the spatial dimension to enhance edges on the traces. +# This is not a general rules and need to be tested case by case. +# * `average_across_direction`: Neuropixels 1.0 probes have two contacts per depth. This steps averages them to obtain a unique virtual signal along the probe depth ("y" in `spikeinterface`). +# +# After appying this preprocessing chain, the motion can be estimated almost by eyes ont the traces plotted with the map mode. + +# + +lfprec = si.bandpass_filter( + raw_rec, + freq_min=0.5, + freq_max=250, + margin_ms=1500., + filter_order=3, + dtype="float32", + add_reflect_padding=True, +) +lfprec = si.phase_shift(lfprec) +lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000) + +lfprec = si.directional_derivative(lfprec, order=2, edge_order=1) +lfprec = si.average_across_direction(lfprec) + +print(lfprec) +# - + +# %matplotlib inline +si.plot_traces(lfprec, backend="matplotlib", mode="map", clim=(-0.05, 0.05), time_range=(400, 420)) + +# ### Run the method +# +# `estimate_motion()` is the generic function to estimate motion with multiple methods in `spikeinterface`. +# +# This function returns a `Motion` object and we can notice that the interval is exactly the same as downsampled signal. +# +# Here we use `rigid=True`, which means that we have one unqiue signal to describe the motion across the entire probe depth. + +motion = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True) +motion + +# ### Plot the drift +# +# When plotting the drift, we can notice a very fast drift which corresponds to the heart rate. The slower oscillations can be attributed to the breathing signal. +# +# We can appreciate how the estimated motion signal matches the processed LFP traces plotted above. + +fig, ax = plt.subplots() +si.plot_motion(motion, mode='line', ax=ax) +ax.set_xlim(400, 420) +ax.set_ylim(800, 1300) diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 79a7c899f5..ecf17a1b1f 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -167,7 +167,7 @@ def preprocess_chain(rec): # Case 1 is used before running a spike sorter and the case 2 is used here to display the results. # + -from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks +from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks for preset in some_presets: folder = base_folder / "motion_folder_dataset1" / preset diff --git a/examples/tutorials/README.rst b/examples/tutorials/README.rst index 0697f0ee60..3f7dd10ecc 100644 --- a/examples/tutorials/README.rst +++ b/examples/tutorials/README.rst @@ -3,7 +3,7 @@ Tutorials Longer form tutorials about using spikeinterface. Many of these are downloadable as notebooks or python scripts so that you can "code along" to the tutorials. -If you're new to SpikeInterface, we recommend trying out the :ref:`Quickstart tutorial` first. +If you're new to SpikeInterface, we recommend trying out the :ref:`get_started/quickstart:Quickstart tutorial` first. Updating from legacy -------------------- diff --git a/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py b/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py new file mode 100644 index 0000000000..87eaa4c51a --- /dev/null +++ b/examples/tutorials/sortingcomponents/plot_1_estimate_motion.py @@ -0,0 +1,103 @@ +""" +Motion estimation +================= + +SpikeInterface offers a very flexible framework to handle drift as a +preprocessing step. If you want to know more, please read the +:ref:`motion_correction` section of the documentation. + +Here a short example with a simulated drifting recording. + +""" + +# %% +import matplotlib.pyplot as plt + + +from spikeinterface.generation import generate_drifting_recording +from spikeinterface.preprocessing import correct_motion +from spikeinterface.widgets import plot_motion, plot_motion_info, plot_probe_map + +# %% +# First, let's simulate a drifting recording using the +# :code:`spikeinterface.generation module`. +# +# Here the simulated recording has a small zigzag motion along the 'y' axis of the probe. + +static_recording, drifting_recording, sorting = generate_drifting_recording( + num_units=200, + duration=300., + probe_name='Neuropixel-128', + generate_displacement_vector_kwargs=dict( + displacement_sampling_frequency=5.0, + drift_start_um=[0, 20], + drift_stop_um=[0, -20], + drift_step_um=1, + motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=None, + t_start_drift=60.0, + t_end_drift=None, + period_s=200, + ), + ], + ), + seed=2205, +) + +plot_probe_map(drifting_recording) + +# %% +# Here we will use the high level function :code:`correct_motion()` +# +# Internally, this function is doing all steps of the motion detection: +# 1. **activity profile** : detect peaks and localize them along time and depth +# 2. **motion inference**: estimate the drift motion +# 3. **motion interpolation**: interpolate traces using the estimated motion +# +# All steps have an use several methods with many parameters. This is why we can use +# 'preset' which combine methods and related parameters. +# +# This function can take a while peak detection and localization is a slow process +# that need to go trought the entire traces + +recording_corrected, motion, motion_info = correct_motion( + drifting_recording, preset="nonrigid_fast_and_accurate", + output_motion=True, output_motion_info=True, + n_jobs=-1, progress_bar=True, +) + +# %% +# The function return a recording 'corrected' +# +# A new recording is return, this recording will interpolate motion corrected traces +# when calling get_traces() + +print(recording_corrected) + +# %% +# Optionally the function also return the `Motion` object itself +# + +print(motion) + +# %% +# This motion can be plotted, in our case the motion has been estimated as non-rigid +# so we can use the use the `mode='map'` to check the motion across depth. +# + +plot_motion(motion, mode='line') +plot_motion(motion, mode='map') + + +# %% +# The dict `motion_info` can be used for more plotting. +# Here we can appreciate of the two top axes the raster of peaks depth vs times before and +# after correction. + +fig = plt.figure() +plot_motion_info(motion_info, drifting_recording, amplitude_cmap="inferno", color_amplitude=True, figure=fig) +fig.axes[0].set_ylim(520, 620) +plt.show() +# %% diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1de5ad68ac..72c0a2c2fe 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -46,8 +46,8 @@ class BaseExtractor: # these properties are skipped by default in copy_metadata _skip_properties = [] - installed = True installation_mesg = "" + installed = True def __init__(self, main_ids: Sequence) -> None: # store init kwargs for nested serialisation diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 996718dc42..42d5561547 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -75,7 +75,7 @@ class SIJsonEncoder(json.JSONEncoder): def default(self, obj): from spikeinterface.core.base import BaseExtractor - from spikeinterface.sortingcomponents.motion_utils import Motion + from spikeinterface.sortingcomponents.motion.motion_utils import Motion # Over-write behaviors for datetime object if isinstance(obj, datetime.datetime): diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 6ce94114c4..db57d028f7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1677,6 +1677,7 @@ class InjectTemplatesRecording(BaseRecording): templates: np.ndarray[n_units, n_samples, n_channels] or np.ndarray[n_units, n_samples, n_oversampling] Array containing the templates to inject for all the units. Shape can be: + * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce sampling jitter. nbefore: list[int] | int | None, default: None @@ -2057,6 +2058,7 @@ def generate_ground_truth_recording( The templates of units. If None they are generated. Shape can be: + * (num_units, num_samples, num_channels): standard case * (num_units, num_samples, num_channels, upsample_factor): case with oversample template to introduce jitter. ms_before: float, default: 1.5 diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5e1856e7cd..89e9e2cf0f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -718,7 +718,7 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ - This method is equivalent to `save_as()`but with a subset of units. + This method is equivalent to `save_as()` but with a subset of units. Filters units by creating a new sorting analyzer object in a new folder. Extensions are also updated to filter the selected unit ids. @@ -876,11 +876,10 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar ---------- input : str or dict or list The extensions to compute, which can be passed as: - * a string: compute one extension. Additional parameters can be passed as key word arguments. * a dict: compute several extensions. The keys are the extension names and the values are dictiopnaries with the extension parameters. * a list: compute several extensions. The list contains the extension names. Additional parameters can be passed with the extension_params - argument. + argument. save : bool, default: True If True the extension is saved to disk (only if sorting analyzer format is not "memory") extension_params : dict or None, default: None @@ -909,10 +908,11 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar Compute two extensions with an input list specifying custom parameters for one (the other will use default parameters): - >>> analyzer.compute( - ["random_spikes", "waveforms"], - extension_params={"waveforms":{"ms_before":1.5, "ms_after", "2.5"}} - ) + >>> analyzer.compute(\ +["random_spikes", "waveforms"],\ +extension_params={"waveforms":{"ms_before":1.5, "ms_after": "2.5"}}\ +) + """ if isinstance(input, str): return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs) diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 2e70d5ba41..b3f671ebf3 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -25,14 +25,30 @@ class AlphaOmegaRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "AlphaOmegaRawIO" - def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=None, all_annotations=False): + def __init__( + self, + folder_path, + lsx_files=None, + stream_id="RAW", + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(folder_path, lsx_files) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), lsx_files=lsx_files)) diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index e086cb5dde..adfdccddd9 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -19,13 +19,21 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): The file path to load the recordings from. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "AxonaRawIO" - def __init__(self, file_path, all_annotations=False): + def __init__(self, file_path: str | Path, all_annotations: bool = False, use_names_as_ids: bool = False): neo_kwargs = self.map_to_neo_kwargs(file_path) - NeoBaseRecordingExtractor.__init__(self, all_annotations=all_annotations, **neo_kwargs) + NeoBaseRecordingExtractor.__init__( + self, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, + ) self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index e7b6199ea9..4a4600853c 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -29,6 +29,9 @@ class BiocamRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "BiocamRawIO" @@ -40,11 +43,17 @@ def __init__( electrode_width=None, stream_id=None, stream_name=None, - all_annotations=False, + all_annotations: bool = False, + use_names_as_ids: bool = False, ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) # load probe from probeinterface diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 9bd2b05f24..8557c811b5 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -27,7 +27,8 @@ class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): all_annotations : bool, default: False Load exhaustively all annotations from neo. use_names_as_ids : bool, default: False - If False, use default IDs inherited from Neo. If True, use channel names as IDs. + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ @@ -38,8 +39,8 @@ def __init__( file_path, stream_id=None, stream_name=None, - all_annotations=False, - use_names_as_ids=False, + all_annotations: bool = False, + use_names_as_ids: bool = False, ): neo_kwargs = self.map_to_neo_kwargs(file_path) neo_kwargs["load_nev"] = False # Avoid loading spikes release in neo 0.12.0 diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 73a783ec5d..a42a2d75a5 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -25,14 +25,24 @@ class CedRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "CedRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, file_path, stream_id=None, stream_name=None, all_annotations: bool = False, use_names_as_ids: bool = False + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) self.extra_requirements.append("neo[ced]") diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 8369369922..5d36067d97 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -24,14 +24,29 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations: bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "EDFRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = {"filename": str(file_path)} NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("neo[edf]") diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 34c8bf2eb5..f0a1894f25 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -28,7 +28,11 @@ class IntanRecordingExtractor(NeoBaseRecordingExtractor): check we perform is that timestamps are continuous. Setting this to True will ignore this check and set the attribute `discontinuous_timestamps` to True in the underlying neo object. use_names_as_ids : bool, default: False - If False, use default IDs inherited from Neo. If True, use channel names as IDs. + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + In Intan the ids provided by NeoRawIO are the hardware channel ids while the names are custom names given by + the user """ diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 6c72696e16..58110cf7aa 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -30,6 +30,9 @@ class MaxwellRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. rec_name : str, default: None When the file contains several recordings you need to specify the one you want to extract. (rec_name='rec0000'). @@ -50,6 +53,7 @@ def __init__( all_annotations=False, rec_name=None, install_maxwell_plugin=False, + use_names_as_ids: bool = False, ): if install_maxwell_plugin: self.install_maxwell_plugin() @@ -61,6 +65,7 @@ def __init__( stream_name=stream_name, block_index=block_index, all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 307a6c1fba..a50c10907f 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -28,11 +28,22 @@ class MCSRawRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks, specify the block index you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "RawMCSRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations=False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, @@ -40,6 +51,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None stream_name=stream_name, block_index=block_index, all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 21a597029b..5a8dba3c16 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -38,13 +38,21 @@ class MEArecRecordingExtractor(NeoBaseRecordingExtractor): The file path to load the recordings from. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "MEArecRawIO" - def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): + def __init__(self, file_path: Union[str, Path], all_annotations: bool = False, use_names_as_ids: bool = False): neo_kwargs = self.map_to_neo_kwargs(file_path) - NeoBaseRecordingExtractor.__init__(self, all_annotations=all_annotations, **neo_kwargs) + NeoBaseRecordingExtractor.__init__( + self, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, + ) self.extra_requirements.append("mearec") diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index c30c6b94f0..a916d140fb 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -206,8 +206,8 @@ def __init__( if stream_id is None and stream_name is None: if stream_channels.size > 1: raise ValueError( - f"This reader have several streams: \nNames: {stream_names}\nIDs: {stream_ids}. " - f"Specify it with the 'stream_name' or 'stream_id' arguments" + f"This reader have several streams: \nNames: {stream_names}\nIDs: {stream_ids}. \n" + f"Specify it from the options above with the 'stream_name' or 'stream_id' arguments" ) else: stream_id = stream_ids[0] @@ -276,7 +276,7 @@ def __init__( self.set_property("gain_to_uV", final_gains) self.set_property("offset_to_uV", final_offsets) - if not use_names_as_ids and not all_annotations: + if not use_names_as_ids: self.set_property("channel_names", signal_channels["name"]) if all_annotations: @@ -287,13 +287,26 @@ def __init__( seg_ann = block_ann["segments"][0] sig_ann = seg_ann["signals"][self.stream_index] - # scalar annotations - for k, v in sig_ann.items(): - if not k.startswith("__"): - self.set_annotation(k, v) + scalar_annotations = {name: value for name, value in sig_ann.items() if not name.startswith("__")} + + # name in neo corresponds to stream name + # We don't propagate the name as an annotation because that has a differnt meaning on spikeinterface + stream_name = scalar_annotations.pop("name", None) + if stream_name: + self.set_annotation(annotation_key="stream_name", value=stream_name) + for annotation_key, value in scalar_annotations.items(): + self.set_annotation(annotation_key=annotation_key, value=value) + + array_annotations = sig_ann["__array_annotations__"] + # We do not add this because is confusing for the user to have this repeated + array_annotations.pop("channel_ids", None) + # This is duplicated when using channel_names as ids + if use_names_as_ids: + array_annotations.pop("channel_names", None) + # vector array_annotations are channel properties - for k, values in sig_ann["__array_annotations__"].items(): - self.set_property(k, values) + for key, values in array_annotations.items(): + self.set_property(key=key, values=values) nseg = self.neo_reader.segment_count(block_index=self.block_index) for segment_index in range(nseg): diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 98f4a7c2ff..5c70028071 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -29,6 +29,9 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): exclude_filename : list[str], default: None List of filename to exclude from the loading. For example, use `exclude_filename=["events.nev"]` to skip loading the event file. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. strict_gap_mode : bool, default: False See neo documentation. Detect gaps using strict mode or not. @@ -44,16 +47,22 @@ class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): def __init__( self, - folder_path, + folder_path: str | Path, stream_id=None, stream_name=None, all_annotations=False, exclude_filename=None, strict_gap_mode=False, + use_names_as_ids: bool = False, ): neo_kwargs = self.map_to_neo_kwargs(folder_path, exclude_filename, strict_gap_mode) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update( dict(folder_path=str(Path(folder_path).absolute()), exclude_filename=exclude_filename), diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index ac569c0df0..8afc75f773 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -45,14 +45,24 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "NeuroExplorerRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, file_path, stream_id=None, stream_name=None, all_annotations: bool = False, use_names_as_ids: bool = False + ): neo_kwargs = {"filename": str(file_path)} NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("neo[edf]") diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 6c6f1d4bea..70a110eced 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -35,15 +35,31 @@ class NeuroScopeRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "NeuroScopeRawIO" - def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, + file_path, + xml_file_path=None, + stream_id=None, + stream_name: bool = None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(file_path, xml_file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) if xml_file_path is not None: xml_file_path = str(Path(xml_file_path).absolute()) diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index b869936fa3..baae573250 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -25,11 +25,22 @@ class NixRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks, specify the block index you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "NIXRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( self, @@ -37,6 +48,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None stream_name=stream_name, block_index=block_index, all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 04c25998f0..24bc7591e4 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -59,6 +59,9 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor): If there are several blocks (experiments), specify the block index you want to load all_annotations : bool, default: False Load exhaustively all annotation from neo + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. ignore_timestamps_errors : None Deprecated keyword argument. This is now ignored. neo.OpenEphysRawIO is now handling gaps directly but makes the read slower. @@ -72,8 +75,9 @@ def __init__( stream_id=None, stream_name=None, block_index=None, - all_annotations=False, - ignore_timestamps_errors=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ignore_timestamps_errors: bool = None, ): if ignore_timestamps_errors is not None: warnings.warn( @@ -88,6 +92,7 @@ def __init__( stream_name=stream_name, block_index=block_index, all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 9c2586dd5a..0adddc2439 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -15,7 +15,7 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): Parameters ---------- - file_path : str + file_path : str | Path The file path to load the recordings from. stream_id : str, default: None If there are several streams, specify the stream id you want to load. @@ -23,14 +23,33 @@ class PlexonRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: True + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + Example for wideband signals: + names: ["WB01", "WB02", "WB03", "WB04"] + ids: ["0" , "1", "2", "3"] """ NeoRawIOClass = "PlexonRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, + file_path: str | Path, + stream_id=None, + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = True, + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update({"file_path": str(Path(file_path).resolve())}) diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index cbc1db3f74..57fd4dbad7 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -24,14 +24,24 @@ class Spike2RecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "Spike2RawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, file_path, stream_id=None, stream_name=None, all_annotations=False, use_names_as_ids: bool = False + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("sonpy") diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index e7c31b8afa..89c457a573 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -26,14 +26,29 @@ class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "SpikeGadgetsRawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 10a1f78265..cfe20bbfa6 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -38,14 +38,30 @@ class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. """ NeoRawIOClass = "SpikeGLXRawIO" - def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_name=None, all_annotations=False): + def __init__( + self, + folder_path, + load_sync_channel=False, + stream_id=None, + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(folder_path, load_sync_channel=load_sync_channel) NeoBaseRecordingExtractor.__init__( - self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, ) # open the corresponding stream probe for LF and AP diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index a1298dece7..803b21ab23 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -23,13 +23,24 @@ class TdtRecordingExtractor(NeoBaseRecordingExtractor): If there are several streams, specify the stream name you want to load. all_annotations : bool, default: False Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. block_index : int, default: None If there are several blocks (experiments), specify the block index you want to load """ NeoRawIOClass = "TdtRawIO" - def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=None, all_annotations=False): + def __init__( + self, + folder_path, + stream_id=None, + stream_name=None, + block_index=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): neo_kwargs = self.map_to_neo_kwargs(folder_path) NeoBaseRecordingExtractor.__init__( self, @@ -37,6 +48,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No stream_name=stream_name, block_index=block_index, all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, **neo_kwargs, ) self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 7164afeac6..d797e64910 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -497,12 +497,13 @@ class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): >>> from dandi.dandiapi import DandiAPIClient >>> >>> # get s3 path - >>> dandiset_id, filepath = "101116", "sub-001/sub-001_ecephys.nwb" - >>> with DandiAPIClient("https://api-staging.dandiarchive.org/api") as client: - >>> asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(filepath) + >>> dandiset_id = "001054" + >>> filepath = "sub-Dory/sub-Dory_ses-2020-09-14-004_ecephys.nwb" + >>> with DandiAPIClient() as client: + >>> asset = client.get_dandiset(dandiset_id).get_asset_by_path(filepath) >>> s3_url = asset.get_content_url(follow_redirects=1, strip_query=True) >>> - >>> rec = NwbRecordingExtractor(s3_url, stream_mode="fsspec", stream_cache_path="cache") + >>> rec = NwbRecordingExtractor(s3_url, stream_mode="remfile") """ installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index a0e8ece37e..b439c57c52 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -339,11 +339,13 @@ def generate_drifting_recording( Same for both recordings. extra_infos: If extra_outputs=True, then return also a dict that contain various information like: + * displacement_vectors * displacement_sampling_frequency * unit_locations * displacement_unit_factor * unit_displacements + This can be helpfull for motion benchmark. """ diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index a57e090f5f..2806754c9d 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -15,7 +15,7 @@ ) from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.generation.drift_tools import ( InjectDriftingTemplatesRecording, @@ -400,6 +400,7 @@ def generate_hybrid_recording( num_segments = recording.get_num_segments() dtype = recording.dtype durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)]) + num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)]) channel_locations = probe.contact_positions assert ( @@ -548,7 +549,7 @@ def generate_hybrid_recording( displacement_vectors=displacement_vectors, displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - num_samples=(np.array(durations) * sampling_frequency).astype("int64"), + num_samples=num_samples.astype("int64"), amplitude_factor=amplitude_factor, ) diff --git a/src/spikeinterface/generation/tests/test_hybrid_tools.py b/src/spikeinterface/generation/tests/test_hybrid_tools.py index d31a0ec81d..bdcd8dbb8f 100644 --- a/src/spikeinterface/generation/tests/test_hybrid_tools.py +++ b/src/spikeinterface/generation/tests/test_hybrid_tools.py @@ -7,7 +7,7 @@ generate_templates, generate_unit_locations, ) -from spikeinterface.preprocessing.motion import correct_motion, load_motion_info +from spikeinterface.preprocessing.motion import correct_motion from spikeinterface.generation.hybrid_tools import ( estimate_templates_from_recording, generate_hybrid_recording, @@ -35,8 +35,10 @@ def test_generate_hybrid_with_sorting(): def test_generate_hybrid_motion(): - rec, _ = generate_ground_truth_recording(sampling_frequency=20000, durations=[10], seed=0) - _, motion_info = correct_motion(rec, output_motion_info=True) + rec, _ = generate_ground_truth_recording(sampling_frequency=20000, durations=[10], num_channels=16, seed=0) + _, motion_info = correct_motion( + rec, output_motion_info=True, estimate_motion_kwargs={"win_step_um": 20, "win_scale_um": 20} + ) motion = motion_info["motion"] hybrid, sorting_hybrid = generate_hybrid_recording(rec, motion=motion, seed=0) assert rec.get_num_channels() == hybrid.get_num_channels() diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 8eb375e90b..6252c0582b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -493,7 +493,7 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar): try: proj = pca_model.transform(wfs[:, :, wf_ind]) pca_projection[:, :, wf_ind][spike_mask, :] = proj - except NotFittedError as e: + except: # this could happen if len(wfs) is less then n_comp for a channel project_on_non_fitted = True if project_on_non_fitted: diff --git a/src/spikeinterface/preprocessing/align_snippets.py b/src/spikeinterface/preprocessing/align_snippets.py index c37f8f2a97..02457d24a7 100644 --- a/src/spikeinterface/preprocessing/align_snippets.py +++ b/src/spikeinterface/preprocessing/align_snippets.py @@ -8,9 +8,7 @@ class AlignSnippets(BaseSnippets): - installed = True # check at class level if installed or not installation_mesg = "" # err - name = "align_snippets" def __init__(self, snippets, new_nbefore, new_nafter, mode="main_peak", interpolate=1, det_sign=0): assert isinstance(snippets, BaseSnippets), "'snippets' must be a SnippetsExtractor" diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 53f0d54147..ee2083d3c4 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -7,8 +7,6 @@ class AverageAcrossDirectionRecording(BaseRecording): - name = "average_across_direction" - installed = True def __init__( self, diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 106f5e2d92..3b73b306bb 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -4,7 +4,6 @@ class BasePreprocessor(BaseRecording): - installed = True # check at class level if installed or not installation_mesg = "" # err def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=None): diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 78557c70d0..47a4a20d21 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -30,8 +30,6 @@ class ClipRecording(BasePreprocessor): The clipped traces recording extractor object """ - name = "clip" - def __init__(self, recording, a_min=None, a_max=None): value_min = a_min value_max = a_max @@ -86,8 +84,6 @@ class BlankSaturationRecording(BasePreprocessor): """ - name = "blank_staturation" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 93d0448ef4..b9bc1b4b53 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -76,8 +76,6 @@ class CommonReferenceRecording(BasePreprocessor): """ - name = "common_reference" - def __init__( self, recording: BaseRecording, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index aa5c600182..334ebb02d2 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -41,8 +41,6 @@ class DecimateRecording(BasePreprocessor): """ - name = "decimate" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 31ebb90831..f58bc5b578 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -49,8 +49,6 @@ class DeepInterpolatedRecording(BasePreprocessor): The deepinterpolated recording extractor object """ - name = "deepinterpolate" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index f08f6404da..a112774fb1 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -25,9 +25,6 @@ class DepthOrderRecording(ChannelSliceRecording): If flip is True then the order is upper first. """ - name = "depth_order" - installed = True - def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): order_f, order_r = order_channels_by_depth( parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index f8aeac05fc..3a6a480f59 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -7,8 +7,6 @@ class DirectionalDerivativeRecording(BasePreprocessor): - name = "directional_derivative" - installed = True def __init__( self, diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 93462ac5d8..54c5ab2b2d 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -63,8 +63,6 @@ class FilterRecording(BasePreprocessor): The filtered recording extractor object """ - name = "filter" - def __init__( self, recording, @@ -193,8 +191,6 @@ class BandpassFilterRecording(FilterRecording): The bandpass-filtered recording extractor object """ - name = "bandpass_filter" - def __init__(self, recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0, dtype=None, **filter_kwargs): FilterRecording.__init__( self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs @@ -228,8 +224,6 @@ class HighpassFilterRecording(FilterRecording): The highpass-filtered recording extractor object """ - name = "highpass_filter" - def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filter_kwargs): FilterRecording.__init__( self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs @@ -260,8 +254,6 @@ class NotchFilterRecording(BasePreprocessor): The notch-filtered recording extractor object """ - name = "notch_filter" - def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): # coeef is 'ba' type fn = 0.5 * float(recording.get_sampling_frequency()) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 1db7d45bd8..b16df9be69 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -40,8 +40,6 @@ class GaussianFilterRecording(BasePreprocessor): The filtered recording extractor object. """ - name = "gaussian_filter" - def __init__( self, recording: BaseRecording, freq_min: float = 300.0, freq_max: float = 5000.0, margin_sd: float = 5.0 ): diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 881ca26a07..903fef0b6e 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -39,8 +39,6 @@ class FilterOpenCLRecording(BasePreprocessor): """ - name = "filter" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index c0bf869317..86836f262b 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -64,8 +64,6 @@ class HighpassSpatialFilterRecording(BasePreprocessor): https://www.internationalbrainlab.com/repro-ephys """ - name = "highpass_spatial_filter" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index d60c9b27dd..508868e0bb 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -41,8 +41,6 @@ class InterpolateBadChannelsRecording(BasePreprocessor): The recording object with interpolated bad channels """ - name = "interpolate_bad_channels" - def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=None): BasePreprocessor.__init__(self, recording) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 57fe609e91..8e9911b47e 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -19,8 +19,8 @@ method="locally_exclusive", peak_sign="neg", detect_threshold=8.0, - exclude_sweep_ms=0.1, - radius_um=50, + exclude_sweep_ms=0.8, + radius_um=80.0, ), "select_kwargs": dict(), "localize_peaks_kwargs": dict( @@ -35,16 +35,13 @@ "estimate_motion_kwargs": dict( method="decentralized", direction="y", - bin_duration_s=2.0, + bin_s=1.0, rigid=False, bin_um=5.0, - margin_um=0.0, - # win_shape="gaussian", - # win_step_um=50.0, - # win_sigma_um=150.0, + hist_margin_um=20.0, win_shape="gaussian", - win_step_um=100.0, - win_sigma_um=200.0, + win_step_um=200.0, + win_scale_um=300.0, histogram_depth_smooth_um=5.0, histogram_time_smooth_s=None, pairwise_displacement_method="conv", @@ -78,13 +75,14 @@ method="locally_exclusive", peak_sign="neg", detect_threshold=8.0, - exclude_sweep_ms=0.5, - radius_um=50, + exclude_sweep_ms=0.8, + radius_um=80.0, ), "select_kwargs": dict(), "localize_peaks_kwargs": dict( method="grid_convolution", - radius_um=40.0, + # radius_um=40.0, + radius_um=80.0, upsampling_um=5.0, sigma_ms=0.25, margin_um=30.0, @@ -94,16 +92,14 @@ "estimate_motion_kwargs": dict( method="decentralized", direction="y", - bin_duration_s=2.0, + bin_s=2.0, rigid=False, bin_um=5.0, - margin_um=0.0, - # win_shape="gaussian", - # win_step_um=50.0, - # win_sigma_um=150.0, + hist_margin_um=0.0, win_shape="gaussian", win_step_um=100.0, - win_sigma_um=200.0, + win_scale_um=200.0, + win_margin_um=None, histogram_depth_smooth_um=5.0, histogram_time_smooth_s=None, pairwise_displacement_method="conv", @@ -149,7 +145,7 @@ ), "estimate_motion_kwargs": dict( method="decentralized", - bin_duration_s=10.0, + bin_s=10.0, rigid=True, ), "interpolate_motion_kwargs": dict( @@ -179,11 +175,11 @@ ), "estimate_motion_kwargs": dict( method="iterative_template", - bin_duration_s=2.0, + bin_s=2.0, rigid=False, win_step_um=50.0, - win_sigma_um=150.0, - margin_um=0, + win_scale_um=150.0, + hist_margin_um=0, win_shape="rect", ), "interpolate_motion_kwargs": dict( @@ -205,6 +201,7 @@ def correct_motion( recording, preset="nonrigid_accurate", folder=None, + output_motion=False, output_motion_info=False, overwrite=False, detect_kwargs={}, @@ -241,8 +238,8 @@ def correct_motion( * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` - * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion` - * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion` + * :py:func:`~spikeinterface.sortingcomponents.motion.motion.estimate_motion` + * :py:func:`~spikeinterface.sortingcomponents.motion.motion.interpolate_motion` Possible presets : {} @@ -255,6 +252,8 @@ def correct_motion( The preset name folder : Path str or None, default: None If not None then intermediate motion info are saved into a folder + output_motion : bool, default: False + It True, the function returns a `motion` object. output_motion_info : bool, default: False If True, then the function returns a `motion_info` dictionary that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) @@ -279,15 +278,17 @@ def correct_motion( ------- recording_corrected : Recording The motion corrected recording + motion : Motion + Optional output if `output_motion=True`. motion_info : dict - Optional output if `output_motion_info=True`. The key "motion" holds the Motion object. + Optional output if `output_motion_info=True`. This dict contains several variable for + for plotting. See `plot_motion_info()` """ # local import are important because "sortingcomponents" is not important by default from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods - from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording + from spikeinterface.sortingcomponents.motion import estimate_motion, InterpolateMotionRecording from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline # get preset params and update if necessary @@ -395,11 +396,16 @@ def correct_motion( if folder is not None: save_motion_info(motion_info, folder, overwrite=overwrite) - if output_motion_info: - return recording_corrected, motion_info - else: + if not output_motion and not output_motion_info: return recording_corrected + out = (recording_corrected,) + if output_motion: + out += (motion,) + if output_motion_info: + out += (motion_info,) + return out + _doc_presets = "\n" for k, v in motion_options_preset.items(): @@ -431,7 +437,7 @@ def save_motion_info(motion_info, folder, overwrite=False): def load_motion_info(folder): - from spikeinterface.sortingcomponents.motion_utils import Motion + from spikeinterface.sortingcomponents.motion import Motion folder = Path(folder) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index e537be4694..d464c95f4f 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,8 +68,6 @@ class NormalizeByQuantileRecording(BasePreprocessor): The rescaled traces recording extractor object """ - name = "normalize_by_quantile" - def __init__( self, recording, @@ -145,8 +143,6 @@ class ScaleRecording(BasePreprocessor): The transformed traces recording extractor object """ - name = "scale" - def __init__(self, recording, gain=1.0, offset=0.0, dtype="float32"): if dtype is None: dtype = recording.get_dtype() @@ -204,8 +200,6 @@ class CenterRecording(BasePreprocessor): The centered traces recording extractor object """ - name = "center" - def __init__(self, recording, mode="median", dtype="float32", **random_chunk_kwargs): assert mode in ("median", "mean") random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -261,8 +255,6 @@ class ZScoreRecording(BasePreprocessor): The centered traces recording extractor object """ - name = "zscore" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 5d483b3ce2..664964fcf2 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -41,8 +41,6 @@ class PhaseShiftRecording(BasePreprocessor): The phase shifted recording object """ - name = "phase_shift" - def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=None): if inter_sample_shift is None: assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 8f3729b49b..149c6eb458 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -80,5 +80,4 @@ UnsignedToSignedRecording, ] -installed_preprocessers_list = [pp for pp in preprocessers_full_list if pp.installed] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} diff --git a/src/spikeinterface/preprocessing/rectify.py b/src/spikeinterface/preprocessing/rectify.py index 666e0babfd..aea866452b 100644 --- a/src/spikeinterface/preprocessing/rectify.py +++ b/src/spikeinterface/preprocessing/rectify.py @@ -8,7 +8,6 @@ class RectifyRecording(BasePreprocessor): - name = "rectify" def __init__(self, recording): BasePreprocessor.__init__(self, recording) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index d2aef6ba3a..aa1746df25 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -91,8 +91,6 @@ class RemoveArtifactsRecording(BasePreprocessor): The recording extractor after artifact removal """ - name = "remove_artifacts" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index f8324817d4..f076646fdb 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -42,8 +42,6 @@ class ResampleRecording(BasePreprocessor): """ - name = "resample" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 3758d29554..8f38f01469 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -44,8 +44,6 @@ class SilencedPeriodsRecording(BasePreprocessor): The recording extractor after silencing some periods """ - name = "silence_periods" - def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index baa7235263..a1ad3766a9 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -13,7 +13,7 @@ def test_estimate_and_correct_motion(create_cache_folder): if folder.exists(): shutil.rmtree(folder) - rec_corrected = correct_motion(rec, folder=folder) + rec_corrected = correct_motion(rec, folder=folder, estimate_motion_kwargs={"win_step_um": 50, "win_scale_um": 100}) print(rec_corrected) # test reloading motion info diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index b221fd7bed..244fab1bd9 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -20,8 +20,6 @@ class UnsignedToSignedRecording(BasePreprocessor): For example, a `bit_depth` of 12 will correct for an offset of `2**11` """ - name = "unsigned_to_signed" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 96cf5e028f..195969ff79 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -55,8 +55,6 @@ class WhitenRecording(BasePreprocessor): The whitened recording extractor """ - name = "whiten" - def __init__( self, recording, diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 0b2ff9449f..ab1c90dfd9 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -136,8 +136,6 @@ def get_num_samples(self): class ZeroChannelPaddedRecording(BaseRecording): - name = "zero_channel_pad" - installed = True def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: Union[list, None] = None): """Pads a recording with channels that contain only zero. diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 433c04d248..24165da5b3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -534,7 +534,7 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): References ---------- - Based on concepts described in [Gruen]_ + Based on concepts described in [Grün]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ @@ -581,7 +581,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ References ---------- - Based on concepts described in [Gruen]_ + Based on concepts described in [Grün]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2915cee8ec..fa1940c2ba 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -17,7 +17,6 @@ from ..core import get_random_data_chunks, compute_sparsity from ..core.template_tools import get_template_extremum_channel - _possible_pc_metric_names = [ "isolation_distance", "l_ratio", @@ -90,7 +89,7 @@ def compute_pc_metrics( sorting = sorting_analyzer.sorting if metric_names is None: - metric_names = _possible_pc_metric_names + metric_names = _possible_pc_metric_names.copy() if qm_params is None: qm_params = _default_params @@ -110,8 +109,13 @@ def compute_pc_metrics( if "nn_isolation" in metric_names: pc_metrics["nn_unit_id"] = {} + possible_nn_metrics = ["nn_isolation", "nn_noise_overlap"] + + nn_metrics = list(set(metric_names).intersection(possible_nn_metrics)) + non_nn_metrics = list(set(metric_names).difference(possible_nn_metrics)) + # Compute nspikes and firing rate outside of main loop for speed - if any([n in metric_names for n in ["nn_isolation", "nn_noise_overlap"]]): + if nn_metrics: n_spikes_all_units = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) fr_all_units = compute_firing_rates(sorting_analyzer, unit_ids=unit_ids) else: @@ -120,9 +124,6 @@ def compute_pc_metrics( run_in_parallel = n_jobs > 1 - if run_in_parallel: - parallel_functions = [] - # this get dense projection for selected unit_ids dense_projections, spike_unit_indices = pca_ext.get_some_projections(channel_ids=None, unit_ids=unit_ids) all_labels = sorting.unit_ids[spike_unit_indices] @@ -146,7 +147,7 @@ def compute_pc_metrics( func_args = ( pcs_flat, labels, - metric_names, + non_nn_metrics, unit_id, unit_ids, qm_params, @@ -156,16 +157,16 @@ def compute_pc_metrics( ) items.append(func_args) - if not run_in_parallel: + if not run_in_parallel and non_nn_metrics: units_loop = enumerate(unit_ids) if progress_bar: - units_loop = tqdm(units_loop, desc="calculate_pc_metrics", total=len(unit_ids)) + units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) for unit_ind, unit_id in units_loop: pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - else: + elif run_in_parallel and non_nn_metrics: with ProcessPoolExecutor(n_jobs) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: @@ -176,6 +177,37 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric + for metric_name in nn_metrics: + units_loop = enumerate(unit_ids) + if progress_bar: + units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) + + func = _nn_metric_name_to_func[metric_name] + metric_params = qm_params[metric_name] if metric_name in qm_params else {} + + for _, unit_id in units_loop: + try: + res = func( + sorting_analyzer, + unit_id, + seed=seed, + n_spikes_all_units=n_spikes_all_units, + fr_all_units=fr_all_units, + **metric_params, + ) + except: + if metric_name == "nn_isolation": + res = (np.nan, np.nan) + elif metric_name == "nn_noise_overlap": + res = np.nan + + if metric_name == "nn_isolation": + nn_isolation, nn_unit_id = res + pc_metrics["nn_isolation"][unit_id] = nn_isolation + pc_metrics["nn_unit_id"][unit_id] = nn_unit_id + elif metric_name == "nn_noise_overlap": + pc_metrics["nn_noise_overlap"][unit_id] = res + return pc_metrics @@ -677,6 +709,14 @@ def nearest_neighbors_noise_overlap( templates_ext = sorting_analyzer.get_extension("templates") assert templates_ext is not None, "nearest_neighbors_isolation() need extension 'templates'" + try: + sorting_analyzer.get_extension("templates").get_data(operator="median") + except KeyError: + warnings.warn( + "nearest_neighbors_isolation() need extension 'templates' calculated with the 'median' operator." + "You can run sorting_analyzer.compute('templates', operators=['average', 'median']) to calculate templates based on both average and median modes." + ) + if n_spikes_all_units is None: n_spikes_all_units = compute_num_spikes(sorting_analyzer) if fr_all_units is None: @@ -955,11 +995,13 @@ def pca_metrics_one_unit(args): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: + try: isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) except: isolation_distance = np.nan l_ratio = np.nan + if "isolation_distance" in metric_names: pc_metrics["isolation_distance"] = isolation_distance if "l_ratio" in metric_names: @@ -973,6 +1015,7 @@ def pca_metrics_one_unit(args): d_prime = lda_metrics(pcs_flat, labels, unit_id) except: d_prime = np.nan + pc_metrics["d_prime"] = d_prime if "nearest_neighbor" in metric_names: @@ -986,36 +1029,6 @@ def pca_metrics_one_unit(args): pc_metrics["nn_hit_rate"] = nn_hit_rate pc_metrics["nn_miss_rate"] = nn_miss_rate - if "nn_isolation" in metric_names: - try: - nn_isolation, nn_unit_id = nearest_neighbors_isolation( - we, - unit_id, - seed=seed, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - **qm_params["nn_isolation"], - ) - except: - nn_isolation = np.nan - nn_unit_id = np.nan - pc_metrics["nn_isolation"] = nn_isolation - pc_metrics["nn_unit_id"] = nn_unit_id - - if "nn_noise_overlap" in metric_names: - try: - nn_noise_overlap = nearest_neighbors_noise_overlap( - we, - unit_id, - n_spikes_all_units=n_spikes_all_units, - fr_all_units=fr_all_units, - seed=seed, - **qm_params["nn_noise_overlap"], - ) - except: - nn_noise_overlap = np.nan - pc_metrics["nn_noise_overlap"] = nn_noise_overlap - if "silhouette" in metric_names: silhouette_method = qm_params["silhouette"]["method"] if "simplified" in silhouette_method: @@ -1032,3 +1045,9 @@ def pca_metrics_one_unit(args): pc_metrics["silhouette_full"] = unit_silhouette_score return pc_metrics + + +_nn_metric_name_to_func = { + "nn_isolation": nearest_neighbors_isolation, + "nn_noise_overlap": nearest_neighbors_noise_overlap, +} diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 5a9191a256..f3eecb20bf 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -19,7 +19,7 @@ class ComputeQualityMetrics(AnalyzerExtension): """ - Compute quality metrics on sorting_. + Compute quality metrics on a `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py new file mode 100644 index 0000000000..bb2a345340 --- /dev/null +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index aec8201f44..90b622b9ab 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -47,37 +47,6 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=4, - seed=1205, - ) - - sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { @@ -126,7 +95,7 @@ def test_unit_id_order_independence(small_sorting_analyzer): """ recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -161,9 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer): ) for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][3] == metric_1_data["#3"] - assert quality_metrics_2[metric][2] == metric_1_data["#9"] - assert quality_metrics_2[metric][0] == metric_1_data["#4"] + assert quality_metrics_2[metric][2] == metric_1_data["#3"] + assert quality_metrics_2[metric][7] == metric_1_data["#9"] + assert quality_metrics_2[metric][1] == metric_1_data["#4"] def _sorting_analyzer_simple(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 4e5a4858bb..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,82 +1,24 @@ import pytest -from pathlib import Path import numpy as np -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - from spikeinterface.qualitymetrics import ( compute_pc_metrics, - nearest_neighbors_isolation, - nearest_neighbors_noise_overlap, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates", operators=["average", "std", "median"]) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - -def test_calculate_pc_metrics(sorting_analyzer_simple): +def test_calculate_pc_metrics(small_sorting_analyzer): import pandas as pd - sorting_analyzer = sorting_analyzer_simple - res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True) + sorting_analyzer = small_sorting_analyzer + res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) - for k in res1.columns: - mask = ~np.isnan(res1[k].values) - if np.any(mask): - assert np.array_equal(res1[k].values[mask], res2[k].values[mask]) - - -def test_nearest_neighbors_isolation(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_isolation(sorting_analyzer, this_unit_id) - - -def test_nearest_neighbors_noise_overlap(sorting_analyzer_simple): - sorting_analyzer = sorting_analyzer_simple - this_unit_id = sorting_analyzer.unit_ids[0] - nearest_neighbors_noise_overlap(sorting_analyzer, this_unit_id) - + for metric_name in res1.columns: + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(res1[metric_name].values)) + assert not np.all(np.isnan(res2[metric_name].values)) -if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - test_calculate_pc_metrics(sorting_analyzer) - test_nearest_neighbors_isolation(sorting_analyzer) - test_nearest_neighbors_noise_overlap(sorting_analyzer) + assert np.array_equal(res1[metric_name].values, res2[metric_name].values) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 2f965b0483..57755cd759 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -89,7 +89,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.preprocessing import correct_motion - from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording + from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording job_kwargs = params["job_kwargs"].copy() job_kwargs = fix_job_kwargs(job_kwargs) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 55ef21de9d..ec7e1e24a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -9,13 +9,13 @@ from spikeinterface.core import get_noise_levels from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.widgets import plot_probe_map -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion # import MEArec as mr diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index a6ff05fc55..38365adfd1 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -6,7 +6,7 @@ from spikeinterface.sorters import run_sorter from spikeinterface.comparison import GroundTruthComparison -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording +from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording from spikeinterface.curation import MergeUnitsSorting diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index aaa67e3aeb..4d6dd43bce 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -445,7 +445,7 @@ def load_folder(cls, folder): result[k] = load_extractor(folder / k) elif format == "Motion": - from spikeinterface.sortingcomponents.motion_utils import Motion + from spikeinterface.sortingcomponents.motion import Motion result[k] = Motion.load(folder / k) elif format == "zarr_templates": diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 526cc2e92f..78a9eb7dbc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -41,11 +41,11 @@ def test_benchmark_motion_estimaton(create_cache_folder): localize_kwargs=dict(method=loc_method), estimate_motion_kwargs=dict( method=est_method, - bin_duration_s=1.0, + bin_s=1.0, bin_um=5.0, rigid=False, win_step_um=50.0, - win_sigma_um=200.0, + win_scale_um=200.0, ), ), ) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 6d80d027f2..18def37d54 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -56,7 +56,7 @@ def test_benchmark_motion_interpolation(create_cache_folder): # plt.show() cases = {} - bin_duration_s = 1.0 + bin_s = 1.0 cases["static_SC2"] = dict( label="No drift - no correction - SC2", diff --git a/src/spikeinterface/sortingcomponents/motion/__init__.py b/src/spikeinterface/sortingcomponents/motion/__init__.py new file mode 100644 index 0000000000..d2e6a8a3d9 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/__init__.py @@ -0,0 +1,9 @@ +from .motion_utils import Motion +from .motion_estimation import estimate_motion +from .motion_interpolation import ( + correct_motion_on_peaks, + interpolate_motion_on_traces, + InterpolateMotionRecording, + interpolate_motion, +) +from .motion_cleaner import clean_motion_vector diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py new file mode 100644 index 0000000000..41b03b1c43 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -0,0 +1,810 @@ +import numpy as np + +from tqdm.auto import tqdm, trange + + +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d + +from .dredge import normxcorr1d + + +class DecentralizedRegistration: + """ + Method developed by the Paninski's group from Columbia university: + Charlie Windolf, Julien Boussard, Erdem Varol, Hyun Dong Lee + + This method is also known as DREDGe, but this implemenation does not use LFP signals. + + Original reference: + DECENTRALIZED MOTION INFERENCE AND REGISTRATION OF NEUROPIXEL DATA + https://ieeexplore.ieee.org/document/9414145 + https://proceedings.neurips.cc/paper/2021/hash/b950ea26ca12daae142bd74dba4427c8-Abstract.html + + This code was improved during Spike Sorting NY Hackathon 2022 by Erdem Varol and Charlie Windolf. + An additional major improvement can be found in this paper: + https://www.biorxiv.org/content/biorxiv/early/2022/12/05/2022.12.04.519043.full.pdf + + + Here are some various implementations by the original team: + https://github.com/int-brain-lab/spikes_localization_registration/blob/main/registration_pipeline/image_based_motion_estimate.py#L211 + https://github.com/cwindolf/spike-psvae/tree/main/spike_psvae + https://github.com/evarol/DREDge + """ + + name = "decentralized" + need_peak_location = True + params_doc = """ + bin_um: float, default: 10 + Spatial bin size in micrometers + hist_margin_um: float, default: 0 + Margin in um from histogram estimation. + Positive margin extrapolate out of the probe the motion. + Negative margin crop the motion on the border + bin_s: float, default 1.0 + Bin duration in second + histogram_depth_smooth_um: None or float + Optional gaussian smoother on histogram on depth axis. + This is given as the sigma of the gaussian in micrometers. + histogram_time_smooth_s: None or float + Optional gaussian smoother on histogram on time axis. + This is given as the sigma of the gaussian in seconds. + pairwise_displacement_method: "conv" or "phase_cross_correlation" + How to estimate the displacement in the pairwise matrix. + max_displacement_um: float + Maximum possible displacement in micrometers. + weight_scale: "linear" or "exp" + For parwaise displacement, how to to rescale the associated weight matrix. + error_sigma: float, default: 0.2 + In case weight_scale="exp" this controls the sigma of the exponential. + conv_engine: "numpy" or "torch" or None, default: None + In case of pairwise_displacement_method="conv", what library to use to compute + the underlying correlation + torch_device=None + In case of conv_engine="torch", you can control which device (cpu or gpu) + batch_size: int + Size of batch for the convolution. Increasing this will speed things up dramatically + on GPUs and sometimes on CPU as well. + corr_threshold: float + Minimum correlation between pair of time bins in order for these to be + considered when optimizing a global displacment vector to align with + the pairwise displacements. + time_horizon_s: None or float + When not None the parwise discplament matrix is computed in a small time horizon. + In short only pair of bins close in time. + So the pariwaise matrix is super sparse and have values only the diagonal. + convergence_method: "lsmr" | "lsqr_robust" | "gradient_descent", default: "lsqr_robust" + Which method to use to compute the global displacement vector from the pairwise matrix. + robust_regression_sigma: float + Use for convergence_method="lsqr_robust" for iterative selection of the regression. + temporal_prior : bool, default: True + Ensures continuity across time, unless there is evidence in the recording for jumps. + spatial_prior : bool, default: False + Ensures continuity across space. Not usually necessary except in recordings with + glitches across space. + force_spatial_median_continuity: bool, default: False + When spatial_prior=False we can optionally apply a median continuity across spatial windows. + reference_displacement : string, one of: "mean", "median", "time", "mode_search" + Strategy for picking what is considered displacement=0. + - "mean" : the mean displacement is subtracted + - "median" : the median displacement is subtracted + - "time" : the displacement at a given time (in seconds) is subtracted + - "mode_search" : an attempt is made to guess the mode. needs work. + lsqr_robust_n_iter: int + Number of iteration for convergence_method="lsqr_robust". + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + bin_um=1.0, + hist_margin_um=20.0, + bin_s=1.0, + histogram_depth_smooth_um=1.0, + histogram_time_smooth_s=1.0, + pairwise_displacement_method="conv", + max_displacement_um=100.0, + weight_scale="linear", + error_sigma=0.2, + conv_engine=None, + torch_device=None, + batch_size=1, + corr_threshold=0.0, + time_horizon_s=None, + convergence_method="lsqr_robust", + soft_weights=False, + normalized_xcorr=True, + centered_xcorr=True, + temporal_prior=True, + spatial_prior=False, + force_spatial_median_continuity=False, + reference_displacement="median", + reference_displacement_time_s=0, + robust_regression_sigma=2, + lsqr_robust_n_iter=20, + weight_with_amplitude=False, + ): + + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + # spatial histogram bins + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + + # get spatial windows + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depths, + spatial_bin_centers, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + zero_threshold=None, + ) + + # make 2D histogram raster + if verbose: + print("Computing motion histogram") + + motion_histogram, temporal_hist_bin_edges, spatial_hist_bin_edges = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + direction=direction, + bin_s=bin_s, + spatial_bin_edges=spatial_bin_edges, + weight_with_amplitude=weight_with_amplitude, + depth_smooth_um=histogram_depth_smooth_um, + time_smooth_s=histogram_time_smooth_s, + ) + + if extra is not None: + extra["motion_histogram"] = motion_histogram + extra["pairwise_displacement_list"] = [] + extra["temporal_hist_bin_edges"] = temporal_hist_bin_edges + extra["spatial_hist_bin_edges"] = spatial_hist_bin_edges + + # temporal bins are bin center + temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) + + motion_array = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) + windows_iter = non_rigid_windows + if progress_bar: + windows_iter = tqdm(windows_iter, desc="windows") + if spatial_prior: + all_pairwise_displacements = np.empty( + (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 + ) + all_pairwise_displacement_weights = np.empty( + (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 + ) + for i, win in enumerate(windows_iter): + window_slice = np.flatnonzero(win > 1e-5) + window_slice = slice(window_slice[0], window_slice[-1]) + if verbose: + print(f"Computing pairwise displacement: {i + 1} / {len(non_rigid_windows)}") + + pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement( + motion_histogram[:, window_slice], + bin_um, + window=win[window_slice], + method=pairwise_displacement_method, + weight_scale=weight_scale, + error_sigma=error_sigma, + conv_engine=conv_engine, + torch_device=torch_device, + batch_size=batch_size, + max_displacement_um=max_displacement_um, + normalized_xcorr=normalized_xcorr, + centered_xcorr=centered_xcorr, + corr_threshold=corr_threshold, + time_horizon_s=time_horizon_s, + bin_s=bin_s, + progress_bar=False, + ) + + if spatial_prior: + all_pairwise_displacements[i] = pairwise_displacement + all_pairwise_displacement_weights[i] = pairwise_displacement_weight + + if extra is not None: + extra["pairwise_displacement_list"].append(pairwise_displacement) + + if verbose: + print(f"Computing global displacement: {i + 1} / {len(non_rigid_windows)}") + + # TODO: if spatial_prior, do this after the loop + if not spatial_prior: + motion_array[:, i] = compute_global_displacement( + pairwise_displacement, + pairwise_displacement_weight=pairwise_displacement_weight, + convergence_method=convergence_method, + robust_regression_sigma=robust_regression_sigma, + lsqr_robust_n_iter=lsqr_robust_n_iter, + temporal_prior=temporal_prior, + spatial_prior=spatial_prior, + soft_weights=soft_weights, + progress_bar=False, + ) + + if spatial_prior: + motion_array = compute_global_displacement( + all_pairwise_displacements, + pairwise_displacement_weight=all_pairwise_displacement_weights, + convergence_method=convergence_method, + robust_regression_sigma=robust_regression_sigma, + lsqr_robust_n_iter=lsqr_robust_n_iter, + temporal_prior=temporal_prior, + spatial_prior=spatial_prior, + soft_weights=soft_weights, + progress_bar=False, + ) + elif len(non_rigid_windows) > 1: + # if spatial_prior is False, we still want keep the spatial bins + # correctly offset from each other + if force_spatial_median_continuity: + for i in range(len(non_rigid_windows) - 1): + motion_array[:, i + 1] -= np.median(motion_array[:, i + 1] - motion_array[:, i]) + + # try to avoid constant offset + # let the user choose how to do this. here are some ideas. + # (one can also -= their own number on the result of this function.) + if reference_displacement == "mean": + motion_array -= motion_array.mean() + elif reference_displacement == "median": + motion_array -= np.median(motion_array) + elif reference_displacement == "time": + # reference the motion to 0 at a specific time, independently in each window + reference_displacement_bin = np.digitize(reference_displacement_time_s, temporal_hist_bin_edges) - 1 + motion_array -= motion_array[reference_displacement_bin, :] + elif reference_displacement == "mode_search": + # just a sketch of an idea + # things might want to change, should have a configurable bin size, + # should use a call to histogram instead of the loop, ... + step_size = 0.1 + round_mode = np.round # floor? + best_ref = np.median(motion_array) + max_zeros = np.sum(round_mode(motion_array - best_ref) == 0) + for ref in np.arange(np.floor(motion_array.min()), np.ceil(motion_array.max()), step_size): + n_zeros = np.sum(round_mode(motion_array - ref) == 0) + if n_zeros > max_zeros: + max_zeros = n_zeros + best_ref = ref + motion_array -= best_ref + + # replace nan by zeros + np.nan_to_num(motion_array, copy=False) + + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) + + return motion + + +def compute_pairwise_displacement( + motion_hist, + bin_um, + method="conv", + weight_scale="linear", + error_sigma=0.2, + conv_engine="numpy", + torch_device=None, + batch_size=1, + max_displacement_um=1500, + corr_threshold=0, + time_horizon_s=None, + normalized_xcorr=True, + centered_xcorr=True, + bin_s=None, + progress_bar=False, + window=None, +): + """ + Compute pairwise displacement + """ + from scipy import linalg + + if conv_engine is None: + # use torch if installed + try: + import torch + + conv_engine = "torch" + except ImportError: + conv_engine = "numpy" + + if conv_engine == "torch": + import torch + + assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" + size = motion_hist.shape[0] + pairwise_displacement = np.zeros((size, size), dtype="float32") + + if time_horizon_s is not None: + band_width = int(np.ceil(time_horizon_s / bin_s)) + if band_width >= size: + time_horizon_s = None + + if conv_engine == "torch": + if torch_device is None: + torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if method == "conv": + if max_displacement_um is None: + n = motion_hist.shape[1] // 2 + else: + n = min( + motion_hist.shape[1] // 2, + int(np.ceil(max_displacement_um // bin_um)), + ) + possible_displacement = np.arange(-n, n + 1) * bin_um + + xrange = trange if progress_bar else range + + motion_hist_engine = motion_hist + window_engine = window + if conv_engine == "torch": + motion_hist_engine = torch.as_tensor(motion_hist, dtype=torch.float32, device=torch_device) + window_engine = torch.as_tensor(window, dtype=torch.float32, device=torch_device) + + pairwise_displacement = np.empty((size, size), dtype=np.float32) + correlation = np.empty((size, size), dtype=motion_hist.dtype) + + for i in xrange(0, size, batch_size): + corr = normxcorr1d( + motion_hist_engine, + motion_hist_engine[i : i + batch_size], + weights=window_engine, + padding=possible_displacement.size // 2, + conv_engine=conv_engine, + normalized=normalized_xcorr, + centered=centered_xcorr, + ) + if conv_engine == "torch": + max_corr, best_disp_inds = torch.max(corr, dim=2) + best_disp = possible_displacement[best_disp_inds.cpu()] + pairwise_displacement[i : i + batch_size] = best_disp + correlation[i : i + batch_size] = max_corr.cpu() + elif conv_engine == "numpy": + best_disp_inds = np.argmax(corr, axis=2) + max_corr = np.take_along_axis(corr, best_disp_inds[..., None], 2).squeeze() + best_disp = possible_displacement[best_disp_inds] + pairwise_displacement[i : i + batch_size] = best_disp + correlation[i : i + batch_size] = max_corr + + if corr_threshold is not None and corr_threshold > 0: + which = correlation > corr_threshold + correlation *= which + + elif method == "phase_cross_correlation": + # this 'phase_cross_correlation' is an old idea from Julien/Charlie/Erden that is kept for testing + # but this is not very releveant + try: + import skimage.registration + except ImportError: + raise ImportError("To use the 'phase_cross_correlation' method install scikit-image") + + errors = np.zeros((size, size), dtype="float32") + loop = range(size) + if progress_bar: + loop = tqdm(loop) + for i in loop: + for j in range(size): + shift, error, diffphase = skimage.registration.phase_cross_correlation( + motion_hist[i, :], motion_hist[j, :] + ) + pairwise_displacement[i, j] = shift * bin_um + errors[i, j] = error + correlation = 1 - errors + + else: + raise ValueError( + f"method {method} does not exist for compute_pairwise_displacement. Current possible methods are" + f" 'conv' or 'phase_cross_correlation'" + ) + + if weight_scale == "linear": + # between 0 and 1 + pairwise_displacement_weight = correlation + elif weight_scale == "exp": + pairwise_displacement_weight = np.exp((correlation - 1) / error_sigma) + + # handle the time horizon by multiplying the weights by a + # matrix with the time horizon on its diagonal bands. + if method == "conv" and time_horizon_s is not None and time_horizon_s > 0: + horizon_matrix = linalg.toeplitz( + np.r_[np.ones(band_width, dtype=bool), np.zeros(size - band_width, dtype=bool)] + ) + pairwise_displacement_weight *= horizon_matrix + + return pairwise_displacement, pairwise_displacement_weight + + +_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust") + + +def compute_global_displacement( + pairwise_displacement, + pairwise_displacement_weight=None, + sparse_mask=None, + temporal_prior=True, + spatial_prior=True, + soft_weights=False, + convergence_method="lsmr", + robust_regression_sigma=2, + lsqr_robust_n_iter=20, + progress_bar=False, +): + """ + Compute global displacement + + Arguments + --------- + pairwise_displacement : time x time array + pairwise_displacement_weight : time x time array + sparse_mask : time x time array + convergence_method : str + One of "gradient" + + """ + import scipy + from scipy.optimize import minimize + from scipy.sparse import csr_matrix + from scipy.sparse.linalg import lsqr + from scipy.stats import zscore + + if convergence_method == "gradient_descent": + size = pairwise_displacement.shape[0] + + D = pairwise_displacement + if pairwise_displacement_weight is not None or sparse_mask is not None: + # weighted problem + if pairwise_displacement_weight is None: + pairwise_displacement_weight = np.ones_like(D) + if sparse_mask is None: + sparse_mask = np.ones_like(D) + W = pairwise_displacement_weight * sparse_mask + + I, J = np.nonzero(W > 0) + Wij = W[I, J] + Dij = D[I, J] + W = csr_matrix((Wij, (I, J)), shape=W.shape) + WD = csr_matrix((Wij * Dij, (I, J)), shape=W.shape) + fixed_terms = (W @ WD).diagonal() - (WD @ W).diagonal() + diag_WW = (W @ W).diagonal() + Wsq = W.power(2) + + def obj(p): + return 0.5 * np.square(Wij * (Dij - (p[I] - p[J]))).sum() + + def jac(p): + return fixed_terms - 2 * (Wsq @ p) + 2 * p * diag_WW + + else: + # unweighted problem, it's faster when we have no weights + fixed_terms = -D.sum(axis=1) + D.sum(axis=0) + + def obj(p): + v = np.square((D - (p[:, None] - p[None, :]))).sum() + return 0.5 * v + + def jac(p): + return fixed_terms + 2 * (size * p - p.sum()) + + res = minimize(fun=obj, jac=jac, x0=D.mean(axis=1), method="L-BFGS-B") + if not res.success: + print("Global displacement gradient descent had an error") + displacement = res.x + + elif convergence_method == "lsqr_robust": + + if sparse_mask is not None: + I, J = np.nonzero(sparse_mask > 0) + elif pairwise_displacement_weight is not None: + I, J = pairwise_displacement_weight.nonzero() + else: + I, J = np.nonzero(np.ones_like(pairwise_displacement, dtype=bool)) + + nnz_ones = np.ones(I.shape[0], dtype=pairwise_displacement.dtype) + + if pairwise_displacement_weight is not None: + if isinstance(pairwise_displacement_weight, scipy.sparse.csr_matrix): + W = np.array(pairwise_displacement_weight[I, J]).T + else: + W = pairwise_displacement_weight[I, J][:, None] + else: + W = nnz_ones[:, None] + if isinstance(pairwise_displacement, scipy.sparse.csr_matrix): + V = np.array(pairwise_displacement[I, J])[0] + else: + V = pairwise_displacement[I, J] + M = csr_matrix((nnz_ones, (range(I.shape[0]), I)), shape=(I.shape[0], pairwise_displacement.shape[0])) + N = csr_matrix((nnz_ones, (range(I.shape[0]), J)), shape=(I.shape[0], pairwise_displacement.shape[0])) + A = M - N + idx = np.ones(A.shape[0], dtype=bool) + + # TODO: this is already soft_weights + xrange = trange if progress_bar else range + for i in xrange(lsqr_robust_n_iter): + p = lsqr(A[idx].multiply(W[idx]), V[idx] * W[idx][:, 0])[0] + idx = np.nonzero(np.abs(zscore(A @ p - V)) <= robust_regression_sigma) + displacement = p + + elif convergence_method == "lsmr": + import gc + from scipy import sparse + + D = pairwise_displacement + + # weighted problem + if pairwise_displacement_weight is None: + pairwise_displacement_weight = np.ones_like(D) + if sparse_mask is None: + sparse_mask = np.ones_like(D) + W = pairwise_displacement_weight * sparse_mask + if isinstance(W, scipy.sparse.csr_matrix): + W = W.astype(np.float32).toarray() + D = D.astype(np.float32).toarray() + + assert D.shape == W.shape + + # first dimension is the windows dim, which could be empty in rigid case + # we expand dims so that below we can consider only the nonrigid case + if D.ndim == 2: + W = W[None] + D = D[None] + assert D.ndim == W.ndim == 3 + B, T, T_ = D.shape + assert T == T_ + + # sparsify the problem + # we will make a list of temporal problems and then + # stack over the windows axis to finish. + # each matrix in coefficients will be (sparse_dim, T) + coefficients = [] + # each vector in targets will be (T,) + targets = [] + # we want to solve for a vector of shape BT, which we will reshape + # into a (B, T) matrix. + # after the loop below, we will stack a coefts matrix (sparse_dim, B, T) + # and a target vector of shape (B, T), both to be vectorized on last two axes, + # so that the target p is indexed by i = bT + t (block/window major). + + # calculate coefficients matrices and target vector + # this list stores boolean masks corresponding to whether or not each + # term comes from the prior or the likelihood. we can trim the likelihood terms, + # but not the prior terms, in the trimmed least squares (robust iters) iterations below. + cannot_trim = [] + for Wb, Db in zip(W, D): + # indices of active temporal pairs in this window + I, J = np.nonzero(Wb > 0) + n_sampled = I.size + + # construct Kroneckers and sparse objective in this window + pair_weights = np.ones(n_sampled) + if soft_weights: + pair_weights = Wb[I, J] + Mb = sparse.csr_matrix((pair_weights, (range(n_sampled), I)), shape=(n_sampled, T)) + Nb = sparse.csr_matrix((pair_weights, (range(n_sampled), J)), shape=(n_sampled, T)) + block_sparse_kron = Mb - Nb + block_disp_pairs = pair_weights * Db[I, J] + cannot_trim_block = np.ones_like(block_disp_pairs, dtype=bool) + + # add the temporal smoothness prior in this window + if temporal_prior: + temporal_diff_operator = sparse.diags( + ( + np.full(T - 1, -1, dtype=block_sparse_kron.dtype), + np.full(T - 1, 1, dtype=block_sparse_kron.dtype), + ), + offsets=(0, 1), + shape=(T - 1, T), + ) + block_sparse_kron = sparse.vstack( + (block_sparse_kron, temporal_diff_operator), + format="csr", + ) + block_disp_pairs = np.concatenate( + (block_disp_pairs, np.zeros(T - 1)), + ) + cannot_trim_block = np.concatenate( + (cannot_trim_block, np.zeros(T - 1, dtype=bool)), + ) + + coefficients.append(block_sparse_kron) + targets.append(block_disp_pairs) + cannot_trim.append(cannot_trim_block) + coefficients = sparse.block_diag(coefficients) + targets = np.concatenate(targets, axis=0) + cannot_trim = np.concatenate(cannot_trim, axis=0) + + # spatial smoothness prior: penalize difference of each block's + # displacement with the next. + # only if B > 1, and not in the last window. + # this is a (BT, BT) sparse matrix D such that: + # entry at (i, j) is: + # { 1 if i = j, i.e., i = j = bT + t for b = 0,...,B-2 + # { -1 if i = bT + t and j = (b+1)T + t for b = 0,...,B-2 + # { 0 otherwise. + # put more simply, the first (B-1)T diagonal entries are 1, + # and entries (i, j) such that i = j - T are -1. + if B > 1 and spatial_prior: + spatial_diff_operator = sparse.diags( + ( + np.ones((B - 1) * T, dtype=block_sparse_kron.dtype), + np.full((B - 1) * T, -1, dtype=block_sparse_kron.dtype), + ), + offsets=(0, T), + shape=((B - 1) * T, B * T), + ) + coefficients = sparse.vstack((coefficients, spatial_diff_operator)) + targets = np.concatenate((targets, np.zeros((B - 1) * T, dtype=targets.dtype))) + cannot_trim = np.concatenate((cannot_trim, np.zeros((B - 1) * T, dtype=bool))) + coefficients = coefficients.tocsr() + + # initialize at the column mean of pairwise displacements (in each window) + p0 = D.mean(axis=2).reshape(B * T) + + # use LSMR to solve the whole problem || targets - coefficients @ motion ||^2 + iters = range(max(1, lsqr_robust_n_iter)) + if progress_bar and lsqr_robust_n_iter > 1: + iters = tqdm(iters, desc="robust lsqr") + for it in iters: + # trim active set -- start with no trimming + idx = slice(None) + if it: + idx = np.flatnonzero( + cannot_trim | (np.abs(zscore(coefficients @ displacement - targets)) <= robust_regression_sigma) + ) + + # solve trimmed ols problem + displacement, *_ = sparse.linalg.lsmr(coefficients[idx], targets[idx], x0=p0) + + # warm start next iteration + p0 = displacement + # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) + # TODO: check if this gets fixed in scipy + gc.collect() + + displacement = displacement.reshape(B, T).T + else: + raise ValueError( + f"Method {convergence_method} doesn't exist for compute_global_displacement" + f" possible values for 'convergence_method' are {_possible_convergence_method}" + ) + + return np.squeeze(displacement) + + +# normxcorr1d is now implemented in dredge +# we keep the old version here but this will be removed soon + +# def normxcorr1d( +# template, +# x, +# weights=None, +# centered=True, +# normalized=True, +# padding="same", +# conv_engine="torch", +# ): +# """normxcorr1d: Normalized cross-correlation, optionally weighted + +# The API is like torch's F.conv1d, except I have accidentally +# changed the position of input/weights -- template acts like weights, +# and x acts like input. + +# Returns the cross-correlation of `template` and `x` at spatial lags +# determined by `mode`. Useful for estimating the location of `template` +# within `x`. + +# This might not be the most efficient implementation -- ideas welcome. +# It uses a direct convolutional translation of the formula +# corr = (E[XY] - EX EY) / sqrt(var X * var Y) + +# This also supports weights! In that case, the usual adaptation of +# the above formula is made to the weighted case -- and all of the +# normalizations are done per block in the same way. + +# Parameters +# ---------- +# template : tensor, shape (num_templates, length) +# The reference template signal +# x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) +# The signal in which to find `template` +# weights : tensor, shape (length,) +# Will use weighted means, variances, covariances if supplied. +# centered : bool +# If true, means will be subtracted (per weighted patch). +# normalized : bool +# If true, normalize by the variance (per weighted patch). +# padding : str +# How far to look? if unset, we'll use half the length +# conv_engine : string, one of "torch", "numpy" +# What library to use for computing cross-correlations. +# If numpy, falls back to the scipy correlate function. + +# Returns +# ------- +# corr : tensor +# """ +# if conv_engine == "torch": +# assert HAVE_TORCH +# conv1d = F.conv1d +# npx = torch +# elif conv_engine == "numpy": +# conv1d = scipy_conv1d +# npx = np +# else: +# raise ValueError(f"Unknown conv_engine {conv_engine}. 'conv_engine' must be 'torch' or 'numpy'") + +# x = npx.atleast_2d(x) +# num_templates, length = template.shape +# num_inputs, length_ = template.shape +# assert length == length_ + +# # generalize over weighted / unweighted case +# device_kw = {} if conv_engine == "numpy" else dict(device=x.device) +# ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw) +# no_weights = weights is None +# if no_weights: +# weights = ones +# wt = template[:, None, :] +# else: +# assert weights.shape == (length,) +# weights = weights[None, None] +# wt = template[:, None, :] * weights + +# # conv1d valid rule: +# # (B,1,L),(O,1,L)->(B,O,L) + +# # compute expectations +# # how many points in each window? seems necessary to normalize +# # for numerical stability. +# N = conv1d(ones, weights, padding=padding) +# if centered: +# Et = conv1d(ones, wt, padding=padding) +# Et /= N +# Ex = conv1d(x[:, None, :], weights, padding=padding) +# Ex /= N + +# # compute (weighted) covariance +# # important: the formula E[XY] - EX EY is well-suited here, +# # because the means are naturally subtracted correctly +# # patch-wise. you couldn't pre-subtract them! +# cov = conv1d(x[:, None, :], wt, padding=padding) +# cov /= N +# if centered: +# cov -= Ex * Et + +# # compute variances for denominator, using var X = E[X^2] - (EX)^2 +# if normalized: +# var_template = conv1d(ones, wt * template[:, None, :], padding=padding) +# var_template /= N +# var_x = conv1d(npx.square(x)[:, None, :], weights, padding=padding) +# var_x /= N +# if centered: +# var_template -= npx.square(Et) +# var_x -= npx.square(Ex) + +# # now find the final normxcorr +# corr = cov # renaming for clarity +# if normalized: +# corr /= npx.sqrt(var_x) +# corr /= npx.sqrt(var_template) +# # get rid of NaNs in zero-variance areas +# corr[~npx.isfinite(corr)] = 0 + +# return corr diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py new file mode 100644 index 0000000000..a0dde6d52b --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -0,0 +1,1407 @@ +""" +Copy-paste and then refactoring of DREDge +https://github.com/evarol/dredge + +For historical reason, some function from the DREDge package where implemeneted +in spikeinterface in the motion_estimation.py before the DREDge package itself! + +Here a copy/paste (and small rewriting) of some functions from DREDge. + +The main entry for this function are still: + + * motion = estimate_motion((recording, ..., method='dredge_lfp') + * motion = estimate_motion((recording, ..., method='dredge_ap') < not Done yet + +but here the original functions from Charlie, Julien and Erdem have been ported for an +easier maintenance instead of making DREDge a dependency of spikeinterface. + +Some renaming has been done. Small details has been added. +But this code is very similar to the original code. +2 classes has been added : DredgeApRegistration and DredgeLfpRegistration +but the original function dredge_ap() and dredge_online_lfp() can be used directly. + +""" + +import warnings + +from tqdm.auto import trange +import numpy as np + +import gc + +from .motion_utils import ( + Motion, + get_spatial_windows, + get_window_domains, + scipy_conv1d, + make_2d_motion_histogram, + get_spatial_bin_edges, +) + + +# simple class wrapper to be compliant with estimate_motion +class DredgeApRegistration: + """ + Estimate motion from spikes times and depth. + + This the certified and official version of the dredge implementation. + + Method developed by the Paninski's group from Columbia university: + Charlie Windolf, Julien Boussard, Erdem Varol + + This method is quite similar to "decentralized" which was the previous implementation in spikeinterface. + + The reference is here https://www.biorxiv.org/content/10.1101/2023.10.24.563768v1 + + The original code were here : https://github.com/evarol/DREDge + But this code which use the same internal function is in line with the Motion object of spikeinterface contrary to the dredge repo. + + This code has been ported in spikeinterface (with simple copy/paste) by Samuel but main author is truely Charlie Windolf. + """ + + name = "dredge_ap" + need_peak_location = True + params_doc = """ + bin_um: float + Bin duration in second + bin_s : float + The size of the bins along depth in microns and along time in seconds. + The returned object's .displacement array will respect these bins. + Increasing these can lead to more stable estimates and faster runtimes + at the cost of spatial and/or temporal resolution. + max_disp_um : float + Maximum possible displacement in microns. If you can guess a number which is larger + than the largest displacement possible in your recording across a span of `time_horizon_s` + seconds, setting this value to that number can stabilize the result and speed up + the algorithm (since it can do less cross-correlating). + By default, this is set to win-scale_um / 4, or 112.5 microns. Which can be a bit + large! + time_horizon_s : float + "Time horizon" parameter, in seconds. Time bins separated by more seconds than this + will not be cross-correlated. So, if your data has nonstationarities or changes which + could lead to bad cross-correlations at some timescale, it can help to input that + value here. If this is too small, it can make the motion estimation unstable. + mincorr : float, between 0 and 1 + Correlation threshold. Pairs of frames whose maximal cross correlation value is smaller + than this threshold will be ignored when solving for the global displacement estimate. + thomas_kw, xcorr_kw, raster_kw, weights_kw + These dictionaries allow setting parameters for fine control over the registration + device : str or torch.device + What torch device to run on? E.g., "cpu" or "cuda" or "cuda:1". + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ): + + outs = dredge_ap( + recording, + peaks, + peak_locations, + direction=direction, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + extra_outputs=(extra is not None), + progress_bar=progress_bar, + **method_kwargs, + ) + + if extra is not None: + motion, extra_ = outs + extra.update(extra_) + else: + motion = outs + return motion + + +# @TODO : Charlie I started very small refactoring, I let you continue +def dredge_ap( + recording, + peaks, + peak_locations, + direction="y", + rigid=False, + # nonrigid window construction arguments + win_shape="gaussian", + win_step_um=400, + win_scale_um=450, + win_margin_um=None, + bin_um=1.0, + bin_s=1.0, + max_disp_um=None, + time_horizon_s=1000.0, + mincorr=0.1, + # weights arguments + do_window_weights=True, + weights_threshold_low=0.2, + weights_threshold_high=0.2, + mincorr_percentile=None, + mincorr_percentile_nneighbs=None, + # raster arguments + amp_scale_fn=None, ## @Charlie this one is not used anymore + post_transform=np.log1p, ###@this one is directly transimited to weight_correlation_matrix() and so get_wieiith() + histogram_depth_smooth_um=1, + histogram_time_smooth_s=1, + avg_in_bin=False, + # low-level keyword args + thomas_kw=None, + xcorr_kw=None, + # misc + device=None, + progress_bar=True, + extra_outputs=False, + precomputed_D_C_maxdisp=None, +): + """Estimate motion from spikes + + Spikes located at depths specified in `depths` along the probe, occurring at times in + seconds specified in `times` with amplitudes `amps` are used to create a 2d image of + the spiking activity. This image is cross-correlated with itself to produce a displacement + matrix (or several, one for each nonrigid window). This matrix is used to solve for a + motion estimate. + + Arguments + --------- + recording: BaseRecording + The recording extractor + peaks: numpy array + Peak vector (complex dtype). + Needed for decentralized and iterative_template methods. + peak_locations: numpy array + Complex dtype with "x", "y", "z" fields + Needed for decentralized and iterative_template methods. + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + rigid : bool, default=False + If True, ignore the nonrigid window args (win_shape, win_step_um, win_scale_um, + win_margin_um) and do rigid registration (equivalent to one flat window, which + is how it's implemented). + win_shape : str, default="gaussian" + Nonrigid window shape + win_step_um : float + Spacing between nonrigid window centers in microns + win_scale_um : float + Controls the width of nonrigid windows centers + win_margin_um : float + Distance of nonrigid windows centers from the probe boundary (-1000 means there will + be no window center within 1000um of the edge of the probe) + {} + + Returns + ------- + motion : Motion + The motion object + extra : dict + This has extra info about what happened during registration, including the nonrigid + windows if one wants to visualize them. Set `extra_outputs` to also save displacement + and correlation matrices. + """ + + dim = ["x", "y", "z"].index(direction) + # @charlie: I removed amps/depths_um/times_s from the signature + # preaks and peak_locations are more SI compatible + # the way to get then + amps = peak_amplitudes = peaks["amplitude"] + depths_um = peak_depths = peak_locations[direction] + times_s = peak_times = recording.sample_index_to_time(peaks["sample_index"]) + + thomas_kw = thomas_kw if thomas_kw is not None else {} + xcorr_kw = xcorr_kw if xcorr_kw is not None else {} + if time_horizon_s: + xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) + + # TODO @charlie I think this is a bad to have the dict which is transported to every function + # this should be used only in histogram function but not in weight_correlation_matrix() + # only important kwargs should be explicitly reported + # raster_kw = dict( + # amp_scale_fn=amp_scale_fn, + # post_transform=post_transform, + # histogram_depth_smooth_um=histogram_depth_smooth_um, + # histogram_time_smooth_s=histogram_time_smooth_s, + # bin_s=bin_s, + # bin_um=bin_um, + # avg_in_bin=avg_in_bin, + # return_counts=count_masked_correlation, + # count_bins=count_bins, + # count_bin_min=count_bin_min, + # ) + + weights_kw = dict( + mincorr=mincorr, + time_horizon_s=time_horizon_s, + do_window_weights=do_window_weights, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + ) + + # this will store return values other than the MotionEstimate + extra = {} + + # TODO charlie I switch this to make_2d_motion_histogram + # but we need to add all options from the original spike_raster() + # but I think this is OK + # raster_res = spike_raster( + # amps, + # depths_um, + # times_s, + # **raster_kw, + # ) + # if count_masked_correlation: + # raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res + # else: + # raster, spatial_bin_edges_um, time_bin_edges_s = raster_res + + motion_histogram, time_bin_edges_s, spatial_bin_edges_um = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=True, + avg_in_bin=avg_in_bin, + direction=direction, + bin_s=bin_s, + bin_um=bin_um, + hist_margin_um=0.0, # @charlie maybe we should expose this and set +20. for instance + spatial_bin_edges=None, + depth_smooth_um=histogram_depth_smooth_um, + time_smooth_s=histogram_time_smooth_s, + ) + raster = motion_histogram.T + + # TODO charlie : put the log for hitstogram + + # TODO @charlie you should check that we are doing the same thing + # windows, window_centers = get_spatial_windows( + # np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], + # win_step_um, + # win_scale_um, + # spatial_bin_edges=spatial_bin_edges_um, + # margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um, + # win_shape=win_shape, + # zero_threshold=1e-5, + # rigid=rigid, + # ) + + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) + + windows, window_centers = get_spatial_windows( + contact_depths, + spatial_bin_centers, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + zero_threshold=1e-5, + ) + + # TODO charlie : the count has disapeared + # if extra_outputs and count_masked_correlation: + # extra["counts"] = counts + + # cross-correlate to get D and C + if precomputed_D_C_maxdisp is None: + Ds, Cs, max_disp_um = xcorr_windows( + raster, + windows, + spatial_bin_edges_um, + win_scale_um, + rigid=rigid, + bin_um=bin_um, + max_disp_um=max_disp_um, + progress_bar=progress_bar, + device=device, + # TODO charlie : put back the count for the mask + # masks=(counts > 0) if count_masked_correlation else None, + **xcorr_kw, + ) + else: + Ds, Cs, max_disp_um = precomputed_D_C_maxdisp + + # turn Cs into weights + Us, wextra = weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + spatial_bin_edges_um, + time_bin_edges_s, + # raster_kw, #@charlie this is removed + post_transform=post_transform, # @charlie this isnew + lambda_t=thomas_kw.get("lambda_t", DEFAULT_LAMBDA_T), + eps=thomas_kw.get("eps", DEFAULT_EPS), + progress_bar=progress_bar, + in_place=not extra_outputs, + **weights_kw, + ) + extra.update({k: wextra[k] for k in wextra if k not in ("S", "U")}) + if extra_outputs: + extra.update({k: wextra[k] for k in wextra if k in ("S", "U")}) + del wextra + if extra_outputs: + extra["D"] = Ds + extra["C"] = Cs + del Cs + + # @charlie : is this needed ? + gc.collect() + + # solve for P + # now we can do our tridiag solve + displacement, textra = thomas_solve(Ds, Us, progress_bar=progress_bar, **thomas_kw) + if extra_outputs: + extra.update(textra) + del textra + + if extra_outputs: + extra["windows"] = windows + extra["window_centers"] = window_centers + extra["max_disp_um"] = max_disp_um + + time_bin_centers = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + motion = Motion([displacement.T], [time_bin_centers], window_centers, direction=direction) + + if extra_outputs: + return motion, extra + else: + return motion + + +dredge_ap.__doc__ = dredge_ap.__doc__.format(DredgeApRegistration.params_doc) + + +# simple class wrapper to be compliant with estimate_motion +class DredgeLfpRegistration: + """ + Estimate motion from LFP recording. + + This the certified and official version of the dredge implementation. + + Method developed by the Paninski's group from Columbia university: + Charlie Windolf, Julien Boussard, Erdem Varol + + The reference is here https://www.biorxiv.org/content/10.1101/2023.10.24.563768v1 + """ + + name = "dredge_lfp" + need_peak_location = False + params_doc = """ + lfp_recording : spikeinterface BaseRecording object + Preprocessed LFP recording. The temporal resolution of this recording will + be the target resolution of the registration, so definitely use SpikeInterface + to resample your recording to, say, 250Hz (or a value you like) rather than + estimating motion at the original frequency (which may be high). + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + rigid : boolean, optional + If True, window-related arguments are ignored and we do rigid registration + win_shape, win_step_um, win_scale_um, win_margin_um : float + Nonrigid window-related arguments + The depth domain will be broken up into windows with shape controlled by win_shape, + spaced by win_step_um at a margin of win_margin_um from the boundary, and with + width controlled by win_scale_um. + chunk_len_s : float + Length of chunks (in seconds) that the recording is broken into for online + registration. The computational speed of the method is a function of the + number of samples this corresponds to, and things can get slow if it is + set high enough that the number of samples per chunk is bigger than ~10,000. + But, it can't be set too low or the algorithm doesn't have enough data + to work with. The default is set assuming sampling rate of 250Hz, leading + to 2500 samples per chunk. + time_horizon_s : float + Time-bins farther apart than this value in seconds will not be cross-correlated. + Set this to at least `chunk_len_s`. + max_disp_um : number, optional + This is the ceiling on the possible displacement estimates. It should be + set to a number which is larger than the allowed displacement in a single + chunk. Setting it as small as possible (while following that rule) can speed + things up and improve the result by making it impossible to estimate motion + which is too big. + mincorr : float in [0,1] + Minimum correlation between pairs of frames such that they will be included + in the optimization of the displacement estimates. + mincorr_percentile, mincorr_percentile_nneighbs + If mincorr_percentile is set to a number in [0, 100], then mincorr will be replaced + by this percentile of the correlations of neighbors within mincorr_percentile_nneighbs + time bins of each other. + device : string or torch.device + Controls torch device + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ): + # Note peaks and peak_locations are not used and can be None + + outs = dredge_online_lfp( + recording, + direction=direction, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + extra_outputs=(extra is not None), + progress_bar=progress_bar, + **method_kwargs, + ) + + if extra is not None: + motion, extra_ = outs + extra.update(extra_) + else: + motion = outs + return motion + + +def dredge_online_lfp( + lfp_recording, + direction="y", + # nonrigid window construction arguments + rigid=True, + win_shape="gaussian", + win_step_um=800, + win_scale_um=850, + win_margin_um=None, + chunk_len_s=10.0, + max_disp_um=500, + time_horizon_s=None, + # weighting arguments + mincorr=0.8, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + soft=False, + # low-level arguments + thomas_kw=None, + xcorr_kw=None, + # misc + extra_outputs=False, + device=None, + progress_bar=True, +): + """Online registration of a preprocessed LFP recording + + Arguments + --------- + {} + + Returns + ------- + motion : Motion + A motion object. + extra : dict + Dict containing extra info for debugging + """ + dim = ["x", "y", "z"].index(direction) + # contact pos is the only on the direction + contact_depths = lfp_recording.get_channel_locations()[:, dim] + + fs = lfp_recording.get_sampling_frequency() + T_total = lfp_recording.get_num_samples() + T_chunk = min(int(np.floor(fs * chunk_len_s)), T_total) + + # kwarg defaults and handling + # need lfp-specific defaults + xcorr_kw = xcorr_kw if xcorr_kw is not None else {} + thomas_kw = thomas_kw if thomas_kw is not None else {} + full_xcorr_kw = dict( + rigid=rigid, + bin_um=np.median(np.diff(contact_depths)), + max_disp_um=max_disp_um, + progress_bar=False, + device=device, + **xcorr_kw, + ) + threshold_kw = dict( + mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, + in_place=True, + soft=soft, + # time_horizon_s=weights_kw["time_horizon_s"], # max_dt not implemented for lfp at this point + time_horizon_s=time_horizon_s, + bin_s=1 / fs, # only relevant for time_horizon_s + ) + + # here we check that contact positons are unique on the direction + if contact_depths.size != np.unique(contact_depths).size: + raise ValueError( + f"estimate motion with 'dredge_lfp' need channel_positions to be unique in the direction='{direction}'" + ) + if np.any(np.diff(contact_depths) < 0): + raise ValueError( + f"estimate motion with 'dredge_lfp' need channel_positions to be ordered direction='{direction}'" + "please use spikeinterface.preprocessing.depth_order(recording)" + ) + + # Important detail : in LFP bin center are contact position in the direction + spatial_bin_centers = contact_depths + + windows, window_centers = get_spatial_windows( + contact_depths=contact_depths, + spatial_bin_centers=spatial_bin_centers, + rigid=rigid, + win_margin_um=win_margin_um, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_shape=win_shape, + zero_threshold=1e-5, + ) + + B = len(windows) + + if extra_outputs: + extra = dict(window_centers=window_centers, windows=windows) + + # -- allocate output and initialize first chunk + P_online = np.empty((B, T_total), dtype=np.float32) + # below, t0 is start of prev chunk, t1 start of cur chunk, t2 end of cur + t0, t1 = 0, T_chunk + traces0 = lfp_recording.get_traces(start_frame=t0, end_frame=t1) + Ds0, Cs0, max_disp_um = xcorr_windows(traces0.T, windows, contact_depths, win_scale_um, **full_xcorr_kw) + full_xcorr_kw["max_disp_um"] = max_disp_um + Ss0, mincorr0 = threshold_correlation_matrix( + Cs0, + mincorr=mincorr, + mincorr_percentile=mincorr_percentile, + **threshold_kw, + ) + if extra_outputs: + extra["D"] = [Ds0] + extra["C"] = [Cs0] + extra["S"] = [Ss0] + extra["D01"] = [] + extra["C01"] = [] + extra["S01"] = [] + extra["mincorrs"] = [mincorr0] + extra["max_disp_um"] = max_disp_um + + P_online[:, t0:t1], _ = thomas_solve(Ds0, Ss0, **thomas_kw) + + # -- loop through chunks + chunk_starts = range(T_chunk, T_total, T_chunk) + if progress_bar: + chunk_starts = trange( + T_chunk, + T_total, + T_chunk, + desc=f"Online chunks [{chunk_len_s}s each]", + ) + for t1 in chunk_starts: + t2 = min(T_total, t1 + T_chunk) + traces1 = lfp_recording.get_traces(start_frame=t1, end_frame=t2) + + # cross-correlations between prev/cur chunks + # these are T1, T0 shaped + Ds10, Cs10, _ = xcorr_windows( + traces1.T, + windows, + contact_depths, + win_scale_um, + raster_b=traces0.T, + **full_xcorr_kw, + ) + + # cross-correlation in current chunk + Ds1, Cs1, _ = xcorr_windows(traces1.T, windows, contact_depths, win_scale_um, **full_xcorr_kw) + Ss1, mincorr1 = threshold_correlation_matrix( + Cs1, + mincorr_percentile=mincorr_percentile, + mincorr=mincorr, + **threshold_kw, + ) + Ss10, _ = threshold_correlation_matrix(Cs10, mincorr=mincorr1, t_offset_bins=T_chunk, **threshold_kw) + + if extra_outputs: + extra["mincorrs"].append(mincorr1) + extra["D"].append(Ds1) + extra["C"].append(Cs1) + extra["S"].append(Ss1) + extra["D01"].append(Ds10) + extra["C01"].append(Cs10) + extra["S01"].append(Ss10) + + # solve online problem + P_online[:, t1:t2], _ = thomas_solve( + Ds1, + Ss1, + P_prev=P_online[:, t0:t1], + Ds_curprev=Ds10, + Us_curprev=Ss10, + Ds_prevcur=-Ds10.transpose(0, 2, 1), + Us_prevcur=Ss10.transpose(0, 2, 1), + **thomas_kw, + ) + + # update loop vars + t0, t1 = t1, t2 + traces0 = traces1 + + motion = Motion([P_online.T], [lfp_recording.get_times(0)], window_centers, direction=direction) + + if extra_outputs: + return motion, extra + else: + return motion + + +dredge_online_lfp.__doc__ = dredge_online_lfp.__doc__.format(DredgeLfpRegistration.params_doc) + + +# -- functions from dredgelib (zone forbiden for sam) + +DEFAULT_LAMBDA_T = 1.0 +DEFAULT_EPS = 1e-3 + +# -- linear algebra, Newton method solver, block tridiagonal (Thomas) solver + + +def laplacian(n, wink=True, eps=DEFAULT_EPS, lambd=1.0, ridge_mask=None): + """Construct a discrete Laplacian operator (plus eps*identity).""" + lap = np.zeros((n, n)) + if ridge_mask is None: + diag = lambd + eps + else: + diag = lambd + eps * ridge_mask + np.fill_diagonal(lap, diag) + if wink: + lap[0, 0] -= 0.5 * lambd + lap[-1, -1] -= 0.5 * lambd + # fill diagonal using a for loop for space reasons when this is large + for i in range(n - 1): + lap[i, i + 1] -= 0.5 * lambd + lap[i + 1, i] -= 0.5 * lambd + return lap + + +def neg_hessian_likelihood_term(Ub, Ub_prevcur=None, Ub_curprev=None): + """Newton step coefficients + + The negative Hessian of the non-regularized cost function inside a nonrigid block. + Together with the term arising from the regularization, this constructs the + coefficients matrix in our linear problem. + """ + negHUb = -Ub - Ub.T + diagonal_terms = np.diagonal(negHUb) + Ub.sum(1) + Ub.sum(0) + if Ub_prevcur is None: + np.fill_diagonal(negHUb, diagonal_terms) + else: + diagonal_terms += Ub_prevcur.sum(0) + Ub_curprev.sum(1) + np.fill_diagonal(negHUb, diagonal_terms) + return negHUb + + +def newton_rhs( + Db, + Ub, + Pb_prev=None, + Db_prevcur=None, + Ub_prevcur=None, + Db_curprev=None, + Ub_curprev=None, +): + """Newton step right hand side + + The gradient at P=0 of the cost function, which is the right hand side of Newton's method. + """ + UDb = Ub * Db + grad_at_0 = UDb.sum(1) - UDb.sum(0) + + # batch case + if Pb_prev is None: + return grad_at_0 + + # online case + align_term = (Ub_prevcur.T + Ub_curprev) @ Pb_prev + rhs = align_term + grad_at_0 + (Ub_curprev * Db_curprev).sum(1) - (Ub_prevcur * Db_prevcur).sum(0) + + return rhs + + +def newton_solve_rigid( + D, + U, + Sigma0inv, + Pb_prev=None, + Db_prevcur=None, + Ub_prevcur=None, + Db_curprev=None, + Ub_curprev=None, +): + """Solve the rigid Newton step + + D is TxT displacement, U is TxT subsampling or soft weights matrix. + """ + from scipy.linalg import solve, lstsq + + negHU = neg_hessian_likelihood_term( + U, + Ub_prevcur=Ub_prevcur, + Ub_curprev=Ub_curprev, + ) + targ = newton_rhs( + D, + U, + Pb_prev=Pb_prev, + Db_prevcur=Db_prevcur, + Ub_prevcur=Ub_prevcur, + Db_curprev=Db_curprev, + Ub_curprev=Ub_curprev, + ) + try: + p = solve(Sigma0inv + negHU, targ, assume_a="pos") + except np.linalg.LinAlgError: + warnings.warn("Singular problem, using least squares.") + p, *_ = lstsq(Sigma0inv + negHU, targ) + return p, negHU + + +def thomas_solve( + Ds, + Us, + lambda_t=DEFAULT_LAMBDA_T, + lambda_s=1.0, + eps=DEFAULT_EPS, + P_prev=None, + Ds_prevcur=None, + Us_prevcur=None, + Ds_curprev=None, + Us_curprev=None, + progress_bar=False, + bandwidth=None, +): + """Block tridiagonal algorithm, special cased to our setting + + This code solves for the displacement estimates across the nonrigid windows, + given blockwise, pairwise (BxTxT) displacement and weights arrays `Ds` and `Us`. + + If `lambda_t>0`, a temporal prior is applied to "fill the gaps", effectively + interpolating through time to avoid artifacts in low-signal areas. Setting this + to 0 can lead to numerical warnings and should be done with care. + + If `lambda_s>0`, a spatial prior is applied. This can help fill gaps more + meaningfully in the nonrigid case, using information from the neighboring nonrigid + windows to inform the estimate in an untrusted region of a given window. + + If arguments `P_prev,Ds_prevcur,Us_prevcur` are supplied, this code handles the + online case. The return value will be the new chunk's displacement estimate, + solving the online registration problem. + """ + from scipy.linalg import solve + + Ds = np.asarray(Ds, dtype=np.float64) + Us = np.asarray(Us, dtype=np.float64) + online = P_prev is not None + online_kw_rhs = online_kw_hess = lambda b: {} + if online: + assert Ds_prevcur is not None + assert Us_prevcur is not None + online_kw_rhs = lambda b: dict( # noqa + Pb_prev=P_prev[b].astype(np.float64, copy=False), + Db_prevcur=Ds_prevcur[b].astype(np.float64, copy=False), + Ub_prevcur=Us_prevcur[b].astype(np.float64, copy=False), + Db_curprev=Ds_curprev[b].astype(np.float64, copy=False), + Ub_curprev=Us_curprev[b].astype(np.float64, copy=False), + ) + online_kw_hess = lambda b: dict( # noqa + Ub_prevcur=Us_prevcur[b].astype(np.float64, copy=False), + Ub_curprev=Us_curprev[b].astype(np.float64, copy=False), + ) + + B, T, T_ = Ds.shape + assert T == T_ + assert Us.shape == Ds.shape + + # figure out which temporal bins are included in the problem + # these are used to figure out where epsilon can be added + # for numerical stability without changing the solution + had_weights = (Us > 0).any(axis=2) + had_weights[~had_weights.any(axis=1)] = 1 + + # temporal prior matrix + L_t = [laplacian(T, eps=eps, lambd=lambda_t, ridge_mask=w) for w in had_weights] + extra = dict(L_t=L_t) + + # just solve independent problems when there's no spatial regularization + # not that there's much overhead to the backward pass etc but might as well + if B == 1 or lambda_s == 0: + P = np.zeros((B, T)) + extra["HU"] = np.zeros((B, T, T)) + for b in range(B): + P[b], extra["HU"][b] = newton_solve_rigid(Ds[b], Us[b], L_t[b], **online_kw_rhs(b)) + return P, extra + + # spatial prior is a sparse, block tridiagonal kronecker product + # the first and last diagonal blocks are + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s / 2, ridge_mask=had_weights[0]) + # and the off-diagonal blocks are + Lambda_s_offdiag = laplacian(T, eps=0, lambd=-lambda_s / 2) + + # initialize block-LU stuff and forward variable + alpha_hat_b = L_t[0] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[0], **online_kw_hess(0)) + targets = np.c_[Lambda_s_offdiag, newton_rhs(Us[0], Ds[0], **online_kw_rhs(0))] + res = solve(alpha_hat_b, targets, assume_a="pos") + assert res.shape == (T, T + 1) + gamma_hats = [res[:, :T]] + ys = [res[:, T]] + + # forward pass + for b in trange(1, B, desc="Solve") if progress_bar else range(1, B): + if b < B - 1: + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s, ridge_mask=had_weights[b]) + else: + Lambda_s_diagb = laplacian(T, eps=eps, lambd=lambda_s / 2, ridge_mask=had_weights[b]) + + Ab = L_t[b] + Lambda_s_diagb + neg_hessian_likelihood_term(Us[b], **online_kw_hess(b)) + alpha_hat_b = Ab - Lambda_s_offdiag @ gamma_hats[b - 1] + targets[:, T] = newton_rhs(Us[b], Ds[b], **online_kw_rhs(b)) + targets[:, T] -= Lambda_s_offdiag @ ys[b - 1] + res = solve(alpha_hat_b, targets) + assert res.shape == (T, T + 1) + gamma_hats.append(res[:, :T]) + ys.append(res[:, T]) + + # back substitution + xs = [None] * B + xs[-1] = ys[-1] + for b in range(B - 2, -1, -1): + xs[b] = ys[b] - gamma_hats[b] @ xs[b + 1] + + # un-vectorize + P = np.concatenate(xs).reshape(B, T) + + return P, extra + + +def threshold_correlation_matrix( + Cs, + mincorr=0.0, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + time_horizon_s=0, + in_place=False, + bin_s=1, + t_offset_bins=None, + T=None, + soft=True, +): + if mincorr_percentile is not None: + diags = [np.diagonal(Cs, offset=j, axis1=1, axis2=2).ravel() for j in range(1, mincorr_percentile_nneighbs)] + mincorr = np.percentile( + np.concatenate(diags), + mincorr_percentile, + ) + + # need abs to avoid -0.0s which cause numerical issues + if in_place: + Ss = Cs + if soft: + Ss[Ss < mincorr] = 0 + else: + Ss = (Ss >= mincorr).astype(Cs.dtype) + np.square(Ss, out=Ss) + else: + if soft: + Ss = np.square((Cs >= mincorr) * Cs) + else: + Ss = (Cs >= mincorr).astype(Cs.dtype) + if time_horizon_s is not None and time_horizon_s > 0 and T is not None and time_horizon_s < T: + tt0 = bin_s * np.arange(T) + tt1 = tt0 + if t_offset_bins: + tt1 = tt0 + t_offset_bins + dt = tt1[:, None] - tt0[None, :] + mask = (np.abs(dt) <= time_horizon_s).astype(Ss.dtype) + Ss *= mask[None] + return Ss, mincorr + + +def xcorr_windows( + raster_a, + windows, + spatial_bin_edges_um, + win_scale_um, + raster_b=None, + rigid=False, + bin_um=1, + max_disp_um=None, + max_dt_bins=None, + progress_bar=True, + centered=True, + normalized=True, + masks=None, + device=None, +): + """Main computational function + + Compute pairwise (time x time) maximum cross-correlation and displacement + matrices in each nonrigid window. + """ + import torch + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if max_disp_um is None: + if rigid: + max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + else: + max_disp_um = int(win_scale_um // 4) + + max_disp_bins = int(max_disp_um // bin_um) + slices = get_window_domains(windows) + B, D = windows.shape + D_, T0 = raster_a.shape + + assert D == D_ + + # torch versions on device + windows_ = torch.as_tensor(windows, dtype=torch.float, device=device) + raster_a_ = torch.as_tensor(raster_a, dtype=torch.float, device=device) + if raster_b is not None: + assert raster_b.shape[0] == D + T1 = raster_b.shape[1] + raster_b_ = torch.as_tensor(raster_b, dtype=torch.float, device=device) + else: + T1 = T0 + raster_b_ = raster_a_ + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float, device=device) + + # estimate each window's displacement + Ds = np.zeros((B, T0, T1), dtype=np.float32) + Cs = np.zeros((B, T0, T1), dtype=np.float32) + block_iter = trange(B, desc="Cross correlation") if progress_bar else range(B) + for b in block_iter: + window = windows_[b] + + # we search for the template (windowed part of raster a) + # within a larger-than-the-window neighborhood in raster b + targ_low = slices[b].start - max_disp_bins + b_low = max(0, targ_low) + targ_high = slices[b].stop + max_disp_bins + b_high = min(D, targ_high) + padding = max(b_low - targ_low, targ_high - b_high) + + # arithmetic to compute the lags in um corresponding to + # corr argmaxes + n_left = padding + slices[b].start - b_low + n_right = padding + b_high - slices[b].stop + poss_disp = -np.arange(-n_left, n_right + 1) * bin_um + + Ds[b], Cs[b] = calc_corr_decent_pair( + raster_a_[slices[b]], + raster_b_[b_low:b_high], + weights=window[slices[b]], + masks=None if masks is None else masks[slices[b]], + xmasks=None if masks is None else masks[b_low:b_high], + disp=padding, + possible_displacement=poss_disp, + device=device, + centered=centered, + normalized=normalized, + max_dt_bins=max_dt_bins, + ) + + return Ds, Cs, max_disp_um + + +def calc_corr_decent_pair( + raster_a, + raster_b, + weights=None, + masks=None, + xmasks=None, + disp=None, + batch_size=512, + normalized=True, + centered=True, + possible_displacement=None, + max_dt_bins=None, + device=None, +): + """Weighted pairwise cross-correlation + + Calculate TxT normalized xcorr and best displacement matrices + Given a DxT raster, this computes normalized cross correlations for + all pairs of time bins at offsets in the range [-disp, disp], by + increments of step_size. Then it finds the best one and its + corresponding displacement, resulting in two TxT matrices: one for + the normxcorrs at the best displacement, and the matrix of the best + displacements. + + Arguments + --------- + raster : DxT array + batch_size : int + How many raster rows to xcorr against the whole raster + at once. + step_size : int + Displacement increment. Not implemented yet but easy to do. + disp : int + Maximum displacement + device : torch device + Returns: D, C: TxT arrays + """ + import torch + + D, Ta = raster_a.shape + D_, Tb = raster_b.shape + + # sensible default: at most half the domain. + if disp is None: + disp == D // 2 + + # range of displacements + if D == D_: + if possible_displacement is None: + possible_displacement = np.arange(-disp, disp + 1) + else: + assert possible_displacement is not None + assert disp is not None + + # pick torch device if unset + if device is None: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # process rasters into the tensors we need for conv2ds below + # convert to TxD device floats + raster_a = torch.as_tensor(raster_a.T, dtype=torch.float32, device=device) + # normalize over depth for normalized (uncentered) xcorrs + raster_b = torch.as_tensor(raster_b.T, dtype=torch.float32, device=device) + + D = np.zeros((Ta, Tb), dtype=np.float32) + C = np.zeros((Ta, Tb), dtype=np.float32) + for i in range(0, Ta, batch_size): + for j in range(0, Tb, batch_size): + dt_bins = min(abs(i - j), abs(i + batch_size - j), abs(i - j - batch_size)) + if max_dt_bins and dt_bins > max_dt_bins: + continue + weights_ = weights + if masks is not None: + weights_ = masks.T[i : i + batch_size] * weights + corr = normxcorr1d( + raster_a[i : i + batch_size], + raster_b[j : j + batch_size], + weights=weights_, + xmasks=None if xmasks is None else xmasks.T[j : j + batch_size], + padding=disp, + normalized=normalized, + centered=centered, + ) + max_corr, best_disp_inds = torch.max(corr, dim=2) + best_disp = possible_displacement[best_disp_inds.cpu()] + D[i : i + batch_size, j : j + batch_size] = best_disp.T + C[i : i + batch_size, j : j + batch_size] = max_corr.cpu().T + + return D, C + + +def normxcorr1d( + template, + x, + weights=None, + xmasks=None, + centered=True, + normalized=True, + padding="same", + conv_engine="torch", +): + """ + normxcorr1d: Normalized cross-correlation, optionally weighted + + The API is like torch's F.conv1d, except I have accidentally + changed the position of input/weights -- template acts like weights, + and x acts like input. + + Returns the cross-correlation of `template` and `x` at spatial lags + determined by `mode`. Useful for estimating the location of `template` + within `x`. + + This might not be the most efficient implementation -- ideas welcome. + It uses a direct convolutional translation of the formula + corr = (E[XY] - EX EY) / sqrt(var X * var Y) + + This also supports weights! In that case, the usual adaptation of + the above formula is made to the weighted case -- and all of the + normalizations are done per block in the same way. + + Parameters + ---------- + template : tensor, shape (num_templates, length) + The reference template signal + x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) + The signal in which to find `template` + weights : tensor, shape (length,) + Will use weighted means, variances, covariances if supplied. + centered : bool + If true, means will be subtracted (per weighted patch). + normalized : bool + If true, normalize by the variance (per weighted patch). + padding : int, optional + How far to look? if unset, we'll use half the length + conv_engine : "torch" | "numpy" + What library to use for computing cross-correlations. + If numpy, falls back to the scipy correlate function. + + Returns + ------- + corr : tensor + """ + + if conv_engine == "torch": + import torch + import torch.nn.functional as F + + conv1d = F.conv1d + npx = torch + elif conv_engine == "numpy": + conv1d = scipy_conv1d + npx = np + else: + raise ValueError(f"Unknown conv_engine {conv_engine}") + + x = npx.atleast_2d(x) + num_templates, lengtht = template.shape + num_inputs, lengthx = x.shape + + # generalize over weighted / unweighted case + device_kw = {} if conv_engine == "numpy" else dict(device=x.device) + if xmasks is None: + onesx = npx.ones((1, 1, lengthx), dtype=x.dtype, **device_kw) + wx = x[:, None, :] + else: + assert xmasks.shape == x.shape + onesx = xmasks[:, None, :] + wx = x[:, None, :] * onesx + no_weights = weights is None + if no_weights: + weights = npx.ones((1, 1, lengtht), dtype=x.dtype, **device_kw) + wt = template[:, None, :] + else: + if weights.shape == (lengtht,): + weights = weights[None, None] + elif weights.shape == (num_templates, lengtht): + weights = weights[:, None, :] + else: + assert False + wt = template[:, None, :] * weights + x = x[:, None, :] + template = template[:, None, :] + + # conv1d valid rule: + # (B,1,L),(O,1,L)->(B,O,L) + # below, we always put x on the LHS, templates on the RHS, so this reads + # (num_inputs, 1, lengthx), (num_templates, 1, lengtht) -> (num_inputs, num_templates, length_out) + + # compute expectations + # how many points in each window? seems necessary to normalize + # for numerical stability. + Nx = conv1d(onesx, weights, padding=padding) # 1,nt,l + empty = Nx == 0 + Nx[empty] = 1 + if centered: + Et = conv1d(onesx, wt, padding=padding) # 1,nt,l + Et /= Nx + Ex = conv1d(wx, weights, padding=padding) # nx,nt,l + Ex /= Nx + + # compute (weighted) covariance + # important: the formula E[XY] - EX EY is well-suited here, + # because the means are naturally subtracted correctly + # patch-wise. you couldn't pre-subtract them! + cov = conv1d(wx, wt, padding=padding) + cov /= Nx + if centered: + cov -= Ex * Et + + # compute variances for denominator, using var X = E[X^2] - (EX)^2 + if normalized: + var_template = conv1d(onesx, wt * template, padding=padding) + var_template /= Nx + var_x = conv1d(wx * x, weights, padding=padding) + var_x /= Nx + if centered: + var_template -= npx.square(Et) + var_x -= npx.square(Ex) + + # fill in zeros to avoid problems when dividing + var_template[var_template <= 0] = 1 + var_x[var_x <= 0] = 1 + + # now find the final normxcorr + corr = cov # renaming for clarity + if normalized: + corr[npx.broadcast_to(empty, corr.shape)] = 0 + corr /= npx.sqrt(var_x) + corr /= npx.sqrt(var_template) + + return corr + + +def get_weights( + Ds, + Ss, + Sigma0inv_t, + windows, + raster, + dbe, + tbe, + # @charlie raster_kw is removed in favor of post_transform only is this OK ??? + # raster_kw, + post_transform=np.log1p, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=False, +): + """Compute per-time-bin weighting for each nonrigid window""" + # determine window-weighted raster "heat" in each nonrigid window + # as a function of time + assert windows.shape[1] == dbe.size - 1 + weights = [] + p_inds = [] + for b in range((len(Ds))): + ilow, ihigh = np.flatnonzero(windows[b])[[0, -1]] + ihigh += 1 + window_sliced = windows[b, ilow:ihigh] + weights.append(window_sliced @ raster[ilow:ihigh]) + weights_orig = np.array(weights) + + # scale_fn = raster_kw["post_transform"] or raster_kw["amp_scale_fn"] + scale_fn = post_transform + if isinstance(weights_threshold_low, tuple): + nspikes_threshold_low, amp_threshold_low = weights_threshold_low + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_low = scale_fn(amp_threshold_low) * windows @ (nspikes_threshold_low * unif) + weights_threshold_low = weights_threshold_low[:, None] + if isinstance(weights_threshold_high, tuple): + nspikes_threshold_high, amp_threshold_high = weights_threshold_high + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_high = scale_fn(amp_threshold_high) * windows @ (nspikes_threshold_high * unif) + weights_threshold_high = weights_threshold_high[:, None] + weights_thresh = weights_orig.copy() + weights_thresh[weights_orig < weights_threshold_low] = 0 + weights_thresh[weights_orig > weights_threshold_high] = np.inf + + return weights, weights_thresh, p_inds + + +def weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + depth_bin_edges, + time_bin_edges, + # @charlie raster_kw is remove in favor of post_transform only + # raster_kw, + post_transform=np.log1p, + mincorr=0.0, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + time_horizon_s=None, + lambda_t=DEFAULT_LAMBDA_T, + eps=DEFAULT_EPS, + do_window_weights=True, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=True, + in_place=False, +): + """Transform the correlation matrix into the weights used in optimization.""" + extra = {} + + Ds = np.asarray(Ds) + Cs = np.asarray(Cs) + if Ds.ndim == 2: + Ds = Ds[None] + Cs = Cs[None] + B, T, T_ = Ds.shape + assert T == T_ + assert Ds.shape == Cs.shape + extra = {} + + Ss, mincorr = threshold_correlation_matrix( + Cs, + mincorr=mincorr, + mincorr_percentile=mincorr_percentile, + mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, + time_horizon_s=time_horizon_s, + bin_s=time_bin_edges[1] - time_bin_edges[0], + T=T, + in_place=in_place, + ) + extra["S"] = Ss + extra["mincorr"] = mincorr + + if not do_window_weights: + return Ss, extra + + # get weights + L_t = lambda_t * laplacian(T, eps=max(1e-5, eps)) + weights_orig, weights_thresh, Pind = get_weights( + Ds, + Ss, + L_t, + windows, + raster, + depth_bin_edges, + time_bin_edges, + # raster_kw, + post_transform=post_transform, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + progress_bar=progress_bar, + ) + extra["weights_orig"] = weights_orig + extra["weights_thresh"] = weights_thresh + extra["Pind"] = Pind + + # update noise model. we deliberately divide by zero and inf here. + Us = Ss if in_place else np.zeros_like(Ss) + with np.errstate(divide="ignore"): + # low mem impl of U = abs(1/(1/weights_thresh+1/weights_thresh'+1/S)) + np.reciprocal(Ss, out=Us) + invW = 1.0 / weights_thresh + Us += invW[:, :, None] + Us += invW[:, None, :] + np.reciprocal(Us, out=Us) + # handles possible -0s that cause issues elsewhere + np.abs(Us, out=Us) + # more readable equivalent: + # for b in range(B): + # invWbtt = invW[b, :, None] + invW[b, None, :] + # Us[b] = np.abs(1.0 / (invWbtt + 1.0 / Ss[b])) + extra["U"] = Us + + return Us, extra diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py new file mode 100644 index 0000000000..1b5eb75508 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -0,0 +1,296 @@ +import numpy as np + +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_3d_motion_histograms + + +class IterativeTemplateRegistration: + """ + Alignment function implemented by Kilosort2.5 and ported from pykilosort: + https://github.com/int-brain-lab/pykilosort/blob/ibl_prod/pykilosort/datashift2.py#L166 + + The main difference with respect to the original implementation are: + * scipy is used for gaussian smoothing + * windowing is implemented as gaussian tapering (instead of rectangular blocks) + * the 3d histogram is constructed in less cryptic way + * peak_locations are computed outside and so can either center fo mass or monopolar trianglation + contrary to kilosort2.5 use exclusively center of mass + + See https://www.science.org/doi/abs/10.1126/science.abf4588?cookieSet=1 + + Ported by Alessio Buccino into SpikeInterface + """ + + name = "iterative_template" + need_peak_location = True + params_doc = """ + bin_um: float, default: 10 + Spatial bin size in micrometers + hist_margin_um: float, default: 0 + Margin in um from histogram estimation. + Positive margin extrapolate out of the probe the motion. + Negative margin crop the motion on the border + bin_s: float, default: 2.0 + Bin duration in second + num_amp_bins: int, default: 20 + number ob bins in the histogram on the log amplitues dimension + num_shifts_global: int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations: int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block: int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma: float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma: float, + sigma parameter for kriging_kernel function + kriging_p: foat + p parameter for kriging_kernel function + kriging_d: float + d parameter for kriging_kernel function + """ + + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + bin_um=10.0, + hist_margin_um=0.0, + bin_s=2.0, + num_amp_bins=20, + num_shifts_global=15, + num_iterations=10, + num_shifts_block=5, + smoothing_sigma=0.5, + kriging_sigma=1, + kriging_p=2, + kriging_d=2, + ): + + dim = ["x", "y", "z"].index(direction) + contact_depths = recording.get_channel_locations()[:, dim] + + # spatial histogram bins + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) + + # get spatial windows + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depths=contact_depths, + spatial_bin_centers=spatial_bin_centers, + rigid=rigid, + win_margin_um=win_margin_um, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_shape=win_shape, + zero_threshold=None, + ) + + # make a 3D histogram + if verbose: + print("Making 3D motion histograms") + motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction=direction, + num_amp_bins=num_amp_bins, + bin_s=bin_s, + spatial_bin_edges=spatial_bin_edges, + ) + # temporal bins are bin center + temporal_bins = temporal_hist_bin_edges[:-1] + bin_s // 2.0 + + # do alignment + if verbose: + print("Estimating alignment shifts") + shift_indices, target_histogram, shift_covs_block = iterative_template_registration( + motion_histograms, + non_rigid_windows=non_rigid_windows, + num_shifts_global=num_shifts_global, + num_iterations=num_iterations, + num_shifts_block=num_shifts_block, + smoothing_sigma=smoothing_sigma, + kriging_sigma=kriging_sigma, + kriging_p=kriging_p, + kriging_d=kriging_d, + ) + + # convert to um + motion_array = -(shift_indices * bin_um) + + if extra: + extra["non_rigid_windows"] = non_rigid_windows + extra["motion_histograms"] = motion_histograms + extra["target_histogram"] = target_histogram + extra["shift_covs_block"] = shift_covs_block + extra["temporal_hist_bin_edges"] = temporal_hist_bin_edges + extra["spatial_hist_bin_edges"] = spatial_hist_bin_edges + + # replace nan by zeros + np.nan_to_num(motion_array, copy=False) + + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) + + return motion + + +def iterative_template_registration( + spikecounts_hist_images, + non_rigid_windows=None, + num_shifts_global=15, + num_iterations=10, + num_shifts_block=5, + smoothing_sigma=0.5, + kriging_sigma=1, + kriging_p=2, + kriging_d=2, +): + """ + + Parameters + ---------- + + spikecounts_hist_images : np.ndarray + Spike count histogram images (num_temporal_bins, num_spatial_bins, num_amps_bins) + non_rigid_windows : list, default: None + If num_non_rigid_windows > 1, this argument is required and it is a list of + windows to taper spatial bins in different blocks + num_shifts_global : int, default: 15 + Number of spatial bin shifts to consider for global alignment + num_iterations : int, default: 10 + Number of iterations for global alignment procedure + num_shifts_block : int, default: 5 + Number of spatial bin shifts to consider for non-rigid alignment + smoothing_sigma : float, default: 0.5 + Sigma of gaussian for covariance matrices smoothing + kriging_sigma : float, default: 1 + sigma parameter for kriging_kernel function + kriging_p : float, default: 2 + p parameter for kriging_kernel function + kriging_d : float, default: 2 + d parameter for kriging_kernel function + + Returns + ------- + optimal_shift_indices + Optimal shifts for each temporal and spatial bin (num_temporal_bins, num_non_rigid_windows) + target_spikecount_hist + Target histogram used for alignment (num_spatial_bins, num_amps_bins) + """ + from scipy.ndimage import gaussian_filter, gaussian_filter1d + + # F is y bins by amp bins by batches + # ysamp are the coordinates of the y bins in um + spikecounts_hist_images = spikecounts_hist_images.swapaxes(0, 1).swapaxes(1, 2) + num_temporal_bins = spikecounts_hist_images.shape[2] + + # look up and down this many y bins to find best alignment + shift_covs = np.zeros((2 * num_shifts_global + 1, num_temporal_bins)) + shifts = np.arange(-num_shifts_global, num_shifts_global + 1) + + # mean subtraction to compute covariance + F = spikecounts_hist_images + Fg = F - np.mean(F, axis=0) + + # initialize the target "frame" for alignment with a single sample + # here we removed min(299, ...) + F0 = Fg[:, :, np.floor(num_temporal_bins / 2).astype("int") - 1] + F0 = F0[:, :, np.newaxis] + + # first we do rigid registration by integer shifts + # everything is iteratively aligned until most of the shifts become 0. + best_shifts = np.zeros((num_iterations, num_temporal_bins)) + for iteration in range(num_iterations): + for t, shift in enumerate(shifts): + # for each NEW potential shift, estimate covariance + Fs = np.roll(Fg, shift, axis=0) + shift_covs[t, :] = np.mean(Fs * F0, axis=(0, 1)) + if iteration + 1 < num_iterations: + # estimate the best shifts + imax = np.argmax(shift_covs, axis=0) + # align the data by these integer shifts + for t, shift in enumerate(shifts): + ibest = imax == t + Fg[:, :, ibest] = np.roll(Fg[:, :, ibest], shift, axis=0) + best_shifts[iteration, ibest] = shift + # new target frame based on our current best alignment + F0 = np.mean(Fg, axis=2)[:, :, np.newaxis] + target_spikecount_hist = F0[:, :, 0] + + # now we figure out how to split the probe into nblocks pieces + # if len(non_rigid_windows) = 1, then we're doing rigid registration + num_non_rigid_windows = len(non_rigid_windows) + + # for each small block, we only look up and down this many samples to find + # nonrigid shift + shifts_block = np.arange(-num_shifts_block, num_shifts_block + 1) + num_shifts = len(shifts_block) + shift_covs_block = np.zeros((2 * num_shifts_block + 1, num_temporal_bins, num_non_rigid_windows)) + + # this part determines the up/down covariance for each block without + # shifting anything + for window_index in range(num_non_rigid_windows): + win = non_rigid_windows[window_index] + window_slice = np.flatnonzero(win > 1e-5) + window_slice = slice(window_slice[0], window_slice[-1]) + tiled_window = win[window_slice, np.newaxis, np.newaxis] + Ftaper = Fg[window_slice] * np.tile(tiled_window, (1,) + Fg.shape[1:]) + for t, shift in enumerate(shifts_block): + Fs = np.roll(Ftaper, shift, axis=0) + F0taper = F0[window_slice] * np.tile(tiled_window, (1,) + F0.shape[1:]) + shift_covs_block[t, :, window_index] = np.mean(Fs * F0taper, axis=(0, 1)) + + # gaussian smoothing: + # here the original my_conv2_cpu is substituted with scipy gaussian_filters + shift_covs_block_smooth = shift_covs_block.copy() + shifts_block_up = np.linspace(-num_shifts_block, num_shifts_block, (2 * num_shifts_block * 10) + 1) + # 1. 2d smoothing over time and blocks dimensions for each shift + for shift_index in range(num_shifts): + shift_covs_block_smooth[shift_index, :, :] = gaussian_filter( + shift_covs_block_smooth[shift_index, :, :], smoothing_sigma + ) # some additional smoothing for robustness, across all dimensions + # 2. 1d smoothing over shift dimension for each spatial block + for window_index in range(num_non_rigid_windows): + shift_covs_block_smooth[:, :, window_index] = gaussian_filter1d( + shift_covs_block_smooth[:, :, window_index], smoothing_sigma, axis=0 + ) # some additional smoothing for robustness, across all dimensions + upsample_kernel = kriging_kernel( + shifts_block[:, np.newaxis], shifts_block_up[:, np.newaxis], sigma=kriging_sigma, p=kriging_p, d=kriging_d + ) + + optimal_shift_indices = np.zeros((num_temporal_bins, num_non_rigid_windows)) + for window_index in range(num_non_rigid_windows): + # using the upsampling kernel K, get the upsampled cross-correlation + # curves + upsampled_cov = upsample_kernel.T @ shift_covs_block_smooth[:, :, window_index] + + # find the max index of these curves + imax = np.argmax(upsampled_cov, axis=0) + + # add the value of the shift to the last row of the matrix of shifts + # (as if it was the last iteration of the main rigid loop ) + best_shifts[num_iterations - 1, :] = shifts_block_up[imax] + + # the sum of all the shifts equals the final shifts for this block + optimal_shift_indices[:, window_index] = np.sum(best_shifts, axis=0) + + return optimal_shift_indices, target_spikecount_hist, shift_covs_block + + +def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): + from scipy.spatial.distance import cdist + + dist_xy = cdist(source_location, target_location, metric="euclidean") + K = np.exp(-((dist_xy / sigma) ** p) / d) + return K diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py new file mode 100644 index 0000000000..6fe36a6193 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -0,0 +1,72 @@ +import numpy as np + +# TODO this need a full rewrite with motion object + + +def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): + """ + Simple machinery to remove spurious fast bump in the motion vector. + Also can apply a smoothing. + + + Arguments + --------- + motion: numpy array 2d + Motion estimate in um. + temporal_bins: numpy.array 1d + temporal bins (bin center) + bin_duration_s: float + bin duration in second + speed_threshold: float (units um/s) + Maximum speed treshold between 2 bins allowed. + Expressed in um/s + sigma_smooth_s: None or float + Optional smooting gaussian kernel. + + Returns + ------- + corr : tensor + + + """ + motion_clean = motion.copy() + + # STEP 1 : + # * detect long plateau or small peak corssing the speed thresh + # * mask the period and interpolate + for i in range(motion.shape[1]): + one_motion = motion_clean[:, i] + speed = np.diff(one_motion, axis=0) / bin_duration_s + (inds,) = np.nonzero(np.abs(speed) > speed_threshold) + inds += 1 + if inds.size % 2 == 1: + # more compicated case: number of of inds is odd must remove first or last + # take the smallest duration sum + inds0 = inds[:-1] + inds1 = inds[1:] + d0 = np.sum(inds0[1::2] - inds0[::2]) + d1 = np.sum(inds1[1::2] - inds1[::2]) + if d0 < d1: + inds = inds0 + mask = np.ones(motion_clean.shape[0], dtype="bool") + for i in range(inds.size // 2): + mask[inds[i * 2] : inds[i * 2 + 1]] = False + import scipy.interpolate + + f = scipy.interpolate.interp1d(temporal_bins[mask], one_motion[mask]) + one_motion[~mask] = f(temporal_bins[~mask]) + + # Step 2 : gaussian smooth + if sigma_smooth_s is not None: + half_size = motion_clean.shape[0] // 2 + if motion_clean.shape[0] % 2 == 0: + # take care of the shift + bins = (np.arange(motion_clean.shape[0]) - half_size + 1) * bin_duration_s + else: + bins = (np.arange(motion_clean.shape[0]) - half_size) * bin_duration_s + smooth_kernel = np.exp(-(bins**2) / (2 * sigma_smooth_s**2)) + smooth_kernel /= np.sum(smooth_kernel) + smooth_kernel = smooth_kernel[:, None] + motion_clean = scipy.signal.fftconvolve(motion_clean, smooth_kernel, mode="same", axes=0) + + return motion_clean diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py new file mode 100644 index 0000000000..2d8564fc54 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import warnings +import numpy as np + + +from spikeinterface.sortingcomponents.tools import make_multi_method_doc + + +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges +from .decentralized import DecentralizedRegistration +from .iterative_template import IterativeTemplateRegistration +from .dredge import DredgeLfpRegistration, DredgeApRegistration + + +# estimate_motion > infer_motion +def estimate_motion( + recording, + peaks=None, + peak_locations=None, + direction="y", + rigid=False, + win_shape="gaussian", + win_step_um=50.0, # @alessio charlie is proposing here instead 400 + win_scale_um=150.0, # @alessio charlie is proposing here instead 400 + win_margin_um=None, + method="decentralized", + extra_outputs=False, + progress_bar=False, + verbose=False, + margin_um=None, + **method_kwargs, +): + """ + + + Estimate motion with several possible methods. + + Most of methods except dredge_lfp needs peaks and after their localization. + + Note that the way you detect peak locations (center of mass/monopolar_triangulation/grid_convolution) + have an impact on the result. + + Parameters + ---------- + recording: BaseRecording + The recording extractor + peaks: numpy array + Peak vector (complex dtype). + Needed for decentralized and iterative_template methods. + peak_locations: numpy array + Complex dtype with "x", "y", "z" fields + Needed for decentralized and iterative_template methods. + direction: "x" | "y" | "z", default: "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + + {method_doc} + + **non-rigid section** + + rigid : bool, default: False + Compute rigid (one motion for the entire probe) or non rigid motion + Rigid computation is equivalent to non-rigid with only one window with rectangular shape. + win_shape : "gaussian" | "rect" | "triangle", default: "gaussian" + The shape of the windows for non rigid. + When rigid this is force to "rect" + Nonrigid window-related arguments + The depth domain will be broken up into windows with shape controlled by win_shape, + spaced by win_step_um at a margin of win_margin_um from the boundary, and with + width controlled by win_scale_um. + When win_margin_um is None the margin is automatically set to -win_scale_um/2. + See get_spatial_windows. + win_step_um : float, default: 50 + See win_shape + win_scale_um : float, default: 150 + See win_shape + win_margin_um : None | float, default: None + See win_shape + extra_outputs: bool, default: False + If True then return an extra dict that contains variables + to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) + progress_bar: bool, default: False + Display progress bar or not + verbose: bool, default: False + If True, output is verbose + + + Returns + ------- + motion: Motion object + The motion object. + extra: dict + Optional output if `extra_outputs=True` + This dict contain histogram, pairwise_displacement usefull for ploting. + """ + + if margin_um is not None: + warnings.warn("estimate_motion() margin_um has been removed used hist_margin_um or win_margin_um") + + # TODO handle multi segment one day : Charlie this is for you + assert recording.get_num_segments() == 1, "At the moment estimate_motion handle only unique segment" + + method_class = estimate_motion_methods[method] + + if method_class.need_peak_location: + if peaks is None or peak_locations is None: + raise ValueError(f"estimate_motion: the method {method} need peaks and peak_locations") + + if extra_outputs: + extra = {} + else: + extra = None + + # run method + motion = method_class.run( + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ) + + if extra_outputs: + return motion, extra + else: + return motion + + +_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration] +estimate_motion_methods = {m.name: m for m in _methods_list} +method_doc = make_multi_method_doc(_methods_list) +estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py similarity index 99% rename from src/spikeinterface/sortingcomponents/motion_interpolation.py rename to src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 32bb7634e9..11ce11e1aa 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -27,6 +27,9 @@ def correct_motion_on_peaks(peaks, peak_locations, motion, recording): corrected_peak_locations: np.array Motion-corrected peak locations """ + if recording is None: + raise ValueError("correct_motion_on_peaks need recording to be not None") + corrected_peak_locations = peak_locations.copy() for segment_index in range(motion.num_segments): diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py new file mode 100644 index 0000000000..a48e10b3e1 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -0,0 +1,577 @@ +import warnings +import json +from pathlib import Path + +import numpy as np +import spikeinterface +from spikeinterface.core.core_tools import check_json + + +class Motion: + """ + Motion of the tissue relative the probe. + + Parameters + ---------- + displacement : numpy array 2d or list of + Motion estimate in um. + List is the number of segment. + For each semgent : + * shape (temporal bins, spatial bins) + * motion.shape[0] = temporal_bins.shape[0] + * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + temporal_bins_s : numpy.array 1d or list of + temporal bins (bin center) + spatial_bins_um : numpy.array 1d + Windows center. + spatial_bins_um.shape[0] == displacement.shape[1] + If rigid then spatial_bins_um.shape[0] == 1 + direction : str, default: 'y' + Direction of the motion. + interpolation_method : str + How to determine the displacement between bin centers? See the docs + for scipy.interpolate.RegularGridInterpolator for options. + """ + + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): + if isinstance(displacement, np.ndarray): + self.displacement = [displacement] + assert isinstance(temporal_bins_s, np.ndarray) + self.temporal_bins_s = [temporal_bins_s] + else: + assert isinstance(displacement, (list, tuple)) + self.displacement = displacement + self.temporal_bins_s = temporal_bins_s + + assert isinstance(spatial_bins_um, np.ndarray) + self.spatial_bins_um = spatial_bins_um + + self.num_segments = len(self.displacement) + self.interpolators = None + self.interpolation_method = interpolation_method + + self.direction = direction + self.dim = ["x", "y", "z"].index(direction) + self.check_properties() + + def check_properties(self): + assert all(d.ndim == 2 for d in self.displacement) + assert all(t.ndim == 1 for t in self.temporal_bins_s) + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + + def __repr__(self): + nbins = self.spatial_bins_um.shape[0] + if nbins == 1: + rigid_txt = "rigid" + else: + rigid_txt = f"non-rigid - {nbins} spatial bins" + + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" + return txt + + def make_interpolators(self): + from scipy.interpolate import RegularGridInterpolator + + self.interpolators = [ + RegularGridInterpolator( + (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method + ) + for j in range(self.num_segments) + ] + self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] + self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) + + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): + """Evaluate the motion estimate at times and positions + + Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement + at the given times and locations. + + Parameters + ---------- + times_s: np.array + locations_um: np.array + Either this is a one-dimensional array (a vector of positions along self.dimension), or + else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. + segment_index: int, default: None + The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. + grid : bool, default: False + If grid=False, the default, then times_s and locations_um should have the same one-dimensional + shape, and the returned displacement[i] is the displacement at time times_s[i] and location + locations_um[i]. + If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. + Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. + + Returns + ------- + displacement : np.array + A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) + if grid=True. + """ + if self.interpolators is None: + self.make_interpolators() + + if segment_index is None: + if self.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + + times_s = np.asarray(times_s) + locations_um = np.asarray(locations_um) + + if locations_um.ndim == 1: + locations_um = locations_um + elif locations_um.ndim == 2: + locations_um = locations_um[:, self.dim] + else: + assert False + + times_s = times_s.clip(*self.temporal_bounds[segment_index]) + locations_um = locations_um.clip(*self.spatial_bounds) + + if grid: + # construct a grid over which to evaluate the displacement + locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") + out_shape = times_s.shape + locations_um = locations_um.ravel() + times_s = times_s.ravel() + else: + # usual case: input is a point cloud + assert locations_um.shape == times_s.shape + assert times_s.ndim == 1 + out_shape = times_s.shape + + points = np.column_stack((times_s, locations_um)) + displacement = self.interpolators[segment_index](points) + # reshape to grid domain shape if necessary + displacement = displacement.reshape(out_shape) + + return displacement + + def to_dict(self): + return dict( + displacement=self.displacement, + temporal_bins_s=self.temporal_bins_s, + spatial_bins_um=self.spatial_bins_um, + interpolation_method=self.interpolation_method, + direction=self.direction, + ) + + def save(self, folder): + folder = Path(folder) + folder.mkdir(exist_ok=False, parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Motion", + num_segments=self.num_segments, + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) + + for segment_index in range(self.num_segments): + np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) + np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) + + @classmethod + def load(cls, folder): + folder = Path(folder) + + info_file = folder / f"spikeinterface_info.json" + err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." + if not info_file.exists(): + raise IOError(err_msg) + + with open(info_file, "r") as f: + info = json.load(f) + if "object" not in info or info["object"] != "Motion": + raise IOError(err_msg) + + direction = info["direction"] + interpolation_method = info["interpolation_method"] + spatial_bins_um = np.load(folder / "spatial_bins_um.npy") + displacement = [] + temporal_bins_s = [] + for segment_index in range(info["num_segments"]): + displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) + temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) + + return cls( + displacement, + temporal_bins_s, + spatial_bins_um, + direction=direction, + interpolation_method=interpolation_method, + ) + + def __eq__(self, other): + for segment_index in range(self.num_segments): + if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): + return False + if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): + return False + + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): + return False + + return True + + def copy(self): + return Motion( + [d.copy() for d in self.displacement], + [t.copy() for t in self.temporal_bins_s], + self.spatial_bins_um.copy(), + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + + +def get_spatial_windows( + contact_depths, + spatial_bin_centers, + rigid=False, + win_shape="gaussian", + win_step_um=50.0, + win_scale_um=150.0, + win_margin_um=None, + zero_threshold=None, +): + """ + Generate spatial windows (taper) for non-rigid motion. + For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. + The windowing can be gaussian or rectangular. + Windows are centered between the min/max of contact_depths. + We can ensure window to not be to close from border with win_margin_um. + + + Parameters + ---------- + contact_depths : np.ndarray + Position of electrodes of the corection direction shape=(num_channels, ) + spatial_bin_centers : np.array + The pre-computed spatial bin centers + rigid : bool, default False + If True, returns a single rectangular window + win_shape : str, default "gaussian" + Shape of the window + "gaussian" | "rect" | "triangle" + win_step_um : float + The steps at which windows are defined + win_scale_um : float, default 150. + Sigma of gaussian window if win_shape is gaussian + Width of the rectangle if win_shape is rect + win_margin_um : None | float, default None + The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows. + When None, then the margin is set to -win_scale_um./2 + zero_threshold: None | float + Lower value for thresholding to set zeros. + + Returns + ------- + windows : 2D arrays + The scaling for each window. Each element has num_spatial_bins values + shape: (num_window, spatial_bins) + window_centers: 1D np.array + The center of each window + + Notes + ----- + Note that kilosort2.5 uses overlaping rectangular windows. + Here by default we use gaussian window. + + """ + n = spatial_bin_centers.size + + if rigid: + # win_shape = 'rect' is forced + windows, window_centers = get_rigid_windows(spatial_bin_centers) + else: + if win_scale_um <= win_step_um / 5.0: + warnings.warn( + f"get_spatial_windows(): spatial windows are probably not overlapping because {win_scale_um=} and {win_step_um=}" + ) + + if win_margin_um is None: + # this ensure that first/last windows do not overflow outside the probe + win_margin_um = -win_scale_um / 2.0 + + min_ = np.min(contact_depths) - win_margin_um + max_ = np.max(contact_depths) + win_margin_um + num_windows = int((max_ - min_) // win_step_um) + + if num_windows < 1: + raise Exception( + f"get_spatial_windows(): {win_step_um=}/{win_scale_um=}/{win_margin_um=} are too large for the " + f"probe size (depth range={np.ptp(contact_depths)}). You can try to reduce them or use rigid motion." + ) + border = ((max_ - min_) % win_step_um) / 2 + window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border + windows = [] + + for win_center in window_centers: + if win_shape == "gaussian": + win = np.exp(-((spatial_bin_centers - win_center) ** 2) / (2 * win_scale_um**2)) + elif win_shape == "rect": + win = np.abs(spatial_bin_centers - win_center) < (win_scale_um / 2.0) + win = win.astype("float64") + elif win_shape == "triangle": + center_dist = np.abs(spatial_bin_centers - win_center) + in_window = center_dist <= (win_scale_um / 2.0) + win = -center_dist + win[~in_window] = 0 + win[in_window] -= win[in_window].min() + win[in_window] /= win[in_window].max() + windows.append(win) + + windows = np.array(windows) + + if zero_threshold is not None: + windows[windows < zero_threshold] = 0 + windows /= windows.sum(axis=1, keepdims=True) + + return windows, window_centers + + +def get_rigid_windows(spatial_bin_centers): + """Generate a single rectangular window for rigid motion.""" + windows = np.ones((1, spatial_bin_centers.size), dtype="float64") + window_centers = np.array([(spatial_bin_centers[0] + spatial_bin_centers[-1]) / 2.0]) + return windows, window_centers + + +def get_window_domains(windows): + """Array of windows -> list of slices where window > 0.""" + slices = [] + for w in windows: + in_window = np.flatnonzero(w) + slices.append(slice(in_window[0], in_window[-1] + 1)) + return slices + + +def scipy_conv1d(input, weights, padding="valid"): + """SciPy translation of torch F.conv1d""" + from scipy.signal import correlate + + n, c_in, length = input.shape + c_out, in_by_groups, kernel_size = weights.shape + assert in_by_groups == c_in == 1 + + if padding == "same": + mode = "same" + length_out = length + elif padding == "valid": + mode = "valid" + length_out = length - 2 * (kernel_size // 2) + elif isinstance(padding, int): + mode = "valid" + input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)]) + length_out = length - (kernel_size - 1) + 2 * padding + else: + raise ValueError(f"Unknown 'padding' value of {padding}, 'padding' must be 'same', 'valid' or an integer") + + output = np.zeros((n, c_out, length_out), dtype=input.dtype) + for m in range(n): + for c in range(c_out): + output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode) + + return output + + +def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): + # contact along one axis + probe = recording.get_probe() + dim = ["x", "y", "z"].index(direction) + contact_depths = probe.contact_positions[:, dim] + + min_ = np.min(contact_depths) - hist_margin_um + max_ = np.max(contact_depths) + hist_margin_um + spatial_bins = np.arange(min_, max_ + bin_um, bin_um) + + return spatial_bins + + +def make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=False, + avg_in_bin=True, + direction="y", + bin_s=1.0, + bin_um=2.0, + hist_margin_um=50, + spatial_bin_edges=None, + depth_smooth_um=None, + time_smooth_s=None, +): + """ + Generate 2d motion histogram in depth and time. + + Parameters + ---------- + recording : BaseRecording + The input recording + peaks : np.array + The peaks array + peak_locations : np.array + Array with peak locations + weight_with_amplitude : bool, default: False + If True, motion histogram is weighted by amplitudes + avg_in_bin : bool, default True + If true, average the amplitudes in each bin. + This is done only if weight_with_amplitude=True. + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_s : float, default: 1.0 + The temporal bin duration in s + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + hist_margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. + Ignored if spatial_bin_edges is given. + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges + depth_smooth_um: None or float + Optional gaussian smoother on histogram on depth axis. + This is given as the sigma of the gaussian in micrometers. + time_smooth_s: None or float + Optional gaussian smoother on histogram on time axis. + This is given as the sigma of the gaussian in seconds. + + Returns + ------- + motion_histogram + 2d np.array with motion histogram (num_temporal_bins, num_spatial_bins) + temporal_bin_edges + 1d array with temporal bin edges + spatial_bin_edges + 1d array with spatial bin edges + """ + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) + if spatial_bin_edges is None: + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + else: + bin_um = spatial_bin_edges[1] - spatial_bin_edges[0] + + arr = np.zeros((peaks.size, 2), dtype="float64") + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) + arr[:, 1] = peak_locations[direction] + + if weight_with_amplitude: + weights = np.abs(peaks["amplitude"]) + else: + weights = None + + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) + + # average amplitude in each bin + if weight_with_amplitude and avg_in_bin: + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) + bin_counts[bin_counts == 0] = 1 + motion_histogram = motion_histogram / bin_counts + + from scipy.ndimage import gaussian_filter1d + + if depth_smooth_um is not None: + motion_histogram = gaussian_filter1d(motion_histogram, depth_smooth_um / bin_um, axis=1, mode="constant") + + if time_smooth_s is not None: + motion_histogram = gaussian_filter1d(motion_histogram, time_smooth_s / bin_s, axis=0, mode="constant") + + return motion_histogram, temporal_bin_edges, spatial_bin_edges + + +def make_3d_motion_histograms( + recording, + peaks, + peak_locations, + direction="y", + bin_s=1.0, + bin_um=2.0, + hist_margin_um=50, + num_amp_bins=20, + log_transform=True, + spatial_bin_edges=None, +): + """ + Generate 3d motion histograms in depth, amplitude, and time. + This is used by the "iterative_template_registration" (Kilosort2.5) method. + + + Parameters + ---------- + recording : BaseRecording + The input recording + peaks : np.array + The peaks array + peak_locations : np.array + Array with peak locations + direction : "x" | "y" | "z", default: "y" + The depth direction + bin_s : float, default: 1.0 + The temporal bin duration in s. + bin_um : float, default: 2.0 + The spatial bin size in um. Ignored if spatial_bin_edges is given. + hist_margin_um : float, default: 50 + The margin to add to the minimum and maximum positions before spatial binning. + Ignored if spatial_bin_edges is given. + log_transform : bool, default: True + If True, histograms are log-transformed + spatial_bin_edges : np.array, default: None + The pre-computed spatial bin edges + + Returns + ------- + motion_histograms + 3d np.array with motion histogram (num_temporal_bins, num_spatial_bins, num_amp_bins) + temporal_bin_edges + 1d array with temporal bin edges + spatial_bin_edges + 1d array with spatial bin edges + """ + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_s, bin_s) + if spatial_bin_edges is None: + spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) + + # pre-compute abs amplitude and ranges for scaling + amplitude_bin_edges = np.linspace(0, 1, num_amp_bins + 1) + abs_peaks = np.abs(peaks["amplitude"]) + max_peak_amp = np.max(abs_peaks) + min_peak_amp = np.min(abs_peaks) + # log amplitudes and scale between 0-1 + abs_peaks_log_norm = (np.log10(abs_peaks) - np.log10(min_peak_amp)) / ( + np.log10(max_peak_amp) - np.log10(min_peak_amp) + ) + + arr = np.zeros((peaks.size, 3), dtype="float64") + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) + arr[:, 1] = peak_locations[direction] + arr[:, 2] = abs_peaks_log_norm + + motion_histograms, edges = np.histogramdd( + arr, + bins=( + temporal_bin_edges, + spatial_bin_edges, + amplitude_bin_edges, + ), + ) + + if log_transform: + motion_histograms = np.log2(1 + motion_histograms) + + return motion_histograms, temporal_bin_edges, spatial_bin_edges diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py new file mode 100644 index 0000000000..8133c1fa6b --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_drege.py @@ -0,0 +1,9 @@ +import pytest + + +def test_dredge_online_lfp(): + pass + + +if __name__ == "__main__": + pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py similarity index 90% rename from src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py index af62ba52ec..3c83a56b9d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py @@ -3,7 +3,7 @@ import numpy as np import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -18,12 +18,11 @@ plt.show() -@pytest.fixture(scope="module") -def setup_module(tmp_path_factory): - recording, sorting = make_dataset() - cache_folder = tmp_path_factory.mktemp("cache_folder") +def setup_dataset_and_peaks(cache_folder): + print(cache_folder, type(cache_folder)) cache_folder.mkdir(parents=True, exist_ok=True) + recording, sorting = make_dataset() # detect and localize extract_dense_waveforms = ExtractDenseWaveforms(recording, ms_before=0.1, ms_after=0.3, return_output=False) pipeline_nodes = [ @@ -49,9 +48,16 @@ def setup_module(tmp_path_factory): return recording, sorting, cache_folder -def test_estimate_motion(setup_module): +@pytest.fixture(scope="module", name="dataset") +def dataset_fixture(create_cache_folder): + cache_folder = create_cache_folder / "motion_estimation" + return setup_dataset_and_peaks(cache_folder) + + +def test_estimate_motion(dataset): # recording, sorting = make_dataset() - recording, sorting, cache_folder = setup_module + recording, sorting, cache_folder = dataset + peaks = np.load(cache_folder / "dataset_peaks.npy") peak_locations = np.load(cache_folder / "dataset_peak_locations.npy") @@ -146,14 +152,14 @@ def test_estimate_motion(setup_module): kwargs = dict( direction="y", - bin_duration_s=1.0, + bin_s=1.0, bin_um=10.0, margin_um=5, - output_extra_check=True, + extra_outputs=True, ) kwargs.update(cases_kwargs) - motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) + motion, extra = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion if cases_kwargs["rigid"]: @@ -215,5 +221,9 @@ def test_estimate_motion(setup_module): if __name__ == "__main__": - setup_module() - test_estimate_motion() + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + cache_folder = Path(tmpdirname) + args = setup_dataset_and_peaks(cache_folder) + test_estimate_motion(args) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py similarity index 97% rename from src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index cb26560272..e022f0cc6c 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,13 +4,13 @@ import pytest import spikeinterface.core as sc from spikeinterface import download_dataset -from spikeinterface.sortingcomponents.motion_interpolation import ( +from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py similarity index 97% rename from src/spikeinterface/sortingcomponents/tests/test_motion_utils.py rename to src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py index 0b67be39c0..73c469c955 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_utils.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion.motion_utils import Motion from spikeinterface.generation import make_one_displacement_vector if hasattr(pytest, "global_test_folder"): diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py deleted file mode 100644 index 3134d68681..0000000000 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ /dev/null @@ -1,1547 +0,0 @@ -from __future__ import annotations - -from tqdm.auto import tqdm, trange -import numpy as np - - -from .motion_utils import Motion -from .tools import make_multi_method_doc - -try: - import torch - import torch.nn.functional as F - - HAVE_TORCH = True -except ImportError: - HAVE_TORCH = False - - -def estimate_motion( - recording, - peaks, - peak_locations, - direction="y", - bin_duration_s=10.0, - bin_um=10.0, - margin_um=0.0, - rigid=False, - win_shape="gaussian", - win_step_um=50.0, - win_sigma_um=150.0, - post_clean=False, - speed_threshold=30, - sigma_smooth_s=None, - method="decentralized", - output_extra_check=False, - progress_bar=False, - upsample_to_histogram_bin=False, - verbose=False, - **method_kwargs, -): - """ - Estimate motion for given peaks and after their localization. - - Note that the way you detect peak locations (center of mass/monopolar triangulation) have an impact on the result. - - Parameters - ---------- - recording: BaseRecording - The recording extractor - peaks: numpy array - Peak vector (complex dtype) - peak_locations: numpy array - Complex dtype with "x", "y", "z" fields - - {method_doc} - - **histogram section** - - direction: "x" | "y" | "z", default: "y" - Dimension on which the motion is estimated. "y" is depth along the probe. - bin_duration_s: float, default: 10 - Bin duration in second - bin_um: float, default: 10 - Spatial bin size in micrometers - margin_um: float, default: 0 - Margin in um to exclude from histogram estimation and - non-rigid smoothing functions to avoid edge effects. - Positive margin extrapolate out of the probe the motion. - Negative margin crop the motion on the border - - **non-rigid section** - - rigid : bool, default: False - Compute rigid (one motion for the entire probe) or non rigid motion - Rigid computation is equivalent to non-rigid with only one window with rectangular shape. - win_shape: "gaussian" | "rect" | "triangle", default: "gaussian" - The shape of the windows for non rigid. - When rigid this is force to "rect" - win_step_um: float, default: 50 - Step deteween window - win_sigma_um: float, default: 150 - Sigma of the gaussian window - - **motion cleaning section** - - post_clean: bool, default: False - Apply some post cleaning to motion matrix or not - speed_threshold: float default: 30. - Detect to fast motion bump and remove then with interpolation - sigma_smooth_s: None or float - Optional smooting gaussian kernel when not None - - output_extra_check: bool, default: False - If True then return an extra dict that contains variables - to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) - upsample_to_histogram_bin: bool or None, default: False - If True then upsample the returned motion array to the number of depth bins specified by bin_um. - When None: - * for non rigid case: then automatically True - * for rigid (non_rigid_kwargs=None): automatically False - This feature is in fact a bad idea and the interpolation should be done outside using better methods - progress_bar: bool, default: False - Display progress bar or not - verbose: bool, default: False - If True, output is verbose - - - Returns - ------- - motion: Motion object - The motion object. - extra_check: dict - Optional output if `output_extra_check=True` - This dict contain histogram, pairwise_displacement usefull for ploting. - """ - # TODO handle multi segment one day - assert recording.get_num_segments() == 1 - - if output_extra_check: - extra_check = {} - else: - extra_check = None - - # contact positions - probe = recording.get_probe() - dim = ["x", "y", "z"].index(direction) - contact_pos = probe.contact_positions[:, dim] - - # spatial bins - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - # get windows - non_rigid_windows, non_rigid_window_centers = get_windows( - rigid, bin_um, contact_pos, spatial_bin_edges, margin_um, win_step_um, win_sigma_um, win_shape - ) - - if output_extra_check: - extra_check["non_rigid_windows"] = non_rigid_windows - - # run method - method_class = estimate_motion_methods[method] - motion_array, temporal_bins = method_class.run( - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - **method_kwargs, - ) - - # replace nan by zeros - np.nan_to_num(motion_array, copy=False) - - if post_clean: - motion_array = clean_motion_vector( - motion_array, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s - ) - - if upsample_to_histogram_bin is None: - upsample_to_histogram_bin = not rigid - if upsample_to_histogram_bin: - extra_check["motion_array"] = motion_array - extra_check["non_rigid_window_centers"] = non_rigid_window_centers - non_rigid_windows = np.array(non_rigid_windows) - non_rigid_windows /= non_rigid_windows.sum(axis=0, keepdims=True) - non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 - motion_array = motion_array @ non_rigid_windows - - # TODO handle multi segment - motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) - - if output_extra_check: - return motion, extra_check - else: - return motion - - -class DecentralizedRegistration: - """ - Method developed by the Paninski's group from Columbia university: - Charlie Windolf, Julien Boussard, Erdem Varol, Hyun Dong Lee - - This method is also known as DREDGe, but this implemenation does not use LFP signals. - - Original reference: - DECENTRALIZED MOTION INFERENCE AND REGISTRATION OF NEUROPIXEL DATA - https://ieeexplore.ieee.org/document/9414145 - https://proceedings.neurips.cc/paper/2021/hash/b950ea26ca12daae142bd74dba4427c8-Abstract.html - - This code was improved during Spike Sorting NY Hackathon 2022 by Erdem Varol and Charlie Windolf. - An additional major improvement can be found in this paper: - https://www.biorxiv.org/content/biorxiv/early/2022/12/05/2022.12.04.519043.full.pdf - - - Here are some various implementations by the original team: - https://github.com/int-brain-lab/spikes_localization_registration/blob/main/registration_pipeline/image_based_motion_estimate.py#L211 - https://github.com/cwindolf/spike-psvae/tree/main/spike_psvae - https://github.com/evarol/DREDge - """ - - name = "decentralized" - params_doc = """ - histogram_depth_smooth_um: None or float - Optional gaussian smoother on histogram on depth axis. - This is given as the sigma of the gaussian in micrometers. - histogram_time_smooth_s: None or float - Optional gaussian smoother on histogram on time axis. - This is given as the sigma of the gaussian in seconds. - pairwise_displacement_method: "conv" or "phase_cross_correlation" - How to estimate the displacement in the pairwise matrix. - max_displacement_um: float - Maximum possible displacement in micrometers. - weight_scale: "linear" or "exp" - For parwaise displacement, how to to rescale the associated weight matrix. - error_sigma: float, default: 0.2 - In case weight_scale="exp" this controls the sigma of the exponential. - conv_engine: "numpy" or "torch" or None, default: None - In case of pairwise_displacement_method="conv", what library to use to compute - the underlying correlation - torch_device=None - In case of conv_engine="torch", you can control which device (cpu or gpu) - batch_size: int - Size of batch for the convolution. Increasing this will speed things up dramatically - on GPUs and sometimes on CPU as well. - corr_threshold: float - Minimum correlation between pair of time bins in order for these to be - considered when optimizing a global displacment vector to align with - the pairwise displacements. - time_horizon_s: None or float - When not None the parwise discplament matrix is computed in a small time horizon. - In short only pair of bins close in time. - So the pariwaise matrix is super sparse and have values only the diagonal. - convergence_method: "lsmr" | "lsqr_robust" | "gradient_descent", default: "lsqr_robust" - Which method to use to compute the global displacement vector from the pairwise matrix. - robust_regression_sigma: float - Use for convergence_method="lsqr_robust" for iterative selection of the regression. - temporal_prior : bool, default: True - Ensures continuity across time, unless there is evidence in the recording for jumps. - spatial_prior : bool, default: False - Ensures continuity across space. Not usually necessary except in recordings with - glitches across space. - force_spatial_median_continuity: bool, default: False - When spatial_prior=False we can optionally apply a median continuity across spatial windows. - reference_displacement : string, one of: "mean", "median", "time", "mode_search" - Strategy for picking what is considered displacement=0. - - "mean" : the mean displacement is subtracted - - "median" : the median displacement is subtracted - - "time" : the displacement at a given time (in seconds) is subtracted - - "mode_search" : an attempt is made to guess the mode. needs work. - lsqr_robust_n_iter: int - Number of iteration for convergence_method="lsqr_robust". - """ - - @classmethod - def run( - cls, - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - histogram_depth_smooth_um=None, - histogram_time_smooth_s=None, - pairwise_displacement_method="conv", - max_displacement_um=100.0, - weight_scale="linear", - error_sigma=0.2, - conv_engine=None, - torch_device=None, - batch_size=1, - corr_threshold=0.0, - time_horizon_s=None, - convergence_method="lsqr_robust", - soft_weights=False, - normalized_xcorr=True, - centered_xcorr=True, - temporal_prior=True, - spatial_prior=False, - force_spatial_median_continuity=False, - reference_displacement="median", - reference_displacement_time_s=0, - robust_regression_sigma=2, - lsqr_robust_n_iter=20, - weight_with_amplitude=False, - ): - # use torch if installed - if conv_engine is None: - conv_engine = "torch" if HAVE_TORCH else "numpy" - - # make 2D histogram raster - if verbose: - print("Computing motion histogram") - - motion_histogram, temporal_hist_bin_edges, spatial_hist_bin_edges = make_2d_motion_histogram( - recording, - peaks, - peak_locations, - direction=direction, - bin_duration_s=bin_duration_s, - spatial_bin_edges=spatial_bin_edges, - weight_with_amplitude=weight_with_amplitude, - ) - import scipy.signal - - if histogram_depth_smooth_um is not None: - bins = np.arange(motion_histogram.shape[1]) * bin_um - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_depth_smooth_um**2)) - smooth_kernel /= np.sum(smooth_kernel) - - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[None, :], mode="same", axes=1) - - if histogram_time_smooth_s is not None: - bins = np.arange(motion_histogram.shape[0]) * bin_duration_s - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_time_smooth_s**2)) - smooth_kernel /= np.sum(smooth_kernel) - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[:, None], mode="same", axes=0) - - if extra_check is not None: - extra_check["motion_histogram"] = motion_histogram - extra_check["pairwise_displacement_list"] = [] - extra_check["temporal_hist_bin_edges"] = temporal_hist_bin_edges - extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges - - # temporal bins are bin center - temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) - - motion = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) - windows_iter = non_rigid_windows - if progress_bar: - windows_iter = tqdm(windows_iter, desc="windows") - if spatial_prior: - all_pairwise_displacements = np.empty( - (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 - ) - all_pairwise_displacement_weights = np.empty( - (len(non_rigid_windows), temporal_bins.size, temporal_bins.size), dtype=np.float64 - ) - for i, win in enumerate(windows_iter): - window_slice = np.flatnonzero(win > 1e-5) - window_slice = slice(window_slice[0], window_slice[-1]) - if verbose: - print(f"Computing pairwise displacement: {i + 1} / {len(non_rigid_windows)}") - - pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement( - motion_histogram[:, window_slice], - bin_um, - window=win[window_slice], - method=pairwise_displacement_method, - weight_scale=weight_scale, - error_sigma=error_sigma, - conv_engine=conv_engine, - torch_device=torch_device, - batch_size=batch_size, - max_displacement_um=max_displacement_um, - normalized_xcorr=normalized_xcorr, - centered_xcorr=centered_xcorr, - corr_threshold=corr_threshold, - time_horizon_s=time_horizon_s, - bin_duration_s=bin_duration_s, - progress_bar=False, - ) - - if spatial_prior: - all_pairwise_displacements[i] = pairwise_displacement - all_pairwise_displacement_weights[i] = pairwise_displacement_weight - - if extra_check is not None: - extra_check["pairwise_displacement_list"].append(pairwise_displacement) - - if verbose: - print(f"Computing global displacement: {i + 1} / {len(non_rigid_windows)}") - - # TODO: if spatial_prior, do this after the loop - if not spatial_prior: - motion[:, i] = compute_global_displacement( - pairwise_displacement, - pairwise_displacement_weight=pairwise_displacement_weight, - convergence_method=convergence_method, - robust_regression_sigma=robust_regression_sigma, - lsqr_robust_n_iter=lsqr_robust_n_iter, - temporal_prior=temporal_prior, - spatial_prior=spatial_prior, - soft_weights=soft_weights, - progress_bar=False, - ) - - if spatial_prior: - motion = compute_global_displacement( - all_pairwise_displacements, - pairwise_displacement_weight=all_pairwise_displacement_weights, - convergence_method=convergence_method, - robust_regression_sigma=robust_regression_sigma, - lsqr_robust_n_iter=lsqr_robust_n_iter, - temporal_prior=temporal_prior, - spatial_prior=spatial_prior, - soft_weights=soft_weights, - progress_bar=False, - ) - elif len(non_rigid_windows) > 1: - # if spatial_prior is False, we still want keep the spatial bins - # correctly offset from each other - if force_spatial_median_continuity: - for i in range(len(non_rigid_windows) - 1): - motion[:, i + 1] -= np.median(motion[:, i + 1] - motion[:, i]) - - # try to avoid constant offset - # let the user choose how to do this. here are some ideas. - # (one can also -= their own number on the result of this function.) - if reference_displacement == "mean": - motion -= motion.mean() - elif reference_displacement == "median": - motion -= np.median(motion) - elif reference_displacement == "time": - # reference the motion to 0 at a specific time, independently in each window - reference_displacement_bin = np.digitize(reference_displacement_time_s, temporal_hist_bin_edges) - 1 - motion -= motion[reference_displacement_bin, :] - elif reference_displacement == "mode_search": - # just a sketch of an idea - # things might want to change, should have a configurable bin size, - # should use a call to histogram instead of the loop, ... - step_size = 0.1 - round_mode = np.round # floor? - best_ref = np.median(motion) - max_zeros = np.sum(round_mode(motion - best_ref) == 0) - for ref in np.arange(np.floor(motion.min()), np.ceil(motion.max()), step_size): - n_zeros = np.sum(round_mode(motion - ref) == 0) - if n_zeros > max_zeros: - max_zeros = n_zeros - best_ref = ref - motion -= best_ref - - return motion, temporal_bins - - -class IterativeTemplateRegistration: - """ - Alignment function implemented by Kilosort2.5 and ported from pykilosort: - https://github.com/int-brain-lab/pykilosort/blob/ibl_prod/pykilosort/datashift2.py#L166 - - The main difference with respect to the original implementation are: - * scipy is used for gaussian smoothing - * windowing is implemented as gaussian tapering (instead of rectangular blocks) - * the 3d histogram is constructed in less cryptic way - * peak_locations are computed outside and so can either center fo mass or monopolar trianglation - contrary to kilosort2.5 use exclusively center of mass - - See https://www.science.org/doi/abs/10.1126/science.abf4588?cookieSet=1 - - Ported by Alessio Buccino into SpikeInterface - """ - - name = "iterative_template" - params_doc = """ - num_amp_bins: int, default: 20 - number ob bins in the histogram on the log amplitues dimension - num_shifts_global: int, default: 15 - Number of spatial bin shifts to consider for global alignment - num_iterations: int, default: 10 - Number of iterations for global alignment procedure - num_shifts_block: int, default: 5 - Number of spatial bin shifts to consider for non-rigid alignment - smoothing_sigma: float, default: 0.5 - Sigma of gaussian for covariance matrices smoothing - kriging_sigma: float, - sigma parameter for kriging_kernel function - kriging_p: foat - p parameter for kriging_kernel function - kriging_d: float - d parameter for kriging_kernel function - """ - - @classmethod - def run( - cls, - recording, - peaks, - peak_locations, - direction, - bin_duration_s, - bin_um, - spatial_bin_edges, - non_rigid_windows, - verbose, - progress_bar, - extra_check, - num_amp_bins=20, - num_shifts_global=15, - num_iterations=10, - num_shifts_block=5, - smoothing_sigma=0.5, - kriging_sigma=1, - kriging_p=2, - kriging_d=2, - ): - # make a 3D histogram - motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( - recording, - peaks, - peak_locations, - direction=direction, - num_amp_bins=num_amp_bins, - bin_duration_s=bin_duration_s, - spatial_bin_edges=spatial_bin_edges, - ) - # temporal bins are bin center - temporal_bins = temporal_hist_bin_edges[:-1] + bin_duration_s // 2.0 - - # do alignment - shift_indices, target_histogram, shift_covs_block = iterative_template_registration( - motion_histograms, - non_rigid_windows=non_rigid_windows, - num_shifts_global=num_shifts_global, - num_iterations=num_iterations, - num_shifts_block=num_shifts_block, - smoothing_sigma=smoothing_sigma, - kriging_sigma=kriging_sigma, - kriging_p=kriging_p, - kriging_d=kriging_d, - ) - - # convert to um - motion = -(shift_indices * bin_um) - - if extra_check: - extra_check["motion_histograms"] = motion_histograms - extra_check["target_histogram"] = target_histogram - extra_check["shift_covs_block"] = shift_covs_block - extra_check["temporal_hist_bin_edges"] = temporal_hist_bin_edges - extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges - - return motion, temporal_bins - - -_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration] -estimate_motion_methods = {m.name: m for m in _methods_list} -method_doc = make_multi_method_doc(_methods_list) -estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) - - -def get_spatial_bin_edges(recording, direction, margin_um, bin_um): - # contact along one axis - probe = recording.get_probe() - dim = ["x", "y", "z"].index(direction) - contact_pos = probe.contact_positions[:, dim] - - min_ = np.min(contact_pos) - margin_um - max_ = np.max(contact_pos) + margin_um - spatial_bins = np.arange(min_, max_ + bin_um, bin_um) - - return spatial_bins - - -def get_windows(rigid, bin_um, contact_pos, spatial_bin_edges, margin_um, win_step_um, win_sigma_um, win_shape): - """ - Generate spatial windows (taper) for non-rigid motion. - For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. - The windowing can be gaussian or rectangular. - - Parameters - ---------- - rigid : bool - If True, returns a single rectangular window - bin_um : float - Spatial bin size in um - contact_pos : np.ndarray - Position of electrodes (num_channels, 2) - spatial_bin_edges : np.array - The pre-computed spatial bin edges - margin_um : float - The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows.= - win_step_um : float - The steps at which windows are defined - win_sigma_um : float - Sigma of gaussian window (if win_shape is gaussian) - win_shape : float - "gaussian" | "rect" - - Returns - ------- - non_rigid_windows : list of 1D arrays - The scaling for each window. Each element has num_spatial_bins values - non_rigid_window_centers: 1D np.array - The center of each window - - Notes - ----- - Note that kilosort2.5 uses overlaping rectangular windows. - Here by default we use gaussian window. - - """ - bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0 - n = bin_centers.size - - if rigid: - # win_shape = 'rect' is forced - non_rigid_windows = [np.ones(n, dtype="float64")] - middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0 - non_rigid_window_centers = np.array([middle]) - else: - assert win_sigma_um >= win_step_um, f"win_sigma_um too low {win_sigma_um} compared to win_step_um {win_step_um}" - - min_ = np.min(contact_pos) - margin_um - max_ = np.max(contact_pos) + margin_um - num_non_rigid_windows = int((max_ - min_) // win_step_um) - border = ((max_ - min_) % win_step_um) / 2 - non_rigid_window_centers = np.arange(num_non_rigid_windows + 1) * win_step_um + min_ + border - non_rigid_windows = [] - - for win_center in non_rigid_window_centers: - if win_shape == "gaussian": - win = np.exp(-((bin_centers - win_center) ** 2) / (2 * win_sigma_um**2)) - elif win_shape == "rect": - win = np.abs(bin_centers - win_center) < (win_sigma_um / 2.0) - win = win.astype("float64") - elif win_shape == "triangle": - center_dist = np.abs(bin_centers - win_center) - in_window = center_dist <= (win_sigma_um / 2.0) - win = -center_dist - win[~in_window] = 0 - win[in_window] -= win[in_window].min() - win[in_window] /= win[in_window].max() - - non_rigid_windows.append(win) - - return non_rigid_windows, non_rigid_window_centers - - -def make_2d_motion_histogram( - recording, - peaks, - peak_locations, - weight_with_amplitude=False, - direction="y", - bin_duration_s=1.0, - bin_um=2.0, - margin_um=50, - spatial_bin_edges=None, -): - """ - Generate 2d motion histogram in depth and time. - - Parameters - ---------- - recording : BaseRecording - The input recording - peaks : np.array - The peaks array - peak_locations : np.array - Array with peak locations - weight_with_amplitude : bool, default: False - If True, motion histogram is weighted by amplitudes - direction : "x" | "y" | "z", default: "y" - The depth direction - bin_duration_s : float, default: 1.0 - The temporal bin duration in s - bin_um : float, default: 2.0 - The spatial bin size in um. Ignored if spatial_bin_edges is given. - margin_um : float, default: 50 - The margin to add to the minimum and maximum positions before spatial binning. - Ignored if spatial_bin_edges is given. - spatial_bin_edges : np.array, default: None - The pre-computed spatial bin edges - - Returns - ------- - motion_histogram - 2d np.array with motion histogram (num_temporal_bins, num_spatial_bins) - temporal_bin_edges - 1d array with temporal bin edges - spatial_bin_edges - 1d array with spatial bin edges - """ - n_samples = recording.get_num_samples() - mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) - temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) - if spatial_bin_edges is None: - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) - arr[:, 1] = peak_locations[direction] - - if weight_with_amplitude: - weights = np.abs(peaks["amplitude"]) - else: - weights = None - - motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) - - # average amplitude in each bin - if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) - bin_counts[bin_counts == 0] = 1 - motion_histogram = motion_histogram / bin_counts - - return motion_histogram, temporal_bin_edges, spatial_bin_edges - - -def make_3d_motion_histograms( - recording, - peaks, - peak_locations, - direction="y", - bin_duration_s=1.0, - bin_um=2.0, - margin_um=50, - num_amp_bins=20, - log_transform=True, - spatial_bin_edges=None, -): - """ - Generate 3d motion histograms in depth, amplitude, and time. - This is used by the "iterative_template_registration" (Kilosort2.5) method. - - - Parameters - ---------- - recording : BaseRecording - The input recording - peaks : np.array - The peaks array - peak_locations : np.array - Array with peak locations - direction : "x" | "y" | "z", default: "y" - The depth direction - bin_duration_s : float, default: 1.0 - The temporal bin duration in s. - bin_um : float, default: 2.0 - The spatial bin size in um. Ignored if spatial_bin_edges is given. - margin_um : float, default: 50 - The margin to add to the minimum and maximum positions before spatial binning. - Ignored if spatial_bin_edges is given. - log_transform : bool, default: True - If True, histograms are log-transformed - spatial_bin_edges : np.array, default: None - The pre-computed spatial bin edges - - Returns - ------- - motion_histograms - 3d np.array with motion histogram (num_temporal_bins, num_spatial_bins, num_amp_bins) - temporal_bin_edges - 1d array with temporal bin edges - spatial_bin_edges - 1d array with spatial bin edges - """ - n_samples = recording.get_num_samples() - mint_s = recording.sample_index_to_time(0) - maxt_s = recording.sample_index_to_time(n_samples) - temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) - if spatial_bin_edges is None: - spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) - - # pre-compute abs amplitude and ranges for scaling - amplitude_bin_edges = np.linspace(0, 1, num_amp_bins + 1) - abs_peaks = np.abs(peaks["amplitude"]) - max_peak_amp = np.max(abs_peaks) - min_peak_amp = np.min(abs_peaks) - # log amplitudes and scale between 0-1 - abs_peaks_log_norm = (np.log10(abs_peaks) - np.log10(min_peak_amp)) / ( - np.log10(max_peak_amp) - np.log10(min_peak_amp) - ) - - arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) - arr[:, 1] = peak_locations[direction] - arr[:, 2] = abs_peaks_log_norm - - motion_histograms, edges = np.histogramdd( - arr, - bins=( - temporal_bin_edges, - spatial_bin_edges, - amplitude_bin_edges, - ), - ) - - if log_transform: - motion_histograms = np.log2(1 + motion_histograms) - - return motion_histograms, temporal_bin_edges, spatial_bin_edges - - -def compute_pairwise_displacement( - motion_hist, - bin_um, - method="conv", - weight_scale="linear", - error_sigma=0.2, - conv_engine="numpy", - torch_device=None, - batch_size=1, - max_displacement_um=1500, - corr_threshold=0, - time_horizon_s=None, - normalized_xcorr=True, - centered_xcorr=True, - bin_duration_s=None, - progress_bar=False, - window=None, -): - """ - Compute pairwise displacement - """ - from scipy import linalg - - assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" - size = motion_hist.shape[0] - pairwise_displacement = np.zeros((size, size), dtype="float32") - - if time_horizon_s is not None: - band_width = int(np.ceil(time_horizon_s / bin_duration_s)) - if band_width >= size: - time_horizon_s = None - - if conv_engine == "torch": - if torch_device is None: - torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - if method == "conv": - if max_displacement_um is None: - n = motion_hist.shape[1] // 2 - else: - n = min( - motion_hist.shape[1] // 2, - int(np.ceil(max_displacement_um // bin_um)), - ) - possible_displacement = np.arange(-n, n + 1) * bin_um - - xrange = trange if progress_bar else range - - motion_hist_engine = motion_hist - window_engine = window - if conv_engine == "torch": - motion_hist_engine = torch.as_tensor(motion_hist, dtype=torch.float32, device=torch_device) - window_engine = torch.as_tensor(window, dtype=torch.float32, device=torch_device) - - pairwise_displacement = np.empty((size, size), dtype=np.float32) - correlation = np.empty((size, size), dtype=motion_hist.dtype) - - for i in xrange(0, size, batch_size): - corr = normxcorr1d( - motion_hist_engine, - motion_hist_engine[i : i + batch_size], - weights=window_engine, - padding=possible_displacement.size // 2, - conv_engine=conv_engine, - normalized=normalized_xcorr, - centered=centered_xcorr, - ) - if conv_engine == "torch": - max_corr, best_disp_inds = torch.max(corr, dim=2) - best_disp = possible_displacement[best_disp_inds.cpu()] - pairwise_displacement[i : i + batch_size] = best_disp - correlation[i : i + batch_size] = max_corr.cpu() - elif conv_engine == "numpy": - best_disp_inds = np.argmax(corr, axis=2) - max_corr = np.take_along_axis(corr, best_disp_inds[..., None], 2).squeeze() - best_disp = possible_displacement[best_disp_inds] - pairwise_displacement[i : i + batch_size] = best_disp - correlation[i : i + batch_size] = max_corr - - if corr_threshold is not None and corr_threshold > 0: - which = correlation > corr_threshold - correlation *= which - - elif method == "phase_cross_correlation": - # this 'phase_cross_correlation' is an old idea from Julien/Charlie/Erden that is kept for testing - # but this is not very releveant - try: - import skimage.registration - except ImportError: - raise ImportError("To use the 'phase_cross_correlation' method install scikit-image") - - errors = np.zeros((size, size), dtype="float32") - loop = range(size) - if progress_bar: - loop = tqdm(loop) - for i in loop: - for j in range(size): - shift, error, diffphase = skimage.registration.phase_cross_correlation( - motion_hist[i, :], motion_hist[j, :] - ) - pairwise_displacement[i, j] = shift * bin_um - errors[i, j] = error - correlation = 1 - errors - - else: - raise ValueError( - f"method {method} does not exist for compute_pairwise_displacement. Current possible methods are" - f" 'conv' or 'phase_cross_correlation'" - ) - - if weight_scale == "linear": - # between 0 and 1 - pairwise_displacement_weight = correlation - elif weight_scale == "exp": - pairwise_displacement_weight = np.exp((correlation - 1) / error_sigma) - - # handle the time horizon by multiplying the weights by a - # matrix with the time horizon on its diagonal bands. - if method == "conv" and time_horizon_s is not None and time_horizon_s > 0: - horizon_matrix = linalg.toeplitz( - np.r_[np.ones(band_width, dtype=bool), np.zeros(size - band_width, dtype=bool)] - ) - pairwise_displacement_weight *= horizon_matrix - - return pairwise_displacement, pairwise_displacement_weight - - -_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust") - - -def compute_global_displacement( - pairwise_displacement, - pairwise_displacement_weight=None, - sparse_mask=None, - temporal_prior=True, - spatial_prior=True, - soft_weights=False, - convergence_method="lsmr", - robust_regression_sigma=2, - lsqr_robust_n_iter=20, - progress_bar=False, -): - """ - Compute global displacement - - Arguments - --------- - pairwise_displacement : time x time array - pairwise_displacement_weight : time x time array - sparse_mask : time x time array - convergence_method : str - One of "gradient" - - """ - import scipy - from scipy.optimize import minimize - from scipy.sparse import csr_matrix - from scipy.sparse.linalg import lsqr - from scipy.stats import zscore - - if convergence_method == "gradient_descent": - size = pairwise_displacement.shape[0] - - D = pairwise_displacement - if pairwise_displacement_weight is not None or sparse_mask is not None: - # weighted problem - if pairwise_displacement_weight is None: - pairwise_displacement_weight = np.ones_like(D) - if sparse_mask is None: - sparse_mask = np.ones_like(D) - W = pairwise_displacement_weight * sparse_mask - - I, J = np.nonzero(W > 0) - Wij = W[I, J] - Dij = D[I, J] - W = csr_matrix((Wij, (I, J)), shape=W.shape) - WD = csr_matrix((Wij * Dij, (I, J)), shape=W.shape) - fixed_terms = (W @ WD).diagonal() - (WD @ W).diagonal() - diag_WW = (W @ W).diagonal() - Wsq = W.power(2) - - def obj(p): - return 0.5 * np.square(Wij * (Dij - (p[I] - p[J]))).sum() - - def jac(p): - return fixed_terms - 2 * (Wsq @ p) + 2 * p * diag_WW - - else: - # unweighted problem, it's faster when we have no weights - fixed_terms = -D.sum(axis=1) + D.sum(axis=0) - - def obj(p): - v = np.square((D - (p[:, None] - p[None, :]))).sum() - return 0.5 * v - - def jac(p): - return fixed_terms + 2 * (size * p - p.sum()) - - res = minimize(fun=obj, jac=jac, x0=D.mean(axis=1), method="L-BFGS-B") - if not res.success: - print("Global displacement gradient descent had an error") - displacement = res.x - - elif convergence_method == "lsqr_robust": - - if sparse_mask is not None: - I, J = np.nonzero(sparse_mask > 0) - elif pairwise_displacement_weight is not None: - I, J = pairwise_displacement_weight.nonzero() - else: - I, J = np.nonzero(np.ones_like(pairwise_displacement, dtype=bool)) - - nnz_ones = np.ones(I.shape[0], dtype=pairwise_displacement.dtype) - - if pairwise_displacement_weight is not None: - if isinstance(pairwise_displacement_weight, scipy.sparse.csr_matrix): - W = np.array(pairwise_displacement_weight[I, J]).T - else: - W = pairwise_displacement_weight[I, J][:, None] - else: - W = nnz_ones[:, None] - if isinstance(pairwise_displacement, scipy.sparse.csr_matrix): - V = np.array(pairwise_displacement[I, J])[0] - else: - V = pairwise_displacement[I, J] - M = csr_matrix((nnz_ones, (range(I.shape[0]), I)), shape=(I.shape[0], pairwise_displacement.shape[0])) - N = csr_matrix((nnz_ones, (range(I.shape[0]), J)), shape=(I.shape[0], pairwise_displacement.shape[0])) - A = M - N - idx = np.ones(A.shape[0], dtype=bool) - - # TODO: this is already soft_weights - xrange = trange if progress_bar else range - for i in xrange(lsqr_robust_n_iter): - p = lsqr(A[idx].multiply(W[idx]), V[idx] * W[idx][:, 0])[0] - idx = np.nonzero(np.abs(zscore(A @ p - V)) <= robust_regression_sigma) - displacement = p - - elif convergence_method == "lsmr": - import gc - from scipy import sparse - - D = pairwise_displacement - - # weighted problem - if pairwise_displacement_weight is None: - pairwise_displacement_weight = np.ones_like(D) - if sparse_mask is None: - sparse_mask = np.ones_like(D) - W = pairwise_displacement_weight * sparse_mask - if isinstance(W, scipy.sparse.csr_matrix): - W = W.astype(np.float32).toarray() - D = D.astype(np.float32).toarray() - - assert D.shape == W.shape - - # first dimension is the windows dim, which could be empty in rigid case - # we expand dims so that below we can consider only the nonrigid case - if D.ndim == 2: - W = W[None] - D = D[None] - assert D.ndim == W.ndim == 3 - B, T, T_ = D.shape - assert T == T_ - - # sparsify the problem - # we will make a list of temporal problems and then - # stack over the windows axis to finish. - # each matrix in coefficients will be (sparse_dim, T) - coefficients = [] - # each vector in targets will be (T,) - targets = [] - # we want to solve for a vector of shape BT, which we will reshape - # into a (B, T) matrix. - # after the loop below, we will stack a coefts matrix (sparse_dim, B, T) - # and a target vector of shape (B, T), both to be vectorized on last two axes, - # so that the target p is indexed by i = bT + t (block/window major). - - # calculate coefficients matrices and target vector - # this list stores boolean masks corresponding to whether or not each - # term comes from the prior or the likelihood. we can trim the likelihood terms, - # but not the prior terms, in the trimmed least squares (robust iters) iterations below. - cannot_trim = [] - for Wb, Db in zip(W, D): - # indices of active temporal pairs in this window - I, J = np.nonzero(Wb > 0) - n_sampled = I.size - - # construct Kroneckers and sparse objective in this window - pair_weights = np.ones(n_sampled) - if soft_weights: - pair_weights = Wb[I, J] - Mb = sparse.csr_matrix((pair_weights, (range(n_sampled), I)), shape=(n_sampled, T)) - Nb = sparse.csr_matrix((pair_weights, (range(n_sampled), J)), shape=(n_sampled, T)) - block_sparse_kron = Mb - Nb - block_disp_pairs = pair_weights * Db[I, J] - cannot_trim_block = np.ones_like(block_disp_pairs, dtype=bool) - - # add the temporal smoothness prior in this window - if temporal_prior: - temporal_diff_operator = sparse.diags( - ( - np.full(T - 1, -1, dtype=block_sparse_kron.dtype), - np.full(T - 1, 1, dtype=block_sparse_kron.dtype), - ), - offsets=(0, 1), - shape=(T - 1, T), - ) - block_sparse_kron = sparse.vstack( - (block_sparse_kron, temporal_diff_operator), - format="csr", - ) - block_disp_pairs = np.concatenate( - (block_disp_pairs, np.zeros(T - 1)), - ) - cannot_trim_block = np.concatenate( - (cannot_trim_block, np.zeros(T - 1, dtype=bool)), - ) - - coefficients.append(block_sparse_kron) - targets.append(block_disp_pairs) - cannot_trim.append(cannot_trim_block) - coefficients = sparse.block_diag(coefficients) - targets = np.concatenate(targets, axis=0) - cannot_trim = np.concatenate(cannot_trim, axis=0) - - # spatial smoothness prior: penalize difference of each block's - # displacement with the next. - # only if B > 1, and not in the last window. - # this is a (BT, BT) sparse matrix D such that: - # entry at (i, j) is: - # { 1 if i = j, i.e., i = j = bT + t for b = 0,...,B-2 - # { -1 if i = bT + t and j = (b+1)T + t for b = 0,...,B-2 - # { 0 otherwise. - # put more simply, the first (B-1)T diagonal entries are 1, - # and entries (i, j) such that i = j - T are -1. - if B > 1 and spatial_prior: - spatial_diff_operator = sparse.diags( - ( - np.ones((B - 1) * T, dtype=block_sparse_kron.dtype), - np.full((B - 1) * T, -1, dtype=block_sparse_kron.dtype), - ), - offsets=(0, T), - shape=((B - 1) * T, B * T), - ) - coefficients = sparse.vstack((coefficients, spatial_diff_operator)) - targets = np.concatenate((targets, np.zeros((B - 1) * T, dtype=targets.dtype))) - cannot_trim = np.concatenate((cannot_trim, np.zeros((B - 1) * T, dtype=bool))) - coefficients = coefficients.tocsr() - - # initialize at the column mean of pairwise displacements (in each window) - p0 = D.mean(axis=2).reshape(B * T) - - # use LSMR to solve the whole problem || targets - coefficients @ motion ||^2 - iters = range(max(1, lsqr_robust_n_iter)) - if progress_bar and lsqr_robust_n_iter > 1: - iters = tqdm(iters, desc="robust lsqr") - for it in iters: - # trim active set -- start with no trimming - idx = slice(None) - if it: - idx = np.flatnonzero( - cannot_trim | (np.abs(zscore(coefficients @ displacement - targets)) <= robust_regression_sigma) - ) - - # solve trimmed ols problem - displacement, *_ = sparse.linalg.lsmr(coefficients[idx], targets[idx], x0=p0) - - # warm start next iteration - p0 = displacement - # Cleanup lsmr memory (see https://stackoverflow.com/questions/56147713/memory-leak-in-scipy) - # TODO: check if this gets fixed in scipy - gc.collect() - - displacement = displacement.reshape(B, T).T - else: - raise ValueError( - f"Method {convergence_method} doesn't exist for compute_global_displacement" - f" possible values for 'convergence_method' are {_possible_convergence_method}" - ) - - return np.squeeze(displacement) - - -def iterative_template_registration( - spikecounts_hist_images, - non_rigid_windows=None, - num_shifts_global=15, - num_iterations=10, - num_shifts_block=5, - smoothing_sigma=0.5, - kriging_sigma=1, - kriging_p=2, - kriging_d=2, -): - """ - - Parameters - ---------- - - spikecounts_hist_images : np.ndarray - Spike count histogram images (num_temporal_bins, num_spatial_bins, num_amps_bins) - non_rigid_windows : list, default: None - If num_non_rigid_windows > 1, this argument is required and it is a list of - windows to taper spatial bins in different blocks - num_shifts_global : int, default: 15 - Number of spatial bin shifts to consider for global alignment - num_iterations : int, default: 10 - Number of iterations for global alignment procedure - num_shifts_block : int, default: 5 - Number of spatial bin shifts to consider for non-rigid alignment - smoothing_sigma : float, default: 0.5 - Sigma of gaussian for covariance matrices smoothing - kriging_sigma : float, default: 1 - sigma parameter for kriging_kernel function - kriging_p : float, default: 2 - p parameter for kriging_kernel function - kriging_d : float, default: 2 - d parameter for kriging_kernel function - - Returns - ------- - optimal_shift_indices - Optimal shifts for each temporal and spatial bin (num_temporal_bins, num_non_rigid_windows) - target_spikecount_hist - Target histogram used for alignment (num_spatial_bins, num_amps_bins) - """ - from scipy.ndimage import gaussian_filter, gaussian_filter1d - - # F is y bins by amp bins by batches - # ysamp are the coordinates of the y bins in um - spikecounts_hist_images = spikecounts_hist_images.swapaxes(0, 1).swapaxes(1, 2) - num_temporal_bins = spikecounts_hist_images.shape[2] - - # look up and down this many y bins to find best alignment - shift_covs = np.zeros((2 * num_shifts_global + 1, num_temporal_bins)) - shifts = np.arange(-num_shifts_global, num_shifts_global + 1) - - # mean subtraction to compute covariance - F = spikecounts_hist_images - Fg = F - np.mean(F, axis=0) - - # initialize the target "frame" for alignment with a single sample - # here we removed min(299, ...) - F0 = Fg[:, :, np.floor(num_temporal_bins / 2).astype("int") - 1] - F0 = F0[:, :, np.newaxis] - - # first we do rigid registration by integer shifts - # everything is iteratively aligned until most of the shifts become 0. - best_shifts = np.zeros((num_iterations, num_temporal_bins)) - for iteration in range(num_iterations): - for t, shift in enumerate(shifts): - # for each NEW potential shift, estimate covariance - Fs = np.roll(Fg, shift, axis=0) - shift_covs[t, :] = np.mean(Fs * F0, axis=(0, 1)) - if iteration + 1 < num_iterations: - # estimate the best shifts - imax = np.argmax(shift_covs, axis=0) - # align the data by these integer shifts - for t, shift in enumerate(shifts): - ibest = imax == t - Fg[:, :, ibest] = np.roll(Fg[:, :, ibest], shift, axis=0) - best_shifts[iteration, ibest] = shift - # new target frame based on our current best alignment - F0 = np.mean(Fg, axis=2)[:, :, np.newaxis] - target_spikecount_hist = F0[:, :, 0] - - # now we figure out how to split the probe into nblocks pieces - # if len(non_rigid_windows) = 1, then we're doing rigid registration - num_non_rigid_windows = len(non_rigid_windows) - - # for each small block, we only look up and down this many samples to find - # nonrigid shift - shifts_block = np.arange(-num_shifts_block, num_shifts_block + 1) - num_shifts = len(shifts_block) - shift_covs_block = np.zeros((2 * num_shifts_block + 1, num_temporal_bins, num_non_rigid_windows)) - - # this part determines the up/down covariance for each block without - # shifting anything - for window_index in range(num_non_rigid_windows): - win = non_rigid_windows[window_index] - window_slice = np.flatnonzero(win > 1e-5) - window_slice = slice(window_slice[0], window_slice[-1]) - tiled_window = win[window_slice, np.newaxis, np.newaxis] - Ftaper = Fg[window_slice] * np.tile(tiled_window, (1,) + Fg.shape[1:]) - for t, shift in enumerate(shifts_block): - Fs = np.roll(Ftaper, shift, axis=0) - F0taper = F0[window_slice] * np.tile(tiled_window, (1,) + F0.shape[1:]) - shift_covs_block[t, :, window_index] = np.mean(Fs * F0taper, axis=(0, 1)) - - # gaussian smoothing: - # here the original my_conv2_cpu is substituted with scipy gaussian_filters - shift_covs_block_smooth = shift_covs_block.copy() - shifts_block_up = np.linspace(-num_shifts_block, num_shifts_block, (2 * num_shifts_block * 10) + 1) - # 1. 2d smoothing over time and blocks dimensions for each shift - for shift_index in range(num_shifts): - shift_covs_block_smooth[shift_index, :, :] = gaussian_filter( - shift_covs_block_smooth[shift_index, :, :], smoothing_sigma - ) # some additional smoothing for robustness, across all dimensions - # 2. 1d smoothing over shift dimension for each spatial block - for window_index in range(num_non_rigid_windows): - shift_covs_block_smooth[:, :, window_index] = gaussian_filter1d( - shift_covs_block_smooth[:, :, window_index], smoothing_sigma, axis=0 - ) # some additional smoothing for robustness, across all dimensions - upsample_kernel = kriging_kernel( - shifts_block[:, np.newaxis], shifts_block_up[:, np.newaxis], sigma=kriging_sigma, p=kriging_p, d=kriging_d - ) - - optimal_shift_indices = np.zeros((num_temporal_bins, num_non_rigid_windows)) - for window_index in range(num_non_rigid_windows): - # using the upsampling kernel K, get the upsampled cross-correlation - # curves - upsampled_cov = upsample_kernel.T @ shift_covs_block_smooth[:, :, window_index] - - # find the max index of these curves - imax = np.argmax(upsampled_cov, axis=0) - - # add the value of the shift to the last row of the matrix of shifts - # (as if it was the last iteration of the main rigid loop ) - best_shifts[num_iterations - 1, :] = shifts_block_up[imax] - - # the sum of all the shifts equals the final shifts for this block - optimal_shift_indices[:, window_index] = np.sum(best_shifts, axis=0) - - return optimal_shift_indices, target_spikecount_hist, shift_covs_block - - -def normxcorr1d( - template, - x, - weights=None, - centered=True, - normalized=True, - padding="same", - conv_engine="torch", -): - """normxcorr1d: Normalized cross-correlation, optionally weighted - - The API is like torch's F.conv1d, except I have accidentally - changed the position of input/weights -- template acts like weights, - and x acts like input. - - Returns the cross-correlation of `template` and `x` at spatial lags - determined by `mode`. Useful for estimating the location of `template` - within `x`. - - This might not be the most efficient implementation -- ideas welcome. - It uses a direct convolutional translation of the formula - corr = (E[XY] - EX EY) / sqrt(var X * var Y) - - This also supports weights! In that case, the usual adaptation of - the above formula is made to the weighted case -- and all of the - normalizations are done per block in the same way. - - Parameters - ---------- - template : tensor, shape (num_templates, length) - The reference template signal - x : tensor, 1d shape (length,) or 2d shape (num_inputs, length) - The signal in which to find `template` - weights : tensor, shape (length,) - Will use weighted means, variances, covariances if supplied. - centered : bool - If true, means will be subtracted (per weighted patch). - normalized : bool - If true, normalize by the variance (per weighted patch). - padding : str - How far to look? if unset, we'll use half the length - conv_engine : string, one of "torch", "numpy" - What library to use for computing cross-correlations. - If numpy, falls back to the scipy correlate function. - - Returns - ------- - corr : tensor - """ - if conv_engine == "torch": - assert HAVE_TORCH - conv1d = F.conv1d - npx = torch - elif conv_engine == "numpy": - conv1d = scipy_conv1d - npx = np - else: - raise ValueError(f"Unknown conv_engine {conv_engine}. 'conv_engine' must be 'torch' or 'numpy'") - - x = npx.atleast_2d(x) - num_templates, length = template.shape - num_inputs, length_ = template.shape - assert length == length_ - - # generalize over weighted / unweighted case - device_kw = {} if conv_engine == "numpy" else dict(device=x.device) - ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw) - no_weights = weights is None - if no_weights: - weights = ones - wt = template[:, None, :] - else: - assert weights.shape == (length,) - weights = weights[None, None] - wt = template[:, None, :] * weights - - # conv1d valid rule: - # (B,1,L),(O,1,L)->(B,O,L) - - # compute expectations - # how many points in each window? seems necessary to normalize - # for numerical stability. - N = conv1d(ones, weights, padding=padding) - if centered: - Et = conv1d(ones, wt, padding=padding) - Et /= N - Ex = conv1d(x[:, None, :], weights, padding=padding) - Ex /= N - - # compute (weighted) covariance - # important: the formula E[XY] - EX EY is well-suited here, - # because the means are naturally subtracted correctly - # patch-wise. you couldn't pre-subtract them! - cov = conv1d(x[:, None, :], wt, padding=padding) - cov /= N - if centered: - cov -= Ex * Et - - # compute variances for denominator, using var X = E[X^2] - (EX)^2 - if normalized: - var_template = conv1d(ones, wt * template[:, None, :], padding=padding) - var_template /= N - var_x = conv1d(npx.square(x)[:, None, :], weights, padding=padding) - var_x /= N - if centered: - var_template -= npx.square(Et) - var_x -= npx.square(Ex) - - # now find the final normxcorr - corr = cov # renaming for clarity - if normalized: - corr /= npx.sqrt(var_x) - corr /= npx.sqrt(var_template) - # get rid of NaNs in zero-variance areas - corr[~npx.isfinite(corr)] = 0 - - return corr - - -def scipy_conv1d(input, weights, padding="valid"): - """SciPy translation of torch F.conv1d""" - from scipy.signal import correlate - - n, c_in, length = input.shape - c_out, in_by_groups, kernel_size = weights.shape - assert in_by_groups == c_in == 1 - - if padding == "same": - mode = "same" - length_out = length - elif padding == "valid": - mode = "valid" - length_out = length - 2 * (kernel_size // 2) - elif isinstance(padding, int): - mode = "valid" - input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)]) - length_out = length - (kernel_size - 1) + 2 * padding - else: - raise ValueError(f"Unknown 'padding' value of {padding}, 'padding' must be 'same', 'valid' or an integer") - - output = np.zeros((n, c_out, length_out), dtype=input.dtype) - for m in range(n): - for c in range(c_out): - output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode) - - return output - - -def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): - """ - Simple machinery to remove spurious fast bump in the motion vector. - Also can applyt a smoothing. - - - Arguments - --------- - motion: numpy array 2d - Motion estimate in um. - temporal_bins: numpy.array 1d - temporal bins (bin center) - bin_duration_s: float - bin duration in second - speed_threshold: float (units um/s) - Maximum speed treshold between 2 bins allowed. - Expressed in um/s - sigma_smooth_s: None or float - Optional smooting gaussian kernel. - - Returns - ------- - corr : tensor - - - """ - motion_clean = motion.copy() - - # STEP 1 : - # * detect long plateau or small peak corssing the speed thresh - # * mask the period and interpolate - for i in range(motion.shape[1]): - one_motion = motion_clean[:, i] - speed = np.diff(one_motion, axis=0) / bin_duration_s - (inds,) = np.nonzero(np.abs(speed) > speed_threshold) - inds += 1 - if inds.size % 2 == 1: - # more compicated case: number of of inds is odd must remove first or last - # take the smallest duration sum - inds0 = inds[:-1] - inds1 = inds[1:] - d0 = np.sum(inds0[1::2] - inds0[::2]) - d1 = np.sum(inds1[1::2] - inds1[::2]) - if d0 < d1: - inds = inds0 - mask = np.ones(motion_clean.shape[0], dtype="bool") - for i in range(inds.size // 2): - mask[inds[i * 2] : inds[i * 2 + 1]] = False - import scipy.interpolate - - f = scipy.interpolate.interp1d(temporal_bins[mask], one_motion[mask]) - one_motion[~mask] = f(temporal_bins[~mask]) - - # Step 2 : gaussian smooth - if sigma_smooth_s is not None: - half_size = motion_clean.shape[0] // 2 - if motion_clean.shape[0] % 2 == 0: - # take care of the shift - bins = (np.arange(motion_clean.shape[0]) - half_size + 1) * bin_duration_s - else: - bins = (np.arange(motion_clean.shape[0]) - half_size) * bin_duration_s - smooth_kernel = np.exp(-(bins**2) / (2 * sigma_smooth_s**2)) - smooth_kernel /= np.sum(smooth_kernel) - smooth_kernel = smooth_kernel[:, None] - motion_clean = scipy.signal.fftconvolve(motion_clean, smooth_kernel, mode="same", axes=0) - - return motion_clean - - -def kriging_kernel(source_location, target_location, sigma=1, p=2, d=2): - from scipy.spatial.distance import cdist - - dist_xy = cdist(source_location, target_location, metric="euclidean") - K = np.exp(-((dist_xy / sigma) ** p) / d) - return K diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py deleted file mode 100644 index a8de3f6d13..0000000000 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ /dev/null @@ -1,234 +0,0 @@ -import json -from pathlib import Path - -import numpy as np -import spikeinterface -from spikeinterface.core.core_tools import check_json - - -class Motion: - """ - Motion of the tissue relative the probe. - - Parameters - ---------- - displacement : numpy array 2d or list of - Motion estimate in um. - List is the number of segment. - For each semgent : - * shape (temporal bins, spatial bins) - * motion.shape[0] = temporal_bins.shape[0] - * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - temporal_bins_s : numpy.array 1d or list of - temporal bins (bin center) - spatial_bins_um : numpy.array 1d - Windows center. - spatial_bins_um.shape[0] == displacement.shape[1] - If rigid then spatial_bins_um.shape[0] == 1 - direction : str, default: 'y' - Direction of the motion. - interpolation_method : str - How to determine the displacement between bin centers? See the docs - for scipy.interpolate.RegularGridInterpolator for options. - """ - - def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): - if isinstance(displacement, np.ndarray): - self.displacement = [displacement] - assert isinstance(temporal_bins_s, np.ndarray) - self.temporal_bins_s = [temporal_bins_s] - else: - assert isinstance(displacement, (list, tuple)) - self.displacement = displacement - self.temporal_bins_s = temporal_bins_s - - assert isinstance(spatial_bins_um, np.ndarray) - self.spatial_bins_um = spatial_bins_um - - self.num_segments = len(self.displacement) - self.interpolators = None - self.interpolation_method = interpolation_method - - self.direction = direction - self.dim = ["x", "y", "z"].index(direction) - self.check_properties() - - def check_properties(self): - assert all(d.ndim == 2 for d in self.displacement) - assert all(t.ndim == 1 for t in self.temporal_bins_s) - assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) - - def __repr__(self): - nbins = self.spatial_bins_um.shape[0] - if nbins == 1: - rigid_txt = "rigid" - else: - rigid_txt = f"non-rigid - {nbins} spatial bins" - - interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] - txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" - return txt - - def make_interpolators(self): - from scipy.interpolate import RegularGridInterpolator - - self.interpolators = [ - RegularGridInterpolator( - (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method - ) - for j in range(self.num_segments) - ] - self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] - self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) - - def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): - """Evaluate the motion estimate at times and positions - - Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement - at the given times and locations. - - Parameters - ---------- - times_s: np.array - The time points at which to evaluate the displacement. - locations_um: np.array - Either this is a one-dimensional array (a vector of positions along self.dimension), or - else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. - segment_index: int, default: None - The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. - grid : bool, default: False - If grid=False, the default, then times_s and locations_um should have the same one-dimensional - shape, and the returned displacement[i] is the displacement at time times_s[i] and location - locations_um[i]. - If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. - Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. - - Returns - ------- - displacement : np.array - A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) - if grid=True. - """ - if self.interpolators is None: - self.make_interpolators() - - if segment_index is None: - if self.num_segments == 1: - segment_index = 0 - else: - raise ValueError("Several segment need segment_index=") - - times_s = np.asarray(times_s) - locations_um = np.asarray(locations_um) - - if locations_um.ndim == 1: - locations_um = locations_um - elif locations_um.ndim == 2: - locations_um = locations_um[:, self.dim] - else: - assert False - - times_s = times_s.clip(*self.temporal_bounds[segment_index]) - locations_um = locations_um.clip(*self.spatial_bounds) - - if grid: - # construct a grid over which to evaluate the displacement - locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") - out_shape = times_s.shape - locations_um = locations_um.ravel() - times_s = times_s.ravel() - else: - # usual case: input is a point cloud - assert locations_um.shape == times_s.shape - assert times_s.ndim == 1 - out_shape = times_s.shape - - points = np.column_stack((times_s, locations_um)) - displacement = self.interpolators[segment_index](points) - # reshape to grid domain shape if necessary - displacement = displacement.reshape(out_shape) - - return displacement - - def to_dict(self): - return dict( - displacement=self.displacement, - temporal_bins_s=self.temporal_bins_s, - spatial_bins_um=self.spatial_bins_um, - direction=self.direction, - interpolation_method=self.interpolation_method, - ) - - def save(self, folder): - folder = Path(folder) - folder.mkdir(exist_ok=False, parents=True) - - info_file = folder / f"spikeinterface_info.json" - info = dict( - version=spikeinterface.__version__, - dev_mode=spikeinterface.DEV_MODE, - object="Motion", - num_segments=self.num_segments, - direction=self.direction, - interpolation_method=self.interpolation_method, - ) - with open(info_file, mode="w") as f: - json.dump(check_json(info), f, indent=4) - - np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) - - for segment_index in range(self.num_segments): - np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) - np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) - - @classmethod - def load(cls, folder): - folder = Path(folder) - - info_file = folder / f"spikeinterface_info.json" - err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." - if not info_file.exists(): - raise IOError(err_msg) - - with open(info_file, "r") as f: - info = json.load(f) - if "object" not in info or info["object"] != "Motion": - raise IOError(err_msg) - - direction = info["direction"] - interpolation_method = info["interpolation_method"] - spatial_bins_um = np.load(folder / "spatial_bins_um.npy") - displacement = [] - temporal_bins_s = [] - for segment_index in range(info["num_segments"]): - displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) - temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) - - return cls( - displacement, - temporal_bins_s, - spatial_bins_um, - direction=direction, - interpolation_method=interpolation_method, - ) - - def __eq__(self, other): - for segment_index in range(self.num_segments): - if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): - return False - if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): - return False - - if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): - return False - - return True - - def copy(self): - return Motion( - [d.copy() for d in self.displacement], - [t.copy() for t in self.temporal_bins_s], - [s.copy() for s in self.spatial_bins_um], - direction=self.direction, - interpolation_method=self.interpolation_method, - ) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 0b79350a62..81cda212b2 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from spikeinterface.core import BaseRecording, SortingAnalyzer -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion class MotionWidget(BaseWidget): @@ -230,7 +230,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.colors import Normalize from .utils_matplotlib import make_mpl_figure - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks dp = to_attr(data_plot) @@ -291,12 +291,10 @@ class MotionInfoWidget(BaseWidget): ---------- motion_info : dict The motion info returned by correct_motion() or loaded back with load_motion_info(). + recording : RecordingExtractor + The recording extractor object segment_index : int, default: None The segment index to display. - recording : RecordingExtractor, default: None - The recording extractor object (only used to get "real" times). - segment_index : int, default: 0 - The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None). depth_lim : tuple or None, default: None @@ -320,8 +318,8 @@ class MotionInfoWidget(BaseWidget): def __init__( self, motion_info: dict, + recording: BaseRecording, segment_index: int | None = None, - recording: BaseRecording | None = None, depth_lim: tuple[float, float] | None = None, motion_lim: tuple[float, float] | None = None, color_amplitude: bool = False, @@ -366,7 +364,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 012b1ac07c..a09304dc86 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -190,8 +190,16 @@ def test_plot_unit_waveforms(self): backend=backend, **self.backend_kwargs[backend], ) - # test "larger" sparsity - with self.assertRaises(AssertionError): + # channel ids + sw.plot_unit_waveforms( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + backend=backend, + **self.backend_kwargs[backend], + ) + # test warning with "larger" sparsity + with self.assertWarns(UserWarning): sw.plot_unit_waveforms( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, @@ -205,10 +213,10 @@ def test_plot_unit_templates(self): for backend in possible_backends: if backend not in self.skip_backends: print(f"Testing backend {backend}") - print("Dense") + # dense sw.plot_unit_templates(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] - print("Dense + radius") + # dense + radius sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_radius, @@ -216,7 +224,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Dense + best") + # dense + best sw.plot_unit_templates( self.sorting_analyzer_dense, sparsity=self.sparsity_best, @@ -225,7 +233,6 @@ def test_plot_unit_templates(self): **self.backend_kwargs[backend], ) # test different shadings - print("Sparse") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -233,7 +240,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Sparse2") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -242,8 +248,6 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - # test different shadings - print("Sparse3") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -252,7 +256,6 @@ def test_plot_unit_templates(self): shade_templates=False, **self.backend_kwargs[backend], ) - print("Sparse4") sw.plot_unit_templates( self.sorting_analyzer_sparse, unit_ids=unit_ids, @@ -260,7 +263,7 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) - print("Extra sparsity") + # extra sparsity sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, @@ -269,8 +272,18 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) + # channel ids + sw.plot_unit_templates( + self.sorting_analyzer_sparse, + channel_ids=self.sorting_analyzer_sparse.channel_ids[::3], + unit_ids=unit_ids, + templates_percentile_shading=[1, 10, 90, 99], + backend=backend, + **self.backend_kwargs[backend], + ) + # test "larger" sparsity - with self.assertRaises(AssertionError): + with self.assertWarns(UserWarning): sw.plot_unit_templates( self.sorting_analyzer_sparse, sparsity=self.sparsity_large, diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index eb9a90d1d1..258ca2adaa 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -24,8 +24,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview" # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + unit_id_to_channel_ids = dp.final_sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.final_sparsity.unit_id_to_channel_indices unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 59f91306ea..c593836061 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -119,38 +119,50 @@ def __init__( if unit_ids is None: unit_ids = sorting_analyzer_or_templates.unit_ids - if channel_ids is None: - channel_ids = sorting_analyzer_or_templates.channel_ids if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer_or_templates) - channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids] - channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices] - extra_sparsity = False - if sorting_analyzer_or_templates.sparsity is not None: - if sparsity is None: - sparsity = sorting_analyzer_or_templates.sparsity - else: - # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), ( - "The provided 'sparsity' needs to include only the sparse channels " - "used to extract waveforms (for example, by using a smaller 'radius_um')." - ) - extra_sparsity = True - else: - if sparsity is None: - # in this case, we construct a dense sparsity - unit_id_to_channel_ids = { - u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids - } - sparsity = ChannelSparsity.from_unit_id_to_channel_ids( - unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_analyzer_or_templates.unit_ids, - channel_ids=sorting_analyzer_or_templates.channel_ids, - ) - else: - assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" + channel_locations = sorting_analyzer_or_templates.get_channel_locations() + extra_sparsity = None + # handle sparsity + sparsity_mismatch_warning = ( + "The provided 'sparsity' includes additional channels not in the analyzer sparsity. " + "These extra channels will be plotted as flat lines." + ) + analyzer_sparsity = sorting_analyzer_or_templates.sparsity + if channel_ids is not None: + assert sparsity is None, "If 'channel_ids' is provided, 'sparsity' should be None!" + channel_mask = np.tile( + np.isin(sorting_analyzer_or_templates.channel_ids, channel_ids), + (len(sorting_analyzer_or_templates.unit_ids), 1), + ) + extra_sparsity = ChannelSparsity( + mask=channel_mask, + channel_ids=sorting_analyzer_or_templates.channel_ids, + unit_ids=sorting_analyzer_or_templates.unit_ids, + ) + elif sparsity is not None: + extra_sparsity = sparsity + + if channel_ids is None: + channel_ids = sorting_analyzer_or_templates.channel_ids + + # assert provided sparsity is a subset of waveform sparsity + if extra_sparsity is not None and analyzer_sparsity is not None: + combined_mask = np.logical_or(analyzer_sparsity.mask, extra_sparsity.mask) + if not np.all(np.sum(combined_mask, 1) - np.sum(analyzer_sparsity.mask, 1) == 0): + warn(sparsity_mismatch_warning) + + final_sparsity = extra_sparsity if extra_sparsity is not None else analyzer_sparsity + if final_sparsity is None: + final_sparsity = ChannelSparsity( + mask=np.ones( + (len(sorting_analyzer_or_templates.unit_ids), len(sorting_analyzer_or_templates.channel_ids)), + dtype=bool, + ), + unit_ids=sorting_analyzer_or_templates.unit_ids, + channel_ids=sorting_analyzer_or_templates.channel_ids, + ) # get templates if isinstance(sorting_analyzer_or_templates, Templates): @@ -174,34 +186,14 @@ def __init__( templates_percentile_shading = None templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading) - wfs_by_ids = {} if plot_waveforms: # this must be a sorting_analyzer wf_ext = sorting_analyzer_or_templates.get_extension("waveforms") if wf_ext is None: raise ValueError("plot_waveforms() needs the extension 'waveforms'") - for unit_id in unit_ids: - unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id) - if not extra_sparsity: - if sorting_analyzer_or_templates.is_sparse(): - # wfs = we.get_waveforms(unit_id) - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) - else: - # wfs = we.get_waveforms(unit_id, sparsity=sparsity) - wfs = wf_ext.get_waveforms_one_unit(unit_id) - wfs = wfs[:, :, sparsity.mask[unit_index]] - else: - # in this case we have to slice the waveform sparsity based on the extra sparsity - # first get the sparse waveforms - # wfs = we.get_waveforms(unit_id) - wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) - # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index]) - (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) - (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) - # apply extra sparsity - wfs = wfs[:, :, extra_slice] - wfs_by_ids[unit_id] = wfs + wfs_by_ids = self._get_wfs_by_ids(sorting_analyzer_or_templates, unit_ids, extra_sparsity=extra_sparsity) + else: + wfs_by_ids = None plot_data = dict( sorting_analyzer_or_templates=sorting_analyzer_or_templates, @@ -209,7 +201,8 @@ def __init__( nbefore=nbefore, unit_ids=unit_ids, channel_ids=channel_ids, - sparsity=sparsity, + final_sparsity=final_sparsity, + extra_sparsity=extra_sparsity, unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, @@ -270,7 +263,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.axes.flatten()[i] color = dp.unit_colors[unit_id] - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + chan_inds = dp.final_sparsity.unit_id_to_channel_indices[unit_id] xvectors_flat = xvectors[:, chan_inds].T.flatten() # plot waveforms @@ -502,6 +495,32 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) + def _get_wfs_by_ids(self, sorting_analyzer, unit_ids, extra_sparsity): + wfs_by_ids = {} + wf_ext = sorting_analyzer.get_extension("waveforms") + for unit_id in unit_ids: + unit_index = list(sorting_analyzer.unit_ids).index(unit_id) + if extra_sparsity is None: + wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + else: + # in this case we have to construct waveforms based on the extra sparsity and add the + # sparse waveforms on the valid channels + if sorting_analyzer.is_sparse(): + original_mask = sorting_analyzer.sparsity.mask[unit_index] + else: + original_mask = np.ones(len(sorting_analyzer.channel_ids), dtype=bool) + wfs_orig = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) + wfs = np.zeros( + (wfs_orig.shape[0], wfs_orig.shape[1], extra_sparsity.mask[unit_index].sum()), dtype=wfs_orig.dtype + ) + # fill in the existing waveforms channels + valid_wfs_indices = extra_sparsity.mask[unit_index][original_mask] + valid_extra_indices = original_mask[extra_sparsity.mask[unit_index]] + wfs[:, :, valid_extra_indices] = wfs_orig[:, :, valid_wfs_indices] + + wfs_by_ids[unit_id] = wfs + return wfs_by_ids + def _get_template_shadings(self, unit_ids, templates_percentile_shading): templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") @@ -538,6 +557,8 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value + data_plot = self.next_data_plot + if self.sorting_analyzer is not None: templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) @@ -549,7 +570,6 @@ def _update_plot(self, change): channel_locations = self.templates.get_channel_locations() # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates data_plot["templates_shading"] = templates_shadings @@ -564,10 +584,10 @@ def _update_plot(self, change): data_plot["scalebar"] = self.scalebar.value if data_plot["plot_waveforms"]: - wf_ext = self.sorting_analyzer.get_extension("waveforms") - data_plot["wfs_by_ids"] = { - unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids - } + wfs_by_ids = self._get_wfs_by_ids( + self.sorting_analyzer, unit_ids, extra_sparsity=data_plot["extra_sparsity"] + ) + data_plot["wfs_by_ids"] = wfs_by_ids # TODO option for plot_legend backend_kwargs = {} @@ -611,7 +631,7 @@ def _plot_probe(self, ax, channel_locations, unit_ids): # TODO this could be done with probeinterface plotting plotting tools!! for unit in unit_ids: - channel_inds = self.data_plot["sparsity"].unit_id_to_channel_indices[unit] + channel_inds = self.data_plot["final_sparsity"].unit_id_to_channel_indices[unit] ax.plot( channel_locations[channel_inds, 0], channel_locations[channel_inds, 1], diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ac0676e4c7..ca09cc4d8f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -151,6 +151,7 @@ def array_to_image( output_image : 3D numpy array """ + import matplotlib.pyplot as plt from scipy.ndimage import zoom