From f55a9405040310064f716015f8d9b0c976b97923 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 28 Oct 2024 14:28:05 +0000 Subject: [PATCH 01/32] Add 'shift start time' function. --- src/spikeinterface/core/baserecording.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..b8a0420794 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,6 +509,26 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_start_time(self, shift, segment_index=None): + """ + Shift the starting time of the times. + + shift : int | float + The shift to apply to the first time point. If positive, + the current start time will be increased by `shift`. If + negative, the start time will be decreased. + + segment_index : int | None + The segment on which to shift the times. + """ + segment_index = self._check_segment_index(segment_index) + rs = self._recording_segments[segment_index] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds From f34da1aff682828dfba78cd17034c0fc2cb40fda Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 1 Nov 2024 14:03:13 +0000 Subject: [PATCH 02/32] Make new index page with hover CSS. --- doc/_static/css/custom.css | 20 +++ doc/conf.py | 6 +- doc/index.rst | 1 + doc/tutorials_custom_index.rst | 254 +++++++++++++++++++++++++++++++++ 4 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 doc/_static/css/custom.css create mode 100644 doc/tutorials_custom_index.rst diff --git a/doc/_static/css/custom.css b/doc/_static/css/custom.css new file mode 100644 index 0000000000..0c51da539e --- /dev/null +++ b/doc/_static/css/custom.css @@ -0,0 +1,20 @@ +/* Center and make the title bold */ +.gallery-card .grid-item-card-title { + text-align: center; + font-weight: bold; +} + +/* Default style for hover content (hidden) */ +.gallery-card .hover-content { + display: none; + text-align: center; +} + +/* Show the hover content when hovering over the card */ +.gallery-card:hover .default-title { + display: none; +} + +.gallery-card:hover .hover-content { + display: block; +} diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..db16269991 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -109,8 +109,10 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] - +html_static_path = ['_static'] +html_css_files = [ + 'css/custom.css', +] html_favicon = "images/favicon-32x32.png" diff --git a/doc/index.rst b/doc/index.rst index ed443e4200..57a0c95443 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -51,6 +51,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index + tutorials_custom_index tutorials/index how_to/index modules/index diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst new file mode 100644 index 0000000000..46a7bea630 --- /dev/null +++ b/doc/tutorials_custom_index.rst @@ -0,0 +1,254 @@ +.. This page provides a custom index to the 'Tutorials' page, rather than the default sphinx-gallery +.. generated page. The benefits of this are flexibility in design and inclusion of non-sphinx files in the index. +.. +.. To update this index with a new documentation page +.. 1) Copy the grid-item-card and associated ".. raw:: html" section. +.. 2) change :link: to a link to your page. If this is an `.rst` file, point to the rst file directly. +.. If it is a sphinx-gallery generated file, format the path as separated by underscore and prefix `sphx_glr`, +.. pointing to the .py file. e.g. `tutorials/my/page.py` -> `sphx_glr_tutorials_my_page.py +.. 3) Change :img-top: to point to the thumbnail image of your choosing. You can point to images generated +.. in the sphinx gallery page if you wish. +.. 4) In the `html` section, change the `default-title` to your pages title and `hover-content` to the subtitle. + +:orphan: + +TutorialsNew +============ + +Longer form tutorials about using SpikeInterface. Many of these are downloadable as notebooks or Python scripts so that you can "code along" with the tutorials. + +If you're new to SpikeInterface, we recommend trying out the :ref:`get_started/quickstart:Quickstart tutorial` first. + +Updating from legacy +-------------------- + +.. toctree:: + :maxdepth: 1 + + tutorials/waveform_extractor_to_sorting_analyzer + +Core tutorials +-------------- + +These tutorials focus on the :py:mod:`spikeinterface.core` module. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png + :img-alt: Recording objects + :class-card: gallery-card + + .. raw:: html + +
Recording objects
+
Manage loaded recordings in SpikeInterface
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png + :img-alt: Sorting objects + :class-card: gallery-card + + .. raw:: html + +
Sorting objects
+
Explore sorting extractor features
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png + :img-alt: Handling probe information + :class-card: gallery-card + + .. raw:: html + +
Handling probe information
+
Handle and visualize probe information
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png + :img-alt: SortingAnalyzer + :class-card: gallery-card + + .. raw:: html + +
SortingAnalyzer
+
Analyze sorting results with ease
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png + :img-alt: Append/Concatenate segments + :class-card: gallery-card + + .. raw:: html + +
Append and/or concatenate segments
+
Combine segments efficiently
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_core_plot_6_handle_times.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png + :img-alt: Handle time information + :class-card: gallery-card + + .. raw:: html + +
Handle time information
+
Manage and analyze time information
+ +Extractors tutorials +-------------------- + +The :py:mod:`spikeinterface.extractors` module is designed to load and save recorded and sorted data, and to handle probe information. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :img-alt: Read various formats + :class-card: gallery-card + + .. raw:: html + +
Read various formats
+
Read different recording formats efficiently
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png + :img-alt: Unscaled traces + :class-card: gallery-card + + .. raw:: html + +
Working with unscaled traces
+
Learn about managing unscaled traces
+ +Quality metrics tutorial +------------------------ + +The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png + :img-alt: Quality Metrics + :class-card: gallery-card + + .. raw:: html + +
Quality Metrics
+
Evaluate sorting quality using metrics
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :img-alt: Curation Tutorial + :class-card: gallery-card + + .. raw:: html + +
Curation Tutorial
+
Learn how to curate spike sorting data
+ +Comparison tutorial +------------------- + +The :code:`spikeinterface.comparison` module allows you to compare sorter outputs or benchmark against ground truth. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py + :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png + :img-alt: Sorter Comparison + :class-card: gallery-card + + .. raw:: html + +
Sorter Comparison
+
Compare sorter outputs and assess weaknesses
+ +Widgets tutorials +----------------- + +The :code:`widgets` module contains several plotting routines (widgets) for visualizing recordings, sorting data, probe layout, and more. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png + :img-alt: Recording Widgets + :class-card: gallery-card + + .. raw:: html + +
RecordingExtractor Widgets
+
Visualize recordings with widgets
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png + :img-alt: Sorting Widgets + :class-card: gallery-card + + .. raw:: html + +
SortingExtractor Widgets
+
Explore sorting data using widgets
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png + :img-alt: Waveforms Widgets + :class-card: gallery-card + + .. raw:: html + +
Waveforms Widgets
+
Display waveforms using SpikeInterface
+ + .. grid-item-card:: + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png + :img-alt: Peaks Widgets + :class-card: gallery-card + + .. raw:: html + +
Peaks Widgets
+
Visualize detected peaks
+ +Download All Examples +--------------------- + +- :download:`Download all examples in Python source code ` +- :download:`Download all examples in Jupyter notebooks ` From 7aa93490cca20916338629518800a1cbf976b8ff Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 1 Nov 2024 14:53:22 +0000 Subject: [PATCH 03/32] Remove CSS and update development docs. --- doc/_static/css/custom.css | 20 ----- doc/conf.py | 6 +- doc/development/development.rst | 19 +++++ doc/index.rst | 1 - doc/tutorials_custom_index.rst | 128 +++++++++----------------------- 5 files changed, 56 insertions(+), 118 deletions(-) delete mode 100644 doc/_static/css/custom.css diff --git a/doc/_static/css/custom.css b/doc/_static/css/custom.css deleted file mode 100644 index 0c51da539e..0000000000 --- a/doc/_static/css/custom.css +++ /dev/null @@ -1,20 +0,0 @@ -/* Center and make the title bold */ -.gallery-card .grid-item-card-title { - text-align: center; - font-weight: bold; -} - -/* Default style for hover content (hidden) */ -.gallery-card .hover-content { - display: none; - text-align: center; -} - -/* Show the hover content when hovering over the card */ -.gallery-card:hover .default-title { - display: none; -} - -.gallery-card:hover .hover-content { - display: block; -} diff --git a/doc/conf.py b/doc/conf.py index db16269991..e3d58ca8f2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -109,10 +109,8 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] -html_css_files = [ - 'css/custom.css', -] +# html_static_path = ['_static'] + html_favicon = "images/favicon-32x32.png" diff --git a/doc/development/development.rst b/doc/development/development.rst index a91818a271..1638c41243 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -213,6 +213,25 @@ We use Sphinx to build the documentation. To build the documentation locally, yo This will build the documentation in the :code:`doc/_build/html` folder. You can open the :code:`index.html` file in your browser to see the documentation. +Adding new documentation +------------------------ + +Documentation can be added as a +`sphinx-gallery `_ +python file ('tutorials') +or a +`sphinx rst `_ +file (all other sections). + +To add a new tutorial, add your ``.py`` file to ``spikeinterface/examples``. +Then, update the ``spikeinterface/doc/tutorials_custom_index.rst`` file +to make a new card linking to the page and an optional image. See +``tutorials_custom_index.rst`` header for more information. + +For other sections, write your documentation in ``.rst`` format and add +the page to the appropriate ``index.rst`` file found in the relevant +folder (e.g. ``how_to/index.rst``). + How to run code coverage locally -------------------------------- To run code coverage locally, you can use the following command: diff --git a/doc/index.rst b/doc/index.rst index 57a0c95443..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -52,7 +52,6 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index tutorials_custom_index - tutorials/index how_to/index modules/index api diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 46a7bea630..4c7625d811 100644 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -12,12 +12,14 @@ :orphan: -TutorialsNew +Tutorials ============ -Longer form tutorials about using SpikeInterface. Many of these are downloadable as notebooks or Python scripts so that you can "code along" with the tutorials. +Longer form tutorials about using SpikeInterface. Many of these are downloadable +as notebooks or Python scripts so that you can "code along" with the tutorials. -If you're new to SpikeInterface, we recommend trying out the :ref:`get_started/quickstart:Quickstart tutorial` first. +If you're new to SpikeInterface, we recommend trying out the +:ref:`get_started/quickstart:Quickstart tutorial` first. Updating from legacy -------------------- @@ -35,77 +37,53 @@ These tutorials focus on the :py:mod:`spikeinterface.core` module. .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Recording objects :link-type: ref :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png :img-alt: Recording objects :class-card: gallery-card + :text-align: center - .. raw:: html - -
Recording objects
-
Manage loaded recordings in SpikeInterface
- - .. grid-item-card:: + .. grid-item-card:: Sorting objects :link-type: ref :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png :img-alt: Sorting objects :class-card: gallery-card + :text-align: center - .. raw:: html - -
Sorting objects
-
Explore sorting extractor features
- - .. grid-item-card:: + .. grid-item-card:: Handling probe information :link-type: ref :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png :img-alt: Handling probe information :class-card: gallery-card + :text-align: center - .. raw:: html - -
Handling probe information
-
Handle and visualize probe information
- - .. grid-item-card:: + .. grid-item-card:: SortingAnalyzer :link-type: ref :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png :img-alt: SortingAnalyzer :class-card: gallery-card + :text-align: center - .. raw:: html - -
SortingAnalyzer
-
Analyze sorting results with ease
- - .. grid-item-card:: + .. grid-item-card:: Append and/or concatenate segments :link-type: ref :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png :img-alt: Append/Concatenate segments :class-card: gallery-card + :text-align: center - .. raw:: html - -
Append and/or concatenate segments
-
Combine segments efficiently
- - .. grid-item-card:: + .. grid-item-card:: Handle time information :link-type: ref :link: sphx_glr_tutorials_core_plot_6_handle_times.py :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png :img-alt: Handle time information :class-card: gallery-card - - .. raw:: html - -
Handle time information
-
Manage and analyze time information
+ :text-align: center Extractors tutorials -------------------- @@ -115,29 +93,21 @@ The :py:mod:`spikeinterface.extractors` module is designed to load and save reco .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Read various formats :link-type: ref :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png :img-alt: Read various formats :class-card: gallery-card + :text-align: center - .. raw:: html - -
Read various formats
-
Read different recording formats efficiently
- - .. grid-item-card:: + .. grid-item-card:: Working with unscaled traces :link-type: ref :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png :img-alt: Unscaled traces :class-card: gallery-card - - .. raw:: html - -
Working with unscaled traces
-
Learn about managing unscaled traces
+ :text-align: center Quality metrics tutorial ------------------------ @@ -147,29 +117,21 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Quality Metrics :link-type: ref :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png :img-alt: Quality Metrics :class-card: gallery-card + :text-align: center - .. raw:: html - -
Quality Metrics
-
Evaluate sorting quality using metrics
- - .. grid-item-card:: + .. grid-item-card:: Curation Tutorial :link-type: ref :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png :img-alt: Curation Tutorial :class-card: gallery-card - - .. raw:: html - -
Curation Tutorial
-
Learn how to curate spike sorting data
+ :text-align: center Comparison tutorial ------------------- @@ -179,17 +141,13 @@ The :code:`spikeinterface.comparison` module allows you to compare sorter output .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: Sorter Comparison :link-type: ref :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png :img-alt: Sorter Comparison :class-card: gallery-card - - .. raw:: html - -
Sorter Comparison
-
Compare sorter outputs and assess weaknesses
+ :text-align: center Widgets tutorials ----------------- @@ -199,53 +157,37 @@ The :code:`widgets` module contains several plotting routines (widgets) for visu .. grid:: 1 2 2 3 :gutter: 2 - .. grid-item-card:: + .. grid-item-card:: RecordingExtractor Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png :img-alt: Recording Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
RecordingExtractor Widgets
-
Visualize recordings with widgets
- - .. grid-item-card:: + .. grid-item-card:: SortingExtractor Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png :img-alt: Sorting Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
SortingExtractor Widgets
-
Explore sorting data using widgets
- - .. grid-item-card:: + .. grid-item-card:: Waveforms Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png :img-alt: Waveforms Widgets :class-card: gallery-card + :text-align: center - .. raw:: html - -
Waveforms Widgets
-
Display waveforms using SpikeInterface
- - .. grid-item-card:: + .. grid-item-card:: Peaks Widgets :link-type: ref :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png :img-alt: Peaks Widgets :class-card: gallery-card - - .. raw:: html - -
Peaks Widgets
-
Visualize detected peaks
+ :text-align: center Download All Examples --------------------- From 507b6b3cf19d0f10069e2415f134dca7fb709b47 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 3 Nov 2024 19:16:26 -0500 Subject: [PATCH 04/32] Address time bin issue arising in LFP-based reg, which AP-based reg doesn't trigger --- .../motion/motion_interpolation.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a5e6ded519..975f43919d 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -3,7 +3,8 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.basepreprocessor import ( + BasePreprocessor, BasePreprocessorSegment) from spikeinterface.preprocessing.filter import fix_dtype @@ -122,14 +123,18 @@ def interpolate_motion_on_traces( time_bins = interpolation_time_bin_centers_s if time_bins is None: time_bins = motion.temporal_bins_s[segment_index] + + # nearest interpolation bin: + # seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # hence the -1. doing it with "left" is not as nice. + # time_bins are bin centers, so subtract half the bin length. this leads + # to snapping to the nearest bin center. bin_s = time_bins[1] - time_bins[0] - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) + bin_inds = np.searchsorted(time_bins - bin_s / 2, times, side="right") - 1 + # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + np.clip(bin_inds, 0, time_bins.size - 1, out=bin_inds) # -- what are the possibilities here anyway? bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) From 4e38ac18be65051d30d15f3d25bada943af3e31f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 4 Nov 2024 11:06:46 -0500 Subject: [PATCH 05/32] Fix LFP-based AP interp bug and allow time_vector in interpolation --- .../motion/motion_interpolation.py | 18 +++-- .../motion/tests/test_motion_interpolation.py | 78 ++++++++++++++++--- 2 files changed, 77 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 975f43919d..4fd42a8b39 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -126,11 +126,16 @@ def interpolate_motion_on_traces( # nearest interpolation bin: # seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] - # hence the -1. doing it with "left" is not as nice. - # time_bins are bin centers, so subtract half the bin length. this leads - # to snapping to the nearest bin center. - bin_s = time_bins[1] - time_bins[0] - bin_inds = np.searchsorted(time_bins - bin_s / 2, times, side="right") - 1 + # hence the -1. doing it with "left" is not as nice -- we want t==b[0] + # to lead to i=1 (rounding down). + # time_bins are bin centers, but we want to snap to the nearest center. + # idea is to get the left bin edges and bin the interp times. + # this is like subtracting bin_dt_s/2, but allows non-equally-spaced bins. + bin_left = np.zeros_like(time_bins) + # it's fine to use the first bin center for the first left edge + bin_left[0] = time_bins[0] + bin_left[1:] = 0.5 * (time_bins[1:] + time_bins[:-1]) + bin_inds = np.searchsorted(bin_left, times, side="right") - 1 # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range @@ -438,9 +443,6 @@ def __init__( self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: - raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") - if start_frame is None: start_frame = 0 if end_frame is None: diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e022f0cc6c..69f681a1be 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -1,16 +1,11 @@ -from pathlib import Path +import warnings import numpy as np -import pytest import spikeinterface.core as sc -from spikeinterface import download_dataset -from spikeinterface.sortingcomponents.motion.motion_interpolation import ( - InterpolateMotionRecording, - correct_motion_on_peaks, - interpolate_motion, - interpolate_motion_on_traces, -) from spikeinterface.sortingcomponents.motion import Motion +from spikeinterface.sortingcomponents.motion.motion_interpolation import ( + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -115,6 +110,66 @@ def test_interpolation_simple(): assert np.all(traces_corrected[:, 2:] == 0) +def test_cross_band_interpolation(): + """Simple version of using LFP to interpolate AP data + + This also tests the time vector implementation in interpolation. + The idea is to have two recordings which are all 0s with a 1 that + moves from one channel to another after 3s. They're at different + sampling frequencies. motion estimation in one sampling frequency + applied to the other should still lead to perfect correction. + """ + from spikeinterface.sortingcomponents.motion import estimate_motion + + # sampling freqs and timing for AP and LFP recordings + fs_lfp = 50.0 + fs_ap = 300.0 + t_start = 10.0 + total_duration = 5.0 + nt_lfp = int(fs_lfp * total_duration) + nt_ap = int(fs_ap * total_duration) + t_switch = 3 + + # because interpolation uses bin centers logic, there will be a half + # bin offset at the change point in the AP recording. + halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) + + # channel geometry + nc = 10 + geom = np.c_[np.zeros(nc), np.arange(nc)] + + # make an LFP recording which drifts a bit + traces_lfp = np.zeros((nt_lfp, nc)) + traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 + traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 + rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) + rec_lfp.set_dummy_probe_from_locations(geom) + + # same for AP + traces_ap = np.zeros((nt_ap, nc)) + traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 + traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 + rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) + rec_ap.set_dummy_probe_from_locations(geom) + + # set times for both, and silence the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap) + + # estimate motion + motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) + + # nearest to keep it simple + rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) + traces_corrected = rec_corrected.get_traces() + target = np.zeros((nt_ap, nc - 2)) + target[:, 4] = 1 + ii, jj = np.nonzero(traces_corrected) + assert np.array_equal(traces_corrected, target) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -148,5 +203,6 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() # test_interpolate_motion_on_traces() - test_interpolation_simple() - test_InterpolateMotionRecording() + # test_interpolation_simple() + # test_InterpolateMotionRecording() + test_cross_band_interpolation() From 726170b1526b954b5a26edd70d3162e476ed9f53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:28:44 +0000 Subject: [PATCH 06/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion/motion_interpolation.py | 3 +-- .../motion/tests/test_motion_interpolation.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 4fd42a8b39..810264d9e4 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -3,8 +3,7 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 69f681a1be..88af619220 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,8 +4,11 @@ import spikeinterface.core as sc from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( - InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, - interpolate_motion_on_traces) + InterpolateMotionRecording, + correct_motion_on_peaks, + interpolate_motion, + interpolate_motion_on_traces, +) from spikeinterface.sortingcomponents.tests.common import make_dataset From d6b4c1e7474c372c6d9f71787ddbe707854bd11f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Nov 2024 11:44:13 +0100 Subject: [PATCH 07/32] Fix cbin_file_path --- src/spikeinterface/extractors/cbin_ibl.py | 30 +++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index d7e5b58e11..88e1029ab0 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import warnings import numpy as np import probeinterface @@ -30,8 +31,10 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file : str or None, default None + cbin_file_path : str or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. + cbin_file : str or None, default None + (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- @@ -41,14 +44,21 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): + def __init__( + self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None + ): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is None: + if cbin_file is not None: + warnings.warn( + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning + ) + cbin_file_path = cbin_file + if cbin_file_path is None: folder_path = Path(folder_path) # check bands assert stream_name in ["ap", "lp"], "stream_name must be one of: 'ap', 'lp'" @@ -60,17 +70,17 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", assert ( len(curr_cbin_files) == 1 ), f"There should only be one `*.cbin` file in the folder, but {print(curr_cbin_files)} have been found" - cbin_file = curr_cbin_files[0] + cbin_file_path = curr_cbin_files[0] else: - cbin_file = Path(cbin_file) - folder_path = cbin_file.parent + cbin_file_path = Path(cbin_file_path) + folder_path = cbin_file_path.parent - ch_file = cbin_file.with_suffix(".ch") - meta_file = cbin_file.with_suffix(".meta") + ch_file = cbin_file_path.with_suffix(".ch") + meta_file = cbin_file_path.with_suffix(".meta") # reader cbuffer = mtscomp.Reader() - cbuffer.open(cbin_file, ch_file) + cbuffer.open(cbin_file_path, ch_file) # meta data meta = read_meta_file(meta_file) @@ -119,7 +129,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", self._kwargs = { "folder_path": str(Path(folder_path).resolve()), "load_sync_channel": load_sync_channel, - "cbin_file": str(Path(cbin_file).resolve()), + "cbin_file_path": str(Path(cbin_file_path).resolve()), } From e6f45056852e181fb8d6909c8a3365a08cb2c8f5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Nov 2024 15:56:48 +0100 Subject: [PATCH 08/32] Update src/spikeinterface/extractors/cbin_ibl.py --- src/spikeinterface/extractors/cbin_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 88e1029ab0..357afde04e 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -55,7 +55,7 @@ def __init__( raise ImportError(self.installation_mesg) if cbin_file is not None: warnings.warn( - "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning, stacklevel=2 ) cbin_file_path = cbin_file if cbin_file_path is None: From 471ce724faac7245766538880b7fcd196f49fa30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:57:14 +0000 Subject: [PATCH 09/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/extractors/cbin_ibl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 357afde04e..8fe19f3d7e 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -55,7 +55,9 @@ def __init__( raise ImportError(self.installation_mesg) if cbin_file is not None: warnings.warn( - "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", DeprecationWarning, stacklevel=2 + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + DeprecationWarning, + stacklevel=2, ) cbin_file_path = cbin_file if cbin_file_path is None: From e791fe18671c2998fde9d44295c54a5781ca2e46 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 11 Nov 2024 10:43:48 -0500 Subject: [PATCH 10/32] Cache bin edges in 2 places as discussed with Sam --- .../motion/motion_interpolation.py | 39 +++++++++++++------ .../sortingcomponents/motion/motion_utils.py | 24 +++++++++++- .../motion/tests/test_motion_interpolation.py | 28 ++++++------- 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 4fd42a8b39..89696f5041 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -7,6 +7,8 @@ BasePreprocessor, BasePreprocessorSegment) from spikeinterface.preprocessing.filter import fix_dtype +from .motion_utils import ensure_time_bin_edges, ensure_time_bins + def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ @@ -55,6 +57,7 @@ def interpolate_motion_on_traces( segment_index=None, channel_inds=None, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, dtype=None, @@ -120,9 +123,11 @@ def interpolate_motion_on_traces( total_num_chans = channel_locations.shape[0] # -- determine the blocks of frames that will land in the same interpolation time bin - time_bins = interpolation_time_bin_centers_s - if time_bins is None: - time_bins = motion.temporal_bins_s[segment_index] + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: + bin_centers_s = motion.temporal_bin_edges_s[segment_index] + bin_edges_s = motion.temporal_bin_edges_s[segment_index] + else: + bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) # nearest interpolation bin: # seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] @@ -131,15 +136,13 @@ def interpolate_motion_on_traces( # time_bins are bin centers, but we want to snap to the nearest center. # idea is to get the left bin edges and bin the interp times. # this is like subtracting bin_dt_s/2, but allows non-equally-spaced bins. - bin_left = np.zeros_like(time_bins) # it's fine to use the first bin center for the first left edge - bin_left[0] = time_bins[0] - bin_left[1:] = 0.5 * (time_bins[1:] + time_bins[:-1]) - bin_inds = np.searchsorted(bin_left, times, side="right") - 1 + bin_inds = np.searchsorted(bin_edges_s, times, side="right") - 1 # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size - 1, out=bin_inds) + n_bins = bin_edges_s.shape[0] - 1 + np.clip(bin_inds, 0, n_bins - 1, out=bin_inds) # -- what are the possibilities here anyway? bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) @@ -148,7 +151,7 @@ def interpolate_motion_on_traces( interp_times = np.empty(total_num_chans) current_start_index = 0 for bin_ind in bins_here: - bin_time = time_bins[bin_ind] + bin_time = bin_centers_s[bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -307,6 +310,7 @@ def __init__( p=1, num_closest=3, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, interpolation_time_bin_size_s=None, dtype=None, **spatial_interpolation_kwargs, @@ -373,9 +377,14 @@ def __init__( # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below - if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: if interpolation_time_bin_size_s is None: interpolation_time_bin_centers_s = motion.temporal_bins_s + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) for segment_index, parent_segment in enumerate(recording._recording_segments): # finish the per-segment part of the time bin logic @@ -385,8 +394,13 @@ def __init__( t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + segment_interpolation_time_bin_edges_s = np.arange( + t_start, t_end + halfbin, interpolation_time_bin_size_s + ) + assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] rec_segment = InterpolateMotionRecordingSegment( parent_segment, @@ -397,6 +411,7 @@ def __init__( channel_inds, segment_index, segment_interpolation_time_bins_s, + segment_interpolation_time_bin_edges_s, dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -430,6 +445,7 @@ def __init__( channel_inds, segment_index, interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s, dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -439,6 +455,7 @@ def __init__( self.channel_inds = channel_inds self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s self.dtype = dtype self.motion = motion @@ -460,7 +477,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, - interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, ) if channel_indices is not None: diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 635624cca8..ec0a55a8f8 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -1,5 +1,5 @@ -import warnings import json +import warnings from pathlib import Path import numpy as np @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) self.check_properties() + self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) @@ -576,3 +577,24 @@ def make_3d_motion_histograms( motion_histograms = np.log2(1 + motion_histograms) return motion_histograms, temporal_bin_edges, spatial_bin_edges + + +def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + if time_bin_centers_s is None and time_bin_edges_s is None: + raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") + + if time_bin_centers_s is None: + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + + if time_bin_edges_s is None: + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + + return time_bin_centers_s, time_bin_edges_s + + +def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): + return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 69f681a1be..07cb5b8ab6 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -62,18 +62,20 @@ def test_interpolate_motion_on_traces(): times = rec.get_times()[0:30000] for method in ("kriging", "idw", "nearest"): - traces_corrected = interpolate_motion_on_traces( - traces, - times, - channel_locations, - motion, - channel_inds=None, - spatial_interpolation_method=method, - # spatial_interpolation_kwargs={}, - spatial_interpolation_kwargs={"force_extrapolate": True}, - ) - assert traces.shape == traces_corrected.shape - assert traces.dtype == traces_corrected.dtype + for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): + traces_corrected = interpolate_motion_on_traces( + traces, + times, + channel_locations, + motion, + channel_inds=None, + spatial_interpolation_method=method, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, + ) + assert traces.shape == traces_corrected.shape + assert traces.dtype == traces_corrected.dtype def test_interpolation_simple(): @@ -202,7 +204,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - # test_interpolate_motion_on_traces() + test_interpolate_motion_on_traces() # test_interpolation_simple() # test_InterpolateMotionRecording() test_cross_band_interpolation() From d8f39b5a70dd83f4e1fff71d41036692fba20b38 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 11 Nov 2024 12:30:32 -0500 Subject: [PATCH 11/32] Sorry if this is shoe-horning in a change... --- src/spikeinterface/core/baserecording.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..b95bfb1ad0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,20 +1,17 @@ from __future__ import annotations + import warnings from pathlib import Path import numpy as np -from probeinterface import Probe, ProbeGroup, read_probeinterface, select_axes, write_probeinterface +from probeinterface import (Probe, ProbeGroup, read_probeinterface, + select_axes, write_probeinterface) from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets -from .core_tools import ( - convert_bytes_to_str, - convert_seconds_to_str, -) -from .recording_tools import write_binary_recording - - +from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs +from .recording_tools import write_binary_recording class BaseRecording(BaseRecordingSnippets): @@ -921,11 +918,11 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = round(sample_index) + sample_index = np.round(sample_index).astype(int) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - return int(sample_index) + return sample_index def get_num_samples(self) -> int: """Returns the number of samples in this signal segment From 0a201e17a0b3de283f06c5456010fb20fd8cd209 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Nov 2024 17:31:00 +0000 Subject: [PATCH 12/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/baserecording.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b95bfb1ad0..6d4509db12 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -4,8 +4,7 @@ from pathlib import Path import numpy as np -from probeinterface import (Probe, ProbeGroup, read_probeinterface, - select_axes, write_probeinterface) +from probeinterface import Probe, ProbeGroup, read_probeinterface, select_axes, write_probeinterface from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets From 620f8013b8bf4f1332a7802dd3f6804ce068493c Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:37:50 +0000 Subject: [PATCH 13/32] Apply to all segments if 'segment_index' is 'None'. --- src/spikeinterface/core/baserecording.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b8a0420794..7392caa69b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -521,13 +521,20 @@ def shift_start_time(self, shift, segment_index=None): segment_index : int | None The segment on which to shift the times. """ - segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + self._check_segment_index(segment_index) - if self.has_time_vector(): - rs.time_vector += shift + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) else: - rs.t_start += shift + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(): + rs.time_vector += shift + else: + rs.t_start += shift def sample_index_to_time(self, sample_ind, segment_index=None): """ From 22d5dfc2a552e00d7b55d7c28681e25a1f51a711 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:39:34 +0000 Subject: [PATCH 14/32] Add type hints. --- src/spikeinterface/core/baserecording.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7392caa69b..0af9c4bb6a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,7 +509,7 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift, segment_index=None): + def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: """ Shift the starting time of the times. @@ -536,15 +536,14 @@ def shift_start_time(self, shift, segment_index=None): else: rs.t_start += shift - def sample_index_to_time(self, sample_ind, segment_index=None): - """ - Transform sample index into time in seconds - """ + def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s, segment_index=None): + def time_to_sample_index(self, time_s: float, segment_index: int | None = None): + """ """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 458a3dcc201380740583ef1f075951e83ee77ed8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 13:43:45 +0000 Subject: [PATCH 15/32] Update name and docstring. --- src/spikeinterface/core/baserecording.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0af9c4bb6a..91f99f17b0 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -509,19 +509,24 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency - def shift_start_time(self, shift: int | float, segment_index: int | None = None) -> None: + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift the starting time of the times. + Shift all times by a scalar value. The default behaviour is to + shift all segments uniformly. + Parameters + ---------- shift : int | float - The shift to apply to the first time point. If positive, - the current start time will be increased by `shift`. If - negative, the start time will be decreased. + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. segment_index : int | None - The segment on which to shift the times. + The segment on which to shift the times. if `None`, all + segments will be shifted. """ - self._check_segment_index(segment_index) + self._check_segment_index(segment_index) # Check the segment index is valid only if segment_index is None: segments_to_shift = range(self.get_num_segments()) From 8845d3d7eb6caad8c6a5f0c12842f480766d3a26 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:40:16 +0000 Subject: [PATCH 16/32] Add verbose kwarg to mda write_recording --- src/spikeinterface/extractors/mdaextractors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index f055e1d7c9..d2886d9e79 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -72,6 +72,7 @@ def write_recording( params_fname="params.json", geom_fname="geom.csv", dtype=None, + verbose=False, **job_kwargs, ): """Write a recording to file in MDA format. @@ -93,6 +94,8 @@ def write_recording( File name of geom file dtype : dtype or None, default: None Data type to be used. If None dtype is same as recording traces. + verbose : bool + If True, shows progress bar when saving recording. **job_kwargs: Use by job_tools modules to set: @@ -130,6 +133,7 @@ def write_recording( dtype=dtype, byte_offset=header_size, add_file_extension=False, + verbose=verbose, **job_kwargs, ) From 3e98c670a27671590613b7c1c4118780a8c47ce8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:32:48 +0000 Subject: [PATCH 17/32] Add tests. --- .../core/tests/test_time_handling.py | 92 ++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..9b7ed11bbb 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load_extractor(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` From 4d7246a529e3d17747cf5a496a0a04bd97f4eb09 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:33:17 +0000 Subject: [PATCH 18/32] Fixes on shift function. --- src/spikeinterface/core/baserecording.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 91f99f17b0..4b545dc7c7 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -526,8 +526,6 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N The segment on which to shift the times. if `None`, all segments will be shifted. """ - self._check_segment_index(segment_index) # Check the segment index is valid only - if segment_index is None: segments_to_shift = range(self.get_num_segments()) else: @@ -536,7 +534,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N for idx in segments_to_shift: rs = self._recording_segments[idx] - if self.has_time_vector(): + if self.has_time_vector(segment_index=idx): rs.time_vector += shift else: rs.t_start += shift From a1cf3367d18a549281208b25c622f2a1ee773226 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:35:32 +0000 Subject: [PATCH 19/32] Undo out of scope changes. --- src/spikeinterface/core/baserecording.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 4b545dc7c7..886f7db79f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -539,14 +539,15 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N else: rs.t_start += shift - def sample_index_to_time(self, sample_ind: int, segment_index: int | None = None): - """ """ + def sample_index_to_time(self, sample_ind, segment_index=None): + """ + Transform sample index into time in seconds + """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.sample_index_to_time(sample_ind) - def time_to_sample_index(self, time_s: float, segment_index: int | None = None): - """ """ + def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] return rs.time_to_sample_index(time_s) From 469b3b0e36fdbc0571d37e100d99d6c741af1377 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Nov 2024 18:37:20 +0000 Subject: [PATCH 20/32] Fix docstring. --- src/spikeinterface/core/baserecording.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 886f7db79f..6d9d2a827f 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -511,8 +511,7 @@ def reset_times(self): def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: """ - Shift all times by a scalar value. The default behaviour is to - shift all segments uniformly. + Shift all times by a scalar value. Parameters ---------- @@ -523,8 +522,8 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N started earlier. segment_index : int | None - The segment on which to shift the times. if `None`, all - segments will be shifted. + The segment on which to shift the times. + If `None`, all segments will be shifted. """ if segment_index is None: segments_to_shift = range(self.get_num_segments()) From 1e53a5e06b2a90956d72150826a8f590d673b5ce Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Nov 2024 17:45:40 +0100 Subject: [PATCH 21/32] Update src/spikeinterface/extractors/cbin_ibl.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/extractors/cbin_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 8fe19f3d7e..728d352973 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -31,7 +31,7 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file_path : str or None, default None + cbin_file_path : str, Path or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. cbin_file : str or None, default None (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. From ad00beb182967ebe68f59ebfd7f1abad3002a10e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:44:20 +0000 Subject: [PATCH 22/32] Update `interpolate_motion_on_traces` docstring --- .../sortingcomponents/motion/motion_interpolation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index d207dced08..7c6f0ba71a 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -83,6 +83,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. + interpolation_time_bin_edges_s : None or np.array + If present, interpolation chunks will be the time bins defined by these edges + rather than interpolation_time_bin_centers_s or the motion's bins. spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing From 28527d2a8bcd03a3c6d9036ac9be180ad4dc6334 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:46:30 +0000 Subject: [PATCH 23/32] Should be centers! --- .../sortingcomponents/motion/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 7c6f0ba71a..f0fff5c039 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -126,7 +126,7 @@ def interpolate_motion_on_traces( # -- determine the blocks of frames that will land in the same interpolation time bin if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: - bin_centers_s = motion.temporal_bin_edges_s[segment_index] + bin_centers_s = motion.temporal_bin_centers_s[segment_index] bin_edges_s = motion.temporal_bin_edges_s[segment_index] else: bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) From b3b3fcf5be7b67451f54cac753844cbb06b8d4a3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:47:34 +0000 Subject: [PATCH 24/32] Typo --- .../sortingcomponents/motion/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index f0fff5c039..9e32e189d9 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -132,7 +132,7 @@ def interpolate_motion_on_traces( bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) # nearest interpolation bin: - # seachsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] # hence the -1. doing it with "left" is not as nice -- we want t==b[0] # to lead to i=1 (rounding down). # time_bins are bin centers, but we want to snap to the nearest center. From b02860e463262a3b9522ebf9badc45c860cf8d29 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:51:11 +0000 Subject: [PATCH 25/32] Clarify comments --- .../sortingcomponents/motion/motion_interpolation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 9e32e189d9..8f96579228 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -131,14 +131,10 @@ def interpolate_motion_on_traces( else: bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) - # nearest interpolation bin: + # bin the frame times according to the interpolation time bins. # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] # hence the -1. doing it with "left" is not as nice -- we want t==b[0] # to lead to i=1 (rounding down). - # time_bins are bin centers, but we want to snap to the nearest center. - # idea is to get the left bin edges and bin the interp times. - # this is like subtracting bin_dt_s/2, but allows non-equally-spaced bins. - # it's fine to use the first bin center for the first left edge bin_inds = np.searchsorted(bin_edges_s, times, side="right") - 1 # the time bins may not cover the whole set of times in the recording, From df2484002b562861f0f64bc053b5da619253e983 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:57:08 +0000 Subject: [PATCH 26/32] Rename variables for clarity --- .../motion/motion_interpolation.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 8f96579228..2bd3493650 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -126,30 +126,30 @@ def interpolate_motion_on_traces( # -- determine the blocks of frames that will land in the same interpolation time bin if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: - bin_centers_s = motion.temporal_bin_centers_s[segment_index] - bin_edges_s = motion.temporal_bin_edges_s[segment_index] + interpolation_time_bin_centers_s = motion.temporal_bin_centers_s[segment_index] + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] else: - bin_centers_s, bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) # bin the frame times according to the interpolation time bins. # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] # hence the -1. doing it with "left" is not as nice -- we want t==b[0] # to lead to i=1 (rounding down). - bin_inds = np.searchsorted(bin_edges_s, times, side="right") - 1 + interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - n_bins = bin_edges_s.shape[0] - 1 - np.clip(bin_inds, 0, n_bins - 1, out=bin_inds) + n_bins = interpolation_time_bin_edges_s.shape[0] - 1 + np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) current_start_index = 0 - for bin_ind in bins_here: - bin_time = bin_centers_s[bin_ind] + for interp_bin_ind in interpolation_bins_here: + bin_time = bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -177,16 +177,17 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() + # quick search logic to find frames corresponding to this interpolation bin in the recording # quickly find the end of this bin, which is also the start of the next next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" + interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" ) - in_bin = slice(current_start_index, next_start_index) + frames_in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index return traces_corrected From 91fb7320eafbcdaa8fe728e36cbc7fa686f32ba8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 20:59:13 +0000 Subject: [PATCH 27/32] Note on clipping behavior --- .../sortingcomponents/motion/motion_interpolation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 2bd3493650..8aed8085bf 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -64,7 +64,11 @@ def interpolate_motion_on_traces( """ Apply inverse motion with spatial interpolation on traces. - Traces can be full traces, but also waveforms snippets. + Traces can be full traces, but also waveforms snippets. Times used for looking up + displacements are controlled by interpolation_time_bin_edges_s or + interpolation_time_bin_centers_s, or fall back to the Motion object's time bins + by default; times in the recording outside these time bins use the closest edge + bin's displacement value during interpolation. Parameters ---------- From b80bad71e2c1ee316048beb79a9169dd8f68c6ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 21:01:47 +0000 Subject: [PATCH 28/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion/motion_interpolation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 8aed8085bf..14471f77fc 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -65,7 +65,7 @@ def interpolate_motion_on_traces( Apply inverse motion with spatial interpolation on traces. Traces can be full traces, but also waveforms snippets. Times used for looking up - displacements are controlled by interpolation_time_bin_edges_s or + displacements are controlled by interpolation_time_bin_edges_s or interpolation_time_bin_centers_s, or fall back to the Motion object's time bins by default; times in the recording outside these time bins use the closest edge bin's displacement value during interpolation. @@ -133,7 +133,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s = motion.temporal_bin_centers_s[segment_index] interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] else: - interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(interpolation_time_bin_centers_s, interpolation_time_bin_edges_s) + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) # bin the frame times according to the interpolation time bins. # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] From c89060314e233b89e4e9112e74c7643545806d22 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 16:09:45 -0500 Subject: [PATCH 29/32] Fix variable typo; add docstring --- .../motion/motion_interpolation.py | 5 +++-- .../sortingcomponents/motion/motion_utils.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 14471f77fc..e87f83751c 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -3,7 +3,8 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.basepreprocessor import ( + BasePreprocessor, BasePreprocessorSegment) from spikeinterface.preprocessing.filter import fix_dtype from .motion_utils import ensure_time_bin_edges, ensure_time_bins @@ -155,7 +156,7 @@ def interpolate_motion_on_traces( interp_times = np.empty(total_num_chans) current_start_index = 0 for interp_bin_ind in interpolation_bins_here: - bin_time = bin_centers_s[interp_bin_ind] + bin_time = interpolation_time_bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index ec0a55a8f8..680d75f221 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -580,6 +580,22 @@ def make_3d_motion_histograms( def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + """Ensure that both bin edges and bin centers are present + + If either of the inputs are None but not both, the missing is reconstructed + from the present. Going from edges to centers is done by taking midpoints. + Going from centers to edges is done by taking midpoints and padding with the + left and rightmost centers. + + Parameters + ---------- + time_bin_centers_s : None or np.array + time_bin_edges_s : None or np.array + + Returns + ------- + time_bin_centers_s, time_bin_edges_s + """ if time_bin_centers_s is None and time_bin_edges_s is None: raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") From 6d2e47911a5e64737f2c98f9cdc199e9f6d306fc Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 16:14:00 -0500 Subject: [PATCH 30/32] Variable names in tests --- .../motion/tests/test_motion_interpolation.py | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index 8542b62524..c97c8324ba 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,11 +4,8 @@ import spikeinterface.core as sc from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( - InterpolateMotionRecording, - correct_motion_on_peaks, - interpolate_motion, - interpolate_motion_on_traces, -) + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -84,26 +81,26 @@ def test_interpolate_motion_on_traces(): def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -131,8 +128,8 @@ def test_cross_band_interpolation(): fs_ap = 300.0 t_start = 10.0 total_duration = 5.0 - nt_lfp = int(fs_lfp * total_duration) - nt_ap = int(fs_ap * total_duration) + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) t_switch = 3 # because interpolation uses bin centers logic, there will be a half @@ -140,18 +137,18 @@ def test_cross_band_interpolation(): halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) # channel geometry - nc = 10 - geom = np.c_[np.zeros(nc), np.arange(nc)] + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] # make an LFP recording which drifts a bit - traces_lfp = np.zeros((nt_lfp, nc)) + traces_lfp = np.zeros((num_samples_lfp, num_chans)) traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) rec_lfp.set_dummy_probe_from_locations(geom) # same for AP - traces_ap = np.zeros((nt_ap, nc)) + traces_ap = np.zeros((num_samples_ap, num_chans)) traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) @@ -160,8 +157,8 @@ def test_cross_band_interpolation(): # set times for both, and silence the warning with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) - rec_lfp.set_times(t_start + np.arange(nt_lfp) / fs_lfp) - rec_ap.set_times(t_start + np.arange(nt_ap) / fs_ap) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) # estimate motion motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) @@ -169,7 +166,7 @@ def test_cross_band_interpolation(): # nearest to keep it simple rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) traces_corrected = rec_corrected.get_traces() - target = np.zeros((nt_ap, nc - 2)) + target = np.zeros((num_samples_ap, num_chans - 2)) target[:, 4] = 1 ii, jj = np.nonzero(traces_corrected) assert np.array_equal(traces_corrected, target) From b4c91a0d941e97be68908b11974d18f802ec74a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 21:15:30 +0000 Subject: [PATCH 31/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/motion/motion_interpolation.py | 3 +-- .../motion/tests/test_motion_interpolation.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index e87f83751c..b3a4c9a207 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -3,8 +3,7 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype from .motion_utils import ensure_time_bin_edges, ensure_time_bins diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index c97c8324ba..e4ba870325 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -4,8 +4,11 @@ import spikeinterface.core as sc from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( - InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, - interpolate_motion_on_traces) + InterpolateMotionRecording, + correct_motion_on_peaks, + interpolate_motion, + interpolate_motion_on_traces, +) from spikeinterface.sortingcomponents.tests.common import make_dataset From 38e0adac18c2c862f9193e5a5d643b394a6d3d52 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 22 Nov 2024 10:28:34 -0500 Subject: [PATCH 32/32] Typo --- .../sortingcomponents/motion/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index b3a4c9a207..fc8ccb788b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -130,7 +130,7 @@ def interpolate_motion_on_traces( # -- determine the blocks of frames that will land in the same interpolation time bin if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: - interpolation_time_bin_centers_s = motion.temporal_bin_centers_s[segment_index] + interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index] interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] else: interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(