diff --git a/doc/development/development.rst b/doc/development/development.rst index a91818a271..1638c41243 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -213,6 +213,25 @@ We use Sphinx to build the documentation. To build the documentation locally, yo This will build the documentation in the :code:`doc/_build/html` folder. You can open the :code:`index.html` file in your browser to see the documentation. +Adding new documentation +------------------------ + +Documentation can be added as a +`sphinx-gallery `_ +python file ('tutorials') +or a +`sphinx rst `_ +file (all other sections). + +To add a new tutorial, add your ``.py`` file to ``spikeinterface/examples``. +Then, update the ``spikeinterface/doc/tutorials_custom_index.rst`` file +to make a new card linking to the page and an optional image. See +``tutorials_custom_index.rst`` header for more information. + +For other sections, write your documentation in ``.rst`` format and add +the page to the appropriate ``index.rst`` file found in the relevant +folder (e.g. ``how_to/index.rst``). + How to run code coverage locally -------------------------------- To run code coverage locally, you can use the following command: diff --git a/doc/index.rst b/doc/index.rst index ed443e4200..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -51,7 +51,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index - tutorials/index + tutorials_custom_index how_to/index modules/index api diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst new file mode 100644 index 0000000000..4c7625d811 --- /dev/null +++ b/doc/tutorials_custom_index.rst @@ -0,0 +1,196 @@ +.. This page provides a custom index to the 'Tutorials' page, rather than the default sphinx-gallery +.. generated page. The benefits of this are flexibility in design and inclusion of non-sphinx files in the index. +.. +.. To update this index with a new documentation page +.. 1) Copy the grid-item-card and associated ".. raw:: html" section. +.. 2) change :link: to a link to your page. If this is an `.rst` file, point to the rst file directly. +.. If it is a sphinx-gallery generated file, format the path as separated by underscore and prefix `sphx_glr`, +.. pointing to the .py file. e.g. `tutorials/my/page.py` -> `sphx_glr_tutorials_my_page.py +.. 3) Change :img-top: to point to the thumbnail image of your choosing. You can point to images generated +.. in the sphinx gallery page if you wish. +.. 4) In the `html` section, change the `default-title` to your pages title and `hover-content` to the subtitle. + +:orphan: + +Tutorials +============ + +Longer form tutorials about using SpikeInterface. Many of these are downloadable +as notebooks or Python scripts so that you can "code along" with the tutorials. + +If you're new to SpikeInterface, we recommend trying out the +:ref:`get_started/quickstart:Quickstart tutorial` first. + +Updating from legacy +-------------------- + +.. toctree:: + :maxdepth: 1 + + tutorials/waveform_extractor_to_sorting_analyzer + +Core tutorials +-------------- + +These tutorials focus on the :py:mod:`spikeinterface.core` module. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Recording objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png + :img-alt: Recording objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Sorting objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png + :img-alt: Sorting objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handling probe information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png + :img-alt: Handling probe information + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingAnalyzer + :link-type: ref + :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png + :img-alt: SortingAnalyzer + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Append and/or concatenate segments + :link-type: ref + :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png + :img-alt: Append/Concatenate segments + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handle time information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_6_handle_times.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png + :img-alt: Handle time information + :class-card: gallery-card + :text-align: center + +Extractors tutorials +-------------------- + +The :py:mod:`spikeinterface.extractors` module is designed to load and save recorded and sorted data, and to handle probe information. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Read various formats + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :img-alt: Read various formats + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Working with unscaled traces + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png + :img-alt: Unscaled traces + :class-card: gallery-card + :text-align: center + +Quality metrics tutorial +------------------------ + +The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Quality Metrics + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png + :img-alt: Quality Metrics + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Curation Tutorial + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :img-alt: Curation Tutorial + :class-card: gallery-card + :text-align: center + +Comparison tutorial +------------------- + +The :code:`spikeinterface.comparison` module allows you to compare sorter outputs or benchmark against ground truth. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Sorter Comparison + :link-type: ref + :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py + :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png + :img-alt: Sorter Comparison + :class-card: gallery-card + :text-align: center + +Widgets tutorials +----------------- + +The :code:`widgets` module contains several plotting routines (widgets) for visualizing recordings, sorting data, probe layout, and more. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: RecordingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png + :img-alt: Recording Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png + :img-alt: Sorting Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Waveforms Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png + :img-alt: Waveforms Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Peaks Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png + :img-alt: Peaks Widgets + :class-card: gallery-card + :text-align: center + +Download All Examples +--------------------- + +- :download:`Download all examples in Python source code ` +- :download:`Download all examples in Jupyter notebooks ` diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..7ca527e255 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,4 +1,5 @@ from __future__ import annotations + import warnings from pathlib import Path @@ -7,14 +8,9 @@ from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets -from .core_tools import ( - convert_bytes_to_str, - convert_seconds_to_str, -) -from .recording_tools import write_binary_recording - - +from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs +from .recording_tools import write_binary_recording class BaseRecording(BaseRecordingSnippets): @@ -509,6 +505,35 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. + + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(segment_index=idx): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds @@ -921,11 +946,11 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = round(sample_index) + sample_index = np.round(sample_index).astype(int) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - return int(sample_index) + return sample_index def get_num_samples(self) -> int: """Returns the number of samples in this signal segment diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..9b7ed11bbb 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load_extractor(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index d7e5b58e11..728d352973 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import warnings import numpy as np import probeinterface @@ -30,8 +31,10 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file : str or None, default None + cbin_file_path : str, Path or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. + cbin_file : str or None, default None + (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- @@ -41,14 +44,23 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): + def __init__( + self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None + ): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is None: + if cbin_file is not None: + warnings.warn( + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + DeprecationWarning, + stacklevel=2, + ) + cbin_file_path = cbin_file + if cbin_file_path is None: folder_path = Path(folder_path) # check bands assert stream_name in ["ap", "lp"], "stream_name must be one of: 'ap', 'lp'" @@ -60,17 +72,17 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", assert ( len(curr_cbin_files) == 1 ), f"There should only be one `*.cbin` file in the folder, but {print(curr_cbin_files)} have been found" - cbin_file = curr_cbin_files[0] + cbin_file_path = curr_cbin_files[0] else: - cbin_file = Path(cbin_file) - folder_path = cbin_file.parent + cbin_file_path = Path(cbin_file_path) + folder_path = cbin_file_path.parent - ch_file = cbin_file.with_suffix(".ch") - meta_file = cbin_file.with_suffix(".meta") + ch_file = cbin_file_path.with_suffix(".ch") + meta_file = cbin_file_path.with_suffix(".meta") # reader cbuffer = mtscomp.Reader() - cbuffer.open(cbin_file, ch_file) + cbuffer.open(cbin_file_path, ch_file) # meta data meta = read_meta_file(meta_file) @@ -119,7 +131,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", self._kwargs = { "folder_path": str(Path(folder_path).resolve()), "load_sync_channel": load_sync_channel, - "cbin_file": str(Path(cbin_file).resolve()), + "cbin_file_path": str(Path(cbin_file_path).resolve()), } diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index f055e1d7c9..d2886d9e79 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -72,6 +72,7 @@ def write_recording( params_fname="params.json", geom_fname="geom.csv", dtype=None, + verbose=False, **job_kwargs, ): """Write a recording to file in MDA format. @@ -93,6 +94,8 @@ def write_recording( File name of geom file dtype : dtype or None, default: None Data type to be used. If None dtype is same as recording traces. + verbose : bool + If True, shows progress bar when saving recording. **job_kwargs: Use by job_tools modules to set: @@ -130,6 +133,7 @@ def write_recording( dtype=dtype, byte_offset=header_size, add_file_extension=False, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a5e6ded519..fc8ccb788b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -6,6 +6,8 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype +from .motion_utils import ensure_time_bin_edges, ensure_time_bins + def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ @@ -54,6 +56,7 @@ def interpolate_motion_on_traces( segment_index=None, channel_inds=None, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, dtype=None, @@ -61,7 +64,11 @@ def interpolate_motion_on_traces( """ Apply inverse motion with spatial interpolation on traces. - Traces can be full traces, but also waveforms snippets. + Traces can be full traces, but also waveforms snippets. Times used for looking up + displacements are controlled by interpolation_time_bin_edges_s or + interpolation_time_bin_centers_s, or fall back to the Motion object's time bins + by default; times in the recording outside these time bins use the closest edge + bin's displacement value during interpolation. Parameters ---------- @@ -80,6 +87,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. + interpolation_time_bin_edges_s : None or np.array + If present, interpolation chunks will be the time bins defined by these edges + rather than interpolation_time_bin_centers_s or the motion's bins. spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -119,26 +129,33 @@ def interpolate_motion_on_traces( total_num_chans = channel_locations.shape[0] # -- determine the blocks of frames that will land in the same interpolation time bin - time_bins = interpolation_time_bin_centers_s - if time_bins is None: - time_bins = motion.temporal_bins_s[segment_index] - bin_s = time_bins[1] - time_bins[0] - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index] + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) + + # bin the frame times according to the interpolation time bins. + # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # hence the -1. doing it with "left" is not as nice -- we want t==b[0] + # to lead to i=1 (rounding down). + interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 + # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + n_bins = interpolation_time_bin_edges_s.shape[0] - 1 + np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) current_start_index = 0 - for bin_ind in bins_here: - bin_time = time_bins[bin_ind] + for interp_bin_ind in interpolation_bins_here: + bin_time = interpolation_time_bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -166,16 +183,17 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() + # quick search logic to find frames corresponding to this interpolation bin in the recording # quickly find the end of this bin, which is also the start of the next next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" + interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" ) - in_bin = slice(current_start_index, next_start_index) + frames_in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index return traces_corrected @@ -297,6 +315,7 @@ def __init__( p=1, num_closest=3, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, interpolation_time_bin_size_s=None, dtype=None, **spatial_interpolation_kwargs, @@ -363,9 +382,14 @@ def __init__( # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below - if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: if interpolation_time_bin_size_s is None: interpolation_time_bin_centers_s = motion.temporal_bins_s + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) for segment_index, parent_segment in enumerate(recording._recording_segments): # finish the per-segment part of the time bin logic @@ -375,8 +399,13 @@ def __init__( t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + segment_interpolation_time_bin_edges_s = np.arange( + t_start, t_end + halfbin, interpolation_time_bin_size_s + ) + assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] rec_segment = InterpolateMotionRecordingSegment( parent_segment, @@ -387,6 +416,7 @@ def __init__( channel_inds, segment_index, segment_interpolation_time_bins_s, + segment_interpolation_time_bin_edges_s, dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -420,6 +450,7 @@ def __init__( channel_inds, segment_index, interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s, dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -429,13 +460,11 @@ def __init__( self.channel_inds = channel_inds self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s self.dtype = dtype self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: - raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") - if start_frame is None: start_frame = 0 if end_frame is None: @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, - interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, ) if channel_indices is not None: diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 635624cca8..680d75f221 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -1,5 +1,5 @@ -import warnings import json +import warnings from pathlib import Path import numpy as np @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) self.check_properties() + self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) @@ -576,3 +577,40 @@ def make_3d_motion_histograms( motion_histograms = np.log2(1 + motion_histograms) return motion_histograms, temporal_bin_edges, spatial_bin_edges + + +def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + """Ensure that both bin edges and bin centers are present + + If either of the inputs are None but not both, the missing is reconstructed + from the present. Going from edges to centers is done by taking midpoints. + Going from centers to edges is done by taking midpoints and padding with the + left and rightmost centers. + + Parameters + ---------- + time_bin_centers_s : None or np.array + time_bin_edges_s : None or np.array + + Returns + ------- + time_bin_centers_s, time_bin_edges_s + """ + if time_bin_centers_s is None and time_bin_edges_s is None: + raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") + + if time_bin_centers_s is None: + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + + if time_bin_edges_s is None: + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + + return time_bin_centers_s, time_bin_edges_s + + +def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): + return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e022f0cc6c..e4ba870325 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -1,16 +1,14 @@ -from pathlib import Path +import warnings import numpy as np -import pytest import spikeinterface.core as sc -from spikeinterface import download_dataset +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -67,43 +65,45 @@ def test_interpolate_motion_on_traces(): times = rec.get_times()[0:30000] for method in ("kriging", "idw", "nearest"): - traces_corrected = interpolate_motion_on_traces( - traces, - times, - channel_locations, - motion, - channel_inds=None, - spatial_interpolation_method=method, - # spatial_interpolation_kwargs={}, - spatial_interpolation_kwargs={"force_extrapolate": True}, - ) - assert traces.shape == traces_corrected.shape - assert traces.dtype == traces_corrected.dtype + for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): + traces_corrected = interpolate_motion_on_traces( + traces, + times, + channel_locations, + motion, + channel_inds=None, + spatial_interpolation_method=method, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, + ) + assert traces.shape == traces_corrected.shape + assert traces.dtype == traces_corrected.dtype def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -115,6 +115,66 @@ def test_interpolation_simple(): assert np.all(traces_corrected[:, 2:] == 0) +def test_cross_band_interpolation(): + """Simple version of using LFP to interpolate AP data + + This also tests the time vector implementation in interpolation. + The idea is to have two recordings which are all 0s with a 1 that + moves from one channel to another after 3s. They're at different + sampling frequencies. motion estimation in one sampling frequency + applied to the other should still lead to perfect correction. + """ + from spikeinterface.sortingcomponents.motion import estimate_motion + + # sampling freqs and timing for AP and LFP recordings + fs_lfp = 50.0 + fs_ap = 300.0 + t_start = 10.0 + total_duration = 5.0 + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) + t_switch = 3 + + # because interpolation uses bin centers logic, there will be a half + # bin offset at the change point in the AP recording. + halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) + + # channel geometry + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] + + # make an LFP recording which drifts a bit + traces_lfp = np.zeros((num_samples_lfp, num_chans)) + traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 + traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 + rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) + rec_lfp.set_dummy_probe_from_locations(geom) + + # same for AP + traces_ap = np.zeros((num_samples_ap, num_chans)) + traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 + traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 + rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) + rec_ap.set_dummy_probe_from_locations(geom) + + # set times for both, and silence the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) + + # estimate motion + motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) + + # nearest to keep it simple + rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) + traces_corrected = rec_corrected.get_traces() + target = np.zeros((num_samples_ap, num_chans - 2)) + target[:, 4] = 1 + ii, jj = np.nonzero(traces_corrected) + assert np.array_equal(traces_corrected, target) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -147,6 +207,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - # test_interpolate_motion_on_traces() - test_interpolation_simple() - test_InterpolateMotionRecording() + test_interpolate_motion_on_traces() + # test_interpolation_simple() + # test_InterpolateMotionRecording() + test_cross_band_interpolation()