diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 29ab453a99..004fe31203 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -24,18 +24,15 @@ runs: pip install -e .[test,extractors,full] shell: bash - name: Force installation of latest dev from key-packages when running dev (not release) - id: version run: | source ${{ github.workspace }}/test_env/bin/activate - if python ./.github/is_spikeinterface_dev.py; then + spikeinterface_is_dev_version=$(python -c "import importlib.metadata; version = importlib.metadata.version('spikeinterface'); print(version.endswith('dev0'))") + if [ $spikeinterface_is_dev_version = "True" ]; then echo "Running spikeinterface dev version" - pip uninstall -y neo - pip uninstall -y probeinterface - pip install git+https://github.com/NeuralEnsemble/python-neo - pip install git+https://github.com/SpikeInterface/probeinterface - else - echo "Running tests for release" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash - name: git-annex install run: | diff --git a/.github/is_spikeinterface_dev.py b/.github/is_spikeinterface_dev.py deleted file mode 100644 index 621305af90..0000000000 --- a/.github/is_spikeinterface_dev.py +++ /dev/null @@ -1,6 +0,0 @@ -import importlib.metadata - -package_name = "spikeinterface" -version = importlib.metadata.version(package_name) -if version.endswith("dev0"): - print(True) diff --git a/.github/workflows/caches_cron_job.yml b/.github/workflows/caches_cron_job.yml index 3ed91b84c4..237612d5d3 100644 --- a/.github/workflows/caches_cron_job.yml +++ b/.github/workflows/caches_cron_job.yml @@ -33,6 +33,7 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ steps.dependencies.outputs.hash }}-${{ steps.date.outputs.date }} + lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - name: Cache found? run: echo "Cache-hit == ${{steps.cache-venv.outputs.cache-hit == 'true'}}" - name: Create the virtual environment to be cached @@ -64,6 +65,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.repo_hash.outputs.dataset_hash }} + lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - name: Cache found? run: echo "Cache-hit == ${{steps.cache-datasets.outputs.cache-hit == 'true'}}" - name: Installing datalad and git-annex @@ -88,7 +90,7 @@ jobs: run: | cd $HOME pwd - du -hs spikeinterface_datasets + du -hs spikeinterface_datasets # Should show the size of ephy_testing_data cd spikeinterface_datasets pwd ls -lh # Should show ephy_testing_data diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 3da889d64e..a5561c2ffc 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -32,8 +32,6 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - restore-keys: | - ${{ runner.os }}-venv- - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -48,8 +46,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.vars.outputs.HASH_EPHY_DATASET }} - restore-keys: | - ${{ runner.os }}-datasets + restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - name: Shows installed packages by pip, git-annex and cached testing files diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index 3e8b082c50..ac5130bade 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -37,8 +37,6 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - restore-keys: | - ${{ runner.os }}-venv- - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -53,8 +51,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.vars.outputs.HASH_EPHY_DATASET }} - restore-keys: | - ${{ runner.os }}-datasets + restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - name: Shows installed packages by pip, git-annex and cached testing files @@ -66,6 +63,10 @@ jobs: id: modules-changed run: | for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + if [[ $file == *"pyproject.toml" ]]; then + echo "pyproject.toml changed" + echo "CORE_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/core/"* || $file == *"/extractors/neoextractors/neobaseextractor.py" ]]; then echo "Core changed" echo "CORE_CHANGED=true" >> $GITHUB_OUTPUT diff --git a/.github/workflows/streaming-extractor-test.yml b/.github/workflows/streaming-extractor-test.yml index 1498684d77..064d38fcc4 100644 --- a/.github/workflows/streaming-extractor-test.yml +++ b/.github/workflows/streaming-extractor-test.yml @@ -1,6 +1,10 @@ name: Test streaming extractors -on: workflow_dispatch +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} @@ -28,9 +32,21 @@ jobs: - run: git fetch --prune --unshallow --tags - name: Install openblas run: sudo apt install libopenblas-dev # Necessary for ROS3 support - - name: Install package and streaming extractor dependencies + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v35 + - name: Module changes + id: modules-changed run: | - pip install -e .[test_core,streaming_extractors] + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + if [[ $file == *"/nwbextractors.py" || $file == *"/iblstreamingrecording.py"* ]]; then + echo "Streaming files changed changed" + echo "STREAMING_CHANGED=true" >> $GITHUB_OUTPUT + fi + done + - name: Install package and streaming extractor dependencies + if: ${{ steps.modules-changed.outputs.STREAMING_CHANGED == 'true' }} + run: pip install -e .[test_core,streaming_extractors] # Temporary disabled because of complicated error with path # - name: Install h5py with ROS3 support and test it works # run: | @@ -38,4 +54,5 @@ jobs: # conda install -c conda-forge "h5py>=3.2" # python -c "import h5py; assert 'ros3' in h5py.registered_drivers(), f'ros3 suppport not available, failed to install'" - name: run tests + if: steps.modules-changed.outputs.STREAMING_CHANGED == 'true' run: pytest -m "streaming_extractors and not ros3_test" -vv -ra diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 816e4e24d6..ced1ee6a2f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black files: ^src/ diff --git a/doc/api.rst b/doc/api.rst index e0a863bd9c..2e9fc1567a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -269,13 +269,14 @@ spikeinterface.widgets .. autofunction:: plot_amplitudes .. autofunction:: plot_autocorrelograms .. autofunction:: plot_crosscorrelograms + .. autofunction:: plot_motion .. autofunction:: plot_quality_metrics .. autofunction:: plot_sorting_summary .. autofunction:: plot_spike_locations .. autofunction:: plot_spikes_on_traces .. autofunction:: plot_template_metrics .. autofunction:: plot_template_similarity - .. autofunction:: plot_timeseries + .. autofunction:: plot_traces .. autofunction:: plot_unit_depths .. autofunction:: plot_unit_locations .. autofunction:: plot_unit_summary diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 0a02a47211..c921b13719 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -264,7 +264,7 @@ the ipywydgets interactive ploter .. code:: python %matplotlib widget - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything @@ -276,9 +276,9 @@ is lazy, so you can change the previsous cell (parameters, step order, # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) - si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) - si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) - si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) + si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) + si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) + si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) @@ -292,7 +292,7 @@ is lazy, so you can change the previsous cell (parameters, step order, # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) @@ -426,7 +426,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks @@ -451,7 +451,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) + peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index 02ccb872d1..0dd618e972 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -104,7 +104,7 @@ and the raster plots. .. code:: ipython3 - w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) + w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) @@ -266,7 +266,7 @@ available parameters are dictionaries and can be accessed with: 'clustering': {}, 'detection': {'detect_threshold': 5, 'peak_sign': 'neg'}, 'filtering': {'dtype': 'float32'}, - 'general': {'local_radius_um': 100, 'ms_after': 2, 'ms_before': 2}, + 'general': {'radius_um': 100, 'ms_after': 2, 'ms_before': 2}, 'job_kwargs': {}, 'localization': {}, 'matching': {}, diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index b59fa4dfcb..7ff98a666b 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -7,7 +7,7 @@ Handle motion/drift with spikeinterface ======================================= -Spikeinterface offers a very flexible framework to handle drift as a +SpikeInterface offers a very flexible framework to handle drift as a preprocessing step. If you want to know more, please read the :ref:`motion_correction` section of the documentation. @@ -96,7 +96,7 @@ Correcting for drift is easy! You just need to run a single function. We will try this function with 3 presets. Internally a preset is a dictionary of dictionaries containing all -parameters for every steps. +parameters for each step. Here we also save the motion correction results into a folder to be able to load them later. @@ -118,10 +118,10 @@ to load them later. 'peak_sign': 'neg', 'detect_threshold': 8.0, 'exclude_sweep_ms': 0.1, - 'local_radius_um': 50}, + 'radius_um': 50}, 'select_kwargs': None, 'localize_peaks_kwargs': {'method': 'grid_convolution', - 'local_radius_um': 30.0, + 'radius_um': 30.0, 'upsampling_um': 3.0, 'sigma_um': array([ 5. , 12.5, 20. ]), 'sigma_ms': 0.25, @@ -185,14 +185,14 @@ A few comments on the figures: start moving is recovered quite well. * The preset **kilosort_like** gives better results because it is a non-rigid case. The motion vector is computed for different depths. The corrected peak locations are - flatter than the rigid case. The motion vector map is still be a bit - noisy at some depths (e.g around 1000um). + flatter than the rigid case. The motion vector map is still a bit + noisy at some depths (e.g. around 1000um). * The preset **nonrigid_accurate** seems to give the best results on this recording. The motion vector seems less noisy globally, but it is not “perfect” (see at the top of the probe 3200um to 3800um). Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion: the upper part of the probe - (2000-3000um) experience some drifts, but the lower part (0-1000um) is + (2000-3000um) experience some drift, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. .. code:: ipython3 @@ -204,8 +204,8 @@ A few comments on the figures: # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(rec, motion_info, figure=fig, depth_lim=(400, 600), - color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) + si.plot_motion(motion_info, figure=fig, depth_lim=(400, 600), + color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) fig.suptitle(f"{preset=}") @@ -237,7 +237,7 @@ axis, especially for the preset “nonrigid_accurate”. Be aware that there are two ways to correct for the motion: 1. Interpolate traces and detect/localize peaks again -(:py:func:`interpolate_recording()`) 2. Compensate for drifts directly on peak +(:py:func:`interpolate_recording()`) 2. Compensate for drift directly on peak locations (:py:func:`correct_motion_on_peaks()`) Case 1 is used before running a spike sorter and the case 2 is used here @@ -272,7 +272,7 @@ to display the results. #color='black', ax.scatter(loc['x'][mask][sl], loc['y'][mask][sl], **color_kargs) - loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.get_times(), + loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.sampling_frequency, motion_info['motion'], motion_info['temporal_bins'], motion_info['spatial_bins'], direction="y") ax = axs[1] diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 1e55827ffd..3fda05848c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -191,7 +191,7 @@ Mountainsort5 pip install mountainsort5 SpyKING CIRCUS -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ * Python, requires MPICH * Url: https://spyking-circus.readthedocs.io diff --git a/doc/installation.rst b/doc/installation.rst index 80452a60e7..acc5117249 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -38,7 +38,7 @@ From source As :code:`spikeinterface` is undergoing a heavy development phase, it is sometimes convenient to install from source to get the latest bug fixes and improvements. We recommend constructing the package within a -[virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) +`virtual environment `_ to prevent potential conflicts with local dependencies. .. code-block:: bash @@ -49,7 +49,7 @@ to prevent potential conflicts with local dependencies. pip install -e . cd .. -Note that this will install the package in [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs). +Note that this will install the package in `editable mode `_. It is also recommended in that case to also install :code:`neo` and :code:`probeinterface` from source, as :code:`spikeinterface` strongly relies on these packages to interface with various formats and handle probes: diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 9af69768dd..fdc4d71fe7 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -137,6 +137,7 @@ It interfaces with a spike-sorted output and has the following features: * enable selection of sub-units * handle time information + Here we assume :code:`sorting` is a :py:class:`~spikeinterface.core.BaseSorting` object with 10 units: @@ -181,6 +182,12 @@ with 10 units: # times are not set, the samples are divided by the sampling frequency +Internally, any sorting object can construct 2 internal caches: + 1. a list (per segment) of dict (per unit) of numpy.array. This cache is usefull when accessing spiketrains unit + per unit across segments. + 2. a unique numpy.array with structured dtype aka "spikes vector". This is usefull for processing by small chunk of + time, like extract amplitudes from a recording. + WaveformExtractor ----------------- @@ -190,7 +197,7 @@ The :py:class:`~spikeinterface.core.WaveformExtractor` class is the core object Waveforms are very important for additional analysis, and the basis of several postprocessing and quality metrics computations. -The :py:class:`~spikeinterface.core.WaveformExtractor` allows to: +The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: * extract and waveforms * sub-sample spikes for waveform extraction @@ -199,7 +206,7 @@ The :py:class:`~spikeinterface.core.WaveformExtractor` allows to: * save sparse waveforms or *sparsify* dense waveforms * select units and associated waveforms -The default format (:code:`mode='folder'`) which waveforms are saved to is a folder structure with waveforms as +The default format (:code:`mode='folder'`) which waveforms are saved to a folder structure with waveforms as :code:`.npy` files. In addition, waveforms can also be extracted in-memory for fast computations (:code:`mode='memory'`). Note that this mode can quickly fill up your RAM... Use it wisely! @@ -231,7 +238,7 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s # (this can also be done within the 'extract_waveforms') we.precompute_templates(modes=("std",)) - # retrieve all template means and standard devs + # retrieve all template means and standard deviations template_means = we.get_all_templates(mode="average") template_stds = we.get_all_templates(mode="std") @@ -490,11 +497,11 @@ Parallel processing and job_kwargs The :py:mod:`~spikeinterface.core` module also contains the basic tools used throughout SpikeInterface for parallel processing of recordings. -In general, parallelization is achieved by splitting the recording in many small time chunks and process +In general, parallelization is achieved by splitting the recording in many small time chunks and processing them in parallel (for more details, see the :py:class:`~spikeinterface.core.ChunkRecordingExecutor` class). Many functions support parallel processing (e.g., :py:func:`~spikeinterface.core.extract_waveforms`, :code:`save`, -and many more). All of this functions, in addition to other arguments, also accept the so-called **job_kwargs**. +and many more). All of these functions, in addition to other arguments, also accept the so-called **job_kwargs**. These are a set of keyword arguments which are common to all functions that support parallelization: * chunk_duration or chunk_size or chunk_memory or total_memory @@ -513,11 +520,11 @@ These are a set of keyword arguments which are common to all functions that supp If True, a progress bar is printed * mp_context: str or None Context for multiprocessing. It can be None (default), "fork" or "spawn". - Note that "fork" is only available on UNIX systems + Note that "fork" is only available on UNIX systems (not Windows) The default **job_kwargs** are :code:`n_jobs=1, chunk_duration="1s", progress_bar=True`. -Any of these argument, can be overridden by manually passing the argument to a function +Any of these arguments, can be overridden by manually passing the argument to a function (e.g., :code:`extract_waveforms(..., n_jobs=16)`). Alternatively, **job_kwargs** can be set globally (for each SpikeInterface session), with the :py:func:`~spikeinterface.core.set_global_job_kwargs` function: @@ -543,6 +550,10 @@ In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikein but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to +Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for +parallel computing. + In this example, we create a recording and a sorting object from numpy objects: .. code-block:: python @@ -574,6 +585,18 @@ In this example, we create a recording and a sorting object from numpy objects: sampling_frequency=sampling_frequency) +Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or +:py:class:`~spikeinterface.core.SharedMemorySorting` easily like this + +.. code-block:: python + + # turn any sortinto into NumpySorting + soring_np = sorting.to_numpy_sorting() + + # or to SharedMemorySorting for parrallel computing + sorting_shm = sorting.to_shared_memory_sorting() + + .. _multi_seg: Manipulating objects: slicing, aggregating @@ -707,13 +730,13 @@ The :py:mod:`spikeinterface.core.template_tools` submodule includes functionalit Generate toy objects -------------------- -The :py:mod:`~spikeinterface.core` module also offers some functions to generate toy/fake data. +The :py:mod:`~spikeinterface.core` module also offers some functions to generate toy/simulated data. They are useful to make examples, tests, and small demos: .. code-block:: python # recording with 2 segments and 4 channels - recording = generate_recording(generate_recording(num_channels=4, sampling_frequency=30000., + recording = generate_recording(num_channels=4, sampling_frequency=30000., durations=[10.325, 3.5], set_probe=True) # sorting with 2 segments and 5 units @@ -739,7 +762,7 @@ There are also some more advanced functions to generate sorting objects with var Downloading test datasets ------------------------- -The `NEO `_ package is maintaining a collection a files of many +The `NEO `_ package is maintaining a collection of many electrophysiology file formats: https://gin.g-node.org/NeuralEnsemble/ephy_testing_data The :py:func:`~spikeinterface.core.download_dataset` function is capable of downloading and caching locally dataset diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 1b582dbafc..62c0d6b8d4 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -7,26 +7,26 @@ Motion/drift correction Overview -------- -Mechanical drifts, often observed in recordings, are currently a major issue for spike sorting. This is especially striking -with the new generation of high-density devices used in-vivo such as the neuropixel electrodes. -The first sorter that has introduced motion/drift correction as a prepossessing step was kilosort2.5 (see [Steinmetz2021]_) +Mechanical drift, often observed in recordings, is currently a major issue for spike sorting. This is especially striking +with the new generation of high-density devices used for in-vivo electrophyisology such as the neuropixel electrodes. +The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_) Long story short, the main idea is the same as the one used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea of the algorithm is create an "image" via the activity profile of the cells during a given time window. Assuming this activity profile should be kept constant over time, the motion can be estimated, by blocks, along the probe's insertion axis -(i.e. depth) so that we can interpolate the traces to compensate this estimated motion. +(i.e. depth) so that we can interpolate the traces to compensate for this estimated motion. Users with a need to handle drift were currently forced to stick to the use of Kilosort2.5 or pyKilosort (see [Pachitariu2023]_). Recently, the Paninski group from Columbia University introduced a possibly more accurate method to estimate the drift (see [Varol2021]_ -and [Windolf2023]_), but this new method was not properly integrated in any sorter. +and [Windolf2023]_), but this new method was not properly integrated into any sorter. -Because motion registration is a hard topic, with numerous hypothesis and/or implementations details that might have a large +Because motion registration is a hard topic, with numerous hypotheses and/or implementations details that might have a large impact on the spike sorting performances (see [Garcia2023]_), in SpikeInterface, we developed a full motion estimation and interpolation framework to make all these methods accessible in one place. This modular approach offers a major benefit: **the drift correction can be applied to a recording as a preprocessing step, and then used for any sorter!** In short, the motion correction is decoupled from the sorter itself. -This gives the user an incredible flexibility to check/test and correct the drifts before the sorting process. +This gives the user an incredible flexibility to check/test and correct the drift before the sorting process. Here is an overview of the motion correction as a preprocessing: @@ -41,21 +41,21 @@ The motion correction process can be split into 3 steps: For every steps, we implemented several methods. The combination of the yellow boxes should give more or less what Kilosort2.5/3 is doing. Similarly, the combination of the green boxes gives the method developed by the Paninski group. -Of course the end user can combine any of the methods to get the best motion correction possible. -This make also an incredible framework for testing new ideas. +Of course the end user can combine any of these methods to get the best motion correction possible. +This also makes an incredible framework for testing new ideas. For a better overview, checkout our recent paper to validate, benchmark, and compare these motion correction methods (see [Garcia2023]_). SpikeInterface offers two levels for motion correction: 1. A high level with a unique function and predefined parameter presets - 2. A low level where the user need to call one by one all functions for a better control + 2. A low level where the user needs to call one by one all functions for better control High-level API -------------- -One challenging task for motion correction is to find parameters. +One challenging task for motion correction is to determine the parameters. The high level :py:func:`~spikeinterface.preprocessing.correct_motion()` proposes the concept of a **"preset"** that already has predefined parameters, in order to achieve a calibrated behavior. @@ -69,7 +69,7 @@ We currently have 3 presets: To be used as check and/or control on a recording to check the presence of drift. Note that, in this case the drift is considered as "rigid" over the electrode. * **"kilosort_like"**: It consists of *grid convolution + iterative_template + kriging*, to mimic what is done in Kilosort (see [Pachitariu2023]_). - Note that this is not exactly 100% what Kilosort is doing, because the peak detection is done with a template mathcing + Note that this is not exactly 100% what Kilosort is doing, because the peak detection is done with a template matching in Kilosort, while in SpikeInterface we used a threshold-based method. However, this "preset" gives similar results to Kilosort2.5. @@ -85,7 +85,7 @@ We currently have 3 presets: rec_corrected = correct_motion(rec, preset="nonrigid_accurate") The process is quite long due the two first steps (activity profile + motion inference) -But the return :code:`rec_corrected` is a lazy recording object this will interpolate traces on the +But the return :code:`rec_corrected` is a lazy recording object that will interpolate traces on the fly (step 3 motion interpolation). @@ -116,7 +116,7 @@ Optionally any parameter from the preset can be overwritten: ) ) -Importantly, all the result and intermediate computation can be saved into a folder for further loading +Importantly, all the result and intermediate computations can be saved into a folder for further loading and checking. The folder will contain the motion vector itself of course but also detected peaks, peak location, and more. @@ -134,11 +134,11 @@ Low-level API ------------- All steps (**activity profile**, **motion inference**, **motion interpolation**) can be launched with distinct functions. -This can be useful to find the good method and finely tune/optimize parameters at every steps. +This can be useful to find the best method and finely tune/optimize parameters at each step. All functions are implemented in the :py:mod:`~spikeinterface.sortingcomponents` module. They all have a simple API with SpikeInterface objects or numpy arrays as inputs. Since motion correction is a hot topic, these functions have many possible methods and also many possible parameters. -Finding the good combination of method/parameters is not that easy, but it should be doable, assuming the presets are not +Finding the best combination of method/parameters is not that easy, but it should be doable, assuming the presets are not working properly for your particular case. @@ -159,7 +159,7 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte peaks = detect_peaks(rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) # (optional) sub-select some peaks to speed up the localization peaks = select_peaks(peaks, ...) - peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",local_radius_um=75.0, + peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",radius_um=75.0, max_distance_um=150.0, **job_kwargs) # Step 2: motion inference @@ -186,7 +186,7 @@ The function :py:func:`~spikeinterface.preprocessing.correct_motion()` requires It is important to keep in mind that the preprocessing can have a strong impact on the motion estimation. -In the context of motion correction we advice: +In the context of motion correction we advise: * to not use whitening before motion estimation (as it interferes with spatial amplitude information) * to remove high frequencies in traces, to reduce noise in peak location (e.g. using a bandpass filter) * if you use Neuropixels, then use :py:func:`~spikeinterface.preprocessing.phase_shift()` in preprocessing diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index b4380fc587..aa62ea5b33 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -51,7 +51,7 @@ follows: peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, - local_radius_um=100, + radius_um=100, noise_levels=None, random_chunk_kwargs={}, outputs='numpy_compact', @@ -95,7 +95,7 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) peak_locations = localize_peaks(recording, peaks, method='center_of_mass', - local_radius_um=70., ms_before=0.3, ms_after=0.6, + radius_um=70., ms_before=0.3, ms_after=0.6, **job_kwargs) diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 9cb99ab5a1..86c541dfd0 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -123,7 +123,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_timeseries(recording, backend="matplotlib") + w = plot_traces(recording, backend="matplotlib") **Output:** @@ -146,9 +146,9 @@ Each function has the following additional arguments: from spikeinterface.preprocessing import common_reference - # ipywidgets backend also supports multiple "layers" for plot_timeseries + # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_timeseries(rec_dict, backend="ipywidgets") + w = sw.plot_traces(rec_dict, backend="ipywidgets") **Output:** @@ -171,7 +171,7 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_timeseries(recording, backend="ipywidgets") + w_ts = sw.plot_traces(recording, backend="ipywidgets") w_ss = sw.plot_sorting_summary(recording, backend="sortingview") diff --git a/doc/releases/0.98.1.rst b/doc/releases/0.98.1.rst new file mode 100644 index 0000000000..b713e2fbd2 --- /dev/null +++ b/doc/releases/0.98.1.rst @@ -0,0 +1,24 @@ +.. _release0.98.1: + +SpikeInterface 0.98.1 release notes +----------------------------------- + +18th July 2023 + +Minor release with some bug fixes. + +* Make all paths resolved and absolute (#1834) +* Improve Documentation (#1809) +* Fix hdbascan installation in read the docs (#1838) +* Fixed numba.jit and binary num_chan warnings (#1836) +* Fix neo release bug in Mearec (#1835) +* Do not load NP probe in OE if load_sync_channel=True (#1832) +* Cleanup dumping/to_dict (#1831) +* Expose AUCpslit param in KS2+ (#1829) +* Add option relative_to=True (#1820) +* plot_motion: make recording optional, add amplitude_clim and alpha (#1818) +* Fix typo in class attribute for NeuralynxSortingExtractor (#1814) +* Make to_phy write templates.npy with datatype np.float64 as required by phy (#1810) +* Add docs requirements and build read-the-docs documentation faster (#1807) +* Fix has_channel_locations function (#1806) +* Add depth_order kwargs (#1803) diff --git a/doc/releases/0.98.2.rst b/doc/releases/0.98.2.rst new file mode 100644 index 0000000000..134aeba960 --- /dev/null +++ b/doc/releases/0.98.2.rst @@ -0,0 +1,18 @@ +.. _release0.98.2: + +SpikeInterface 0.98.2 release notes +----------------------------------- + +20th July 2023 + +Minor release with some bug fixes. + +* Remove warning (#1843) +* Fix Mearec handling of new arguments before neo release 0.13 (#1848) +* Fix full tests by updating hdbscan version (#1849) +* Relax numpy upper bound and update tridesclous dependency (#1850) +* Drop figurl-jupyter dependency (#1855) +* Update Tridesclous 1.6.8 (#1857) +* Eliminate restore keys in CI and simplify installation of dev version dependencies (#1858) +* Allow order_channel_by_depth to accept dimentsions as list (#1861) +* Fixes to Neuroscope extractor before neo release 0.13 (#1863) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 1a61946aa7..8b984e2510 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,8 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.98.2.rst + releases/0.98.1.rst releases/0.98.0.rst releases/0.97.1.rst releases/0.97.0.rst @@ -29,6 +31,18 @@ Release notes releases/0.9.1.rst +Version 0.98.2 +============== + +* Minor release with some bug fixes + + +Version 0.98.1 +============== + +* Minor release with some bug fixes + + Version 0.98.0 ============== diff --git a/docs_rtd.yml b/docs_rtd.yml new file mode 100644 index 0000000000..c4e1fb378c --- /dev/null +++ b/docs_rtd.yml @@ -0,0 +1,9 @@ +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - datalad + - pip: + - -e .[docs] diff --git a/environment_rtd.yml b/environment_rtd.yml deleted file mode 100644 index 5e4b4eb92a..0000000000 --- a/environment_rtd.yml +++ /dev/null @@ -1,18 +0,0 @@ -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - datalad - - numpy=1.23 - - pip: - - sphinx-gallery - - sphinx_rtd_theme - - numpydoc - - MEArec>=1.7.1 - - hdbscan - - numba - - git+https://github.com/NeuralEnsemble/python-neo.git - - git+https://github.com/SpikeInterface/probeinterface.git - - git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full,widgets] diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 9b9048cd0d..eed05a0ee5 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -82,7 +82,7 @@ # # ```python # # %matplotlib widget -# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') +# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') # ``` # # Note that using this ipywidgets make possible to explore different preprocessing chains without saving the entire file to disk. @@ -94,9 +94,9 @@ # here we use a static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) -si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) -si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) -si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) +si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) +si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) +si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) # - @@ -104,7 +104,7 @@ # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] -si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) +si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) # ### Should we save the preprocessed data to a binary file? @@ -170,13 +170,13 @@ job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks # + from spikeinterface.sortingcomponents.peak_localization import localize_peaks -peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) +peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) # - # ### Check for drift diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 266d585de9..7860c605af 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -92,7 +92,7 @@ # # Let's use the `spikeinterface.widgets` module to visualize the traces and the raster plots. -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) # This is how you retrieve info from a `BaseRecording`... diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 9c2b09954e..a1671a7424 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -1,3 +1,19 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: py,ipynb +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.14.6 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + # %matplotlib inline # %load_ext autoreload # %autoreload 2 @@ -119,8 +135,9 @@ def preprocess_chain(rec): # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(rec, motion_info, figure=fig, depth_lim=(400, 600), + si.plot_motion(motion_info, figure=fig, depth_lim=(400, 600), color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) + fig.suptitle(f"{preset=}") # ### Plot peak localization @@ -166,7 +183,7 @@ def preprocess_chain(rec): #color='black', ax.scatter(loc['x'][mask][sl], loc['y'][mask][sl], **color_kargs) - loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.get_times(), + loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.sampling_frequency, motion_info['motion'], motion_info['temporal_bins'], motion_info['spatial_bins'], direction="y") ax = axs[1] diff --git a/examples/modules_gallery/comparison/generate_erroneous_sorting.py b/examples/modules_gallery/comparison/generate_erroneous_sorting.py index b5f53e71ee..d62a15bdc0 100644 --- a/examples/modules_gallery/comparison/generate_erroneous_sorting.py +++ b/examples/modules_gallery/comparison/generate_erroneous_sorting.py @@ -88,7 +88,7 @@ def generate_erroneous_sorting(): for u in [15,16,17]: st = np.sort(np.random.randint(0, high=nframes, size=35)) units_err[u] = st - sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency) + sorting_err = se.NumpySorting.from_unit_dict(units_err, sampling_frequency) return sorting_true, sorting_err diff --git a/examples/modules_gallery/extractors/plot_1_read_various_formats.py b/examples/modules_gallery/extractors/plot_1_read_various_formats.py index 98988a1746..ed0ba34396 100644 --- a/examples/modules_gallery/extractors/plot_1_read_various_formats.py +++ b/examples/modules_gallery/extractors/plot_1_read_various_formats.py @@ -87,7 +87,7 @@ import spikeinterface.widgets as sw -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting, time_range=(0, 5)) plt.show() diff --git a/examples/modules_gallery/widgets/plot_1_rec_gallery.py b/examples/modules_gallery/widgets/plot_1_rec_gallery.py index d3d4792535..1544bbfc54 100644 --- a/examples/modules_gallery/widgets/plot_1_rec_gallery.py +++ b/examples/modules_gallery/widgets/plot_1_rec_gallery.py @@ -15,22 +15,22 @@ recording, sorting = se.toy_example(duration=10, num_channels=4, seed=0, num_segments=1) ############################################################################## -# plot_timeseries() +# plot_traces() # ~~~~~~~~~~~~~~~~~ -w_ts = sw.plot_timeseries(recording) +w_ts = sw.plot_traces(recording) ############################################################################## # We can select time range -w_ts1 = sw.plot_timeseries(recording, time_range=(5, 8)) +w_ts1 = sw.plot_traces(recording, time_range=(5, 8)) ############################################################################## # We can color with groups recording2 = recording.clone() recording2.set_channel_groups(channel_ids=recording.get_channel_ids(), groups=[0, 0, 1, 1]) -w_ts2 = sw.plot_timeseries(recording2, time_range=(5, 8), color_groups=True) +w_ts2 = sw.plot_traces(recording2, time_range=(5, 8), color_groups=True) ############################################################################## # **Note**: each function returns a widget object, which allows to access the figure and axis. @@ -41,7 +41,7 @@ ############################################################################## # We can also use the 'map' mode useful for high channel count -w_ts = sw.plot_timeseries(recording, mode='map', time_range=(5, 8), +w_ts = sw.plot_traces(recording, mode='map', time_range=(5, 8), show_channel_ids=True, order_channel_by_depth=True) ############################################################################## diff --git a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py index df7d9dbf2c..addd87c065 100644 --- a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py +++ b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py @@ -30,7 +30,7 @@ peaks = detect_peaks( rec_filtred, method='locally_exclusive', peak_sign='neg', detect_threshold=6, exclude_sweep_ms=0.3, - local_radius_um=100, + radius_um=100, noise_levels=None, random_chunk_kwargs={}, chunk_memory='10M', n_jobs=1, progress_bar=True) diff --git a/pyproject.toml b/pyproject.toml index 574cd79830..3ecfbe2718 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,13 +18,14 @@ classifiers = [ "Operating System :: OS Independent" ] + dependencies = [ "numpy", - "neo>=0.11.1", + "neo>=0.12.0", "joblib", "threadpoolctl", "tqdm", - "probeinterface>=0.2.16", + "probeinterface>=0.2.17", ] [build-system] @@ -99,7 +100,6 @@ widgets = [ "ipympl", "ipywidgets", "sortingview>=0.11.15", - "figurl-jupyter" ] test_core = [ @@ -118,9 +118,8 @@ test = [ "huggingface_hub", # tridesclous - "numpy<1.24", "numba", - "hdbscan", + "hdbscan>=0.8.33", # Previous version had a broken wheel # for sortingview backend "sortingview", @@ -130,7 +129,7 @@ test = [ "datalad==0.16.2", ## install tridesclous for testing ## - "tridesclous>=1.6.6.1", + "tridesclous>=1.6.8", ## sliding_nn "pymde", @@ -143,6 +142,23 @@ test = [ "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] +docs = [ + "Sphinx==5.1.1", + "sphinx_rtd_theme==1.0.0", + "sphinx-gallery", + "numpydoc", + + # for notebooks in the gallery + "MEArec", # Use as an example + "datalad==0.16.2", # Download mearec data, not sure if needed as is installed with conda as well because of git-annex + "pandas", # Don't know where this is needed + "hdbscan>=0.8.33", # For sorters, probably spikingcircus + "numba", # For sorters, probably spikingcircus + # for release we need pypi, so this needs to be commented + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + +] [tool.pytest.ini_options] markers = [ diff --git a/readthedocs.yml b/readthedocs.yml index 350948104d..512fcbc709 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,17 +1,10 @@ version: 2 build: - image: latest + os: ubuntu-22.04 + tools: + python: "mambaforge-4.10" -conda: - environment: environment_rtd.yml - -# python: -# install: -# - method: pip -# path: . -# python: -# version: 3.8 -# install: -# - requirements: requirements_rtd.txt +conda: + environment: docs_rtd.yml diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index c01ea19f14..db45e2b25b 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -184,8 +184,8 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): unit1_ids = np.array(sorting1.get_unit_ids()) unit2_ids = np.array(sorting2.get_unit_ids()) - ev_counts1 = np.array(list(sorting1.get_total_num_spikes().values())) - ev_counts2 = np.array(list(sorting2.get_total_num_spikes().values())) + ev_counts1 = np.array(list(sorting1.count_num_spikes_per_unit().values())) + ev_counts2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index f33b81bf46..436e04f45a 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -197,7 +197,7 @@ def __init__( self._kwargs = dict( wvf_extractor=str(wvf_extractor.folder.absolute()), - injected_sorting=self.injected_sorting.to_dict(), + injected_sorting=self.injected_sorting, unit_ids=unit_ids, max_injected_per_unit=max_injected_per_unit, injected_rate=injected_rate, @@ -241,7 +241,7 @@ def generate_injected_sorting( injected_spike_trains[segment_index][unit_id] = injected_spike_train - return NumpySorting.from_dict(injected_spike_trains, sorting.get_sampling_frequency()) + return NumpySorting.from_unit_dict(injected_spike_trains, sorting.get_sampling_frequency()) create_hybrid_units_recording = define_function_from_class( diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 4a1617d622..ed9ed7520c 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -199,7 +199,7 @@ def save_to_folder(self, save_folder): json.dump(kwargs, f) sortings = {} for name, sorting in zip(self.name_list, self.object_list): - sortings[name] = sorting.to_dict() + sortings[name] = sorting.to_dict(recursive=True, relative_to=save_folder) with (save_folder / "sortings.json").open("w") as f: json.dump(sortings, f) @@ -211,7 +211,7 @@ def load_from_folder(folder_path): with (folder_path / "sortings.json").open() as f: dict_sortings = json.load(f) name_list = list(dict_sortings.keys()) - sorting_list = [load_extractor(v) for v in dict_sortings.values()] + sorting_list = [load_extractor(v, base_folder=folder_path) for v in dict_sortings.values()] mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(name_list), do_matching=False, **kwargs) filename = str(folder_path / "multicomparison.gpickle") with open(filename, "rb") as f: diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 97269edc76..75976ed44f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -395,6 +395,8 @@ def get_performance(self, method="by_unit", output="pandas"): perf: pandas dataframe/series (or dict) dataframe/series (based on 'output') with performance entries """ + import pandas as pd + possibles = ("raw_count", "by_unit", "pooled_with_average") if method not in possibles: raise Exception("'method' can be " + " or ".join(possibles)) @@ -408,7 +410,7 @@ def get_performance(self, method="by_unit", output="pandas"): elif method == "pooled_with_average": perf = self.get_performance(method="by_unit").mean(axis=0) - if output == "dict" and isinstance(perf, pd.Series): + if output == "dict" and isinstance(perf, (pd.DataFrame, pd.Series)): perf = perf.to_dict() return perf diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py index 08f3613bc2..79227c865f 100644 --- a/src/spikeinterface/comparison/studytools.py +++ b/src/spikeinterface/comparison/studytools.py @@ -53,7 +53,7 @@ def setup_comparison_study(study_folder, gt_dict, **job_kwargs): for rec_name, (recording, sorting_gt) in gt_dict.items(): # write recording using save with binary folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="npz") + sorting_gt.save(folder=folder, format="numpy_folder") folder = study_folder / "raw_files" / rec_name recording.save(folder=folder, format="binary", **job_kwargs) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py index a2f043b9e7..931c989cef 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py @@ -1,6 +1,8 @@ import numpy as np from numpy.testing import assert_array_equal +import pandas as pd + from spikeinterface.extractors import NumpySorting, toy_example from spikeinterface.comparison import compare_sorter_to_ground_truth @@ -55,8 +57,10 @@ def test_compare_sorter_to_ground_truth(): "pooled_with_average", ] for method in methods: - perf = sc.get_performance(method=method) - # ~ print(perf) + perf_df = sc.get_performance(method=method, output="pandas") + assert isinstance(perf_df, (pd.Series, pd.DataFrame)) + perf_dict = sc.get_performance(method=method, output="dict") + assert isinstance(perf_dict, dict) for method in methods: sc.print_performance(method=method) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 6af2698211..70f8a63c8c 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,3 +1,4 @@ +import importlib import shutil import pytest from pathlib import Path @@ -6,6 +7,13 @@ from spikeinterface.sorters import installed_sorters from spikeinterface.comparison import GroundTruthStudy +try: + import tridesclous + + HAVE_TDC = True +except ImportError: + HAVE_TDC = False + if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" @@ -33,6 +41,7 @@ def _setup_comparison_study(): study = GroundTruthStudy.create(study_folder, gt_dict) +@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") def test_run_study_sorters(): study = GroundTruthStudy(study_folder) sorter_list = [ @@ -45,6 +54,7 @@ def test_run_study_sorters(): study.run_sorters(sorter_list) +@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") def test_extract_sortings(): study = GroundTruthStudy(study_folder) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index e379777a44..d44890f844 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,10 +8,10 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, NumpySorting, NumpyEvent, NumpySnippets +from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder -from .npzfolder import NpzFolderSorting, read_npz_folder +from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder from .npysnippetsextractor import NpySnippetsExtractor, read_npy_snippets from .npyfoldersnippets import NpyFolderSnippets, read_npy_snippets_folder diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3925f41d2b..87c0805630 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -322,7 +322,7 @@ def to_dict( include_properties: bool If True, all properties are added to the dict, by default False relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path, by default None + If not None, files and folders are serialized relative to this path, by default None Used in waveform extractor to maintain relative paths to binary files even if the containing folder / diretory is moved folder_metadata: str, Path, or None @@ -338,6 +338,9 @@ def to_dict( kwargs = self._kwargs + if relative_to and not recursive: + raise ValueError("`relative_to` is only posible when `recursive=True`") + if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, @@ -394,13 +397,13 @@ def to_dict( dump_dict["properties"] = {k: self._properties.get(k, None) for k in self._main_properties} if relative_to is not None: - relative_to = Path(relative_to).absolute() + relative_to = Path(relative_to).resolve().absolute() assert relative_to.is_dir(), "'relative_to' must be an existing directory" dump_dict = _make_paths_relative(dump_dict, relative_to) if folder_metadata is not None: if relative_to is not None: - folder_metadata = Path(folder_metadata).absolute().relative_to(relative_to) + folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) return dump_dict @@ -547,8 +550,9 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No ---------- file_path: str or Path The output file (either .json or .pkl/.pickle) - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) @@ -560,18 +564,31 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None: """ Dump recording extractor to json file. - The extractor can be re-loaded with load_extractor_from_json(json_file) + The extractor can be re-loaded with load_extractor(json_file) Parameters ---------- file_path: str Path of the json file - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. + folder_metadata: str, Path, or None + Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable() + assert self.check_if_json_serializable(), "The extractor is not json serializable" + + # Writing paths as relative_to requires recursively expanding the dict + if relative_to: + relative_to = Path(file_path).parent if relative_to is True else Path(relative_to) + relative_to = relative_to.resolve().absolute() + dump_dict = self.to_dict( - include_annotations=True, include_properties=False, relative_to=relative_to, folder_metadata=folder_metadata + include_annotations=True, + include_properties=False, + relative_to=relative_to, + folder_metadata=folder_metadata, + recursive=True, ) file_path = self._get_file_path(file_path, [".json"]) @@ -584,13 +601,11 @@ def dump_to_pickle( self, file_path: Union[str, Path, None] = None, include_properties: bool = True, - relative_to=None, folder_metadata=None, - recursive: bool = False, ): """ Dump recording extractor to a pickle file. - The extractor can be re-loaded with load_extractor_from_json(json_file) + The extractor can be re-loaded with load_extractor(pickle_file) Parameters ---------- @@ -598,41 +613,43 @@ def dump_to_pickle( Path of the json file include_properties: bool If True, all properties are dumped - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path - recursive: bool - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. + folder_metadata: str, Path, or None + Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable() + assert self.check_if_dumpable(), "The extractor is not dumpable" + dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, - relative_to=relative_to, folder_metadata=folder_metadata, - recursive=recursive, + recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) file_path.write_bytes(pickle.dumps(dump_dict)) @staticmethod - def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor": + def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None) -> "BaseExtractor": """ Load extractor from file path (.json or .pkl) Used both after: * dump(...) json or pickle file * save (...) a folder which contain data + json (or pickle) + metadata. + """ file_path = Path(file_path) + if base_folder is True: + base_folder = file_path.parent + if file_path.is_file(): # standard case based on a file (json or pickle) if str(file_path).endswith(".json"): - with open(str(file_path), "r") as f: + with open(file_path, "r") as f: d = json.load(f) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(str(file_path), "rb") as f: + with open(file_path, "rb") as f: d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") @@ -705,7 +722,7 @@ def save(self, **kwargs) -> "BaseExtractor": Parameters ---------- kwargs: Keyword arguments for saving. - * format: "memory", "zarr", or "binary" (for recording) / "memory" or "npz" for sorting. + * format: "memory", "zarr", or "binary" (for recording) / "memory" or "numpy_folder" or "npz_folder" for sorting. In case format is not memory, the recording is saved to a folder. See format specific functions for more info (`save_to_memory()`, `save_to_folder()`, `save_to_zarr()`) * folder: if provided, the folder path where the object is saved @@ -921,7 +938,7 @@ def save_to_zarr( def _make_paths_relative(d, relative) -> dict: - relative = str(Path(relative).absolute()) + relative = str(Path(relative).resolve().absolute()) func = lambda p: os.path.relpath(str(p), start=relative) return recursive_path_modifier(d, func, target="path", copy=True) @@ -1030,7 +1047,11 @@ def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: Parameters ---------- - file_or_folder_or_dict: dictionary or folder or file (json, pickle) + file_or_folder_or_dict : dictionary or folder or file (json, pickle) + The file path, folder path, or dictionary to load the extractor from + base_folder : str | Path | bool (optional) + The base folder to make relative paths absolute. + If True and file_or_folder_or_dict is a file, the parent folder of the file is used. Returns ------- @@ -1038,6 +1059,7 @@ def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: The loaded extractor object """ if isinstance(file_or_folder_or_dict, dict): + assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict" return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder) else: return BaseExtractor.load(file_or_folder_or_dict, base_folder=base_folder) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8c24e4e624..e7166def75 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -445,6 +445,8 @@ def _save(self, format="binary", **save_kwargs): from .binaryrecordingextractor import BinaryRecordingExtractor + # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading + # See the __init__ of `BinaryFolderRecording` binary_rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=self.get_sampling_frequency(), diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 259d3edc17..affde8a75e 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -58,7 +58,7 @@ def has_probe(self): return "contact_vector" in self.get_property_keys() def has_channel_location(self): - return self.has_probe() or "channel_location" in self.get_property_keys() + return self.has_probe() or "location" in self.get_property_keys() def is_filtered(self): # the is_filtered is handle with annotation diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8ecddb39ca..56f46f0a38 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -7,6 +7,9 @@ from .waveform_tools import has_exceeding_spikes +minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + + class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. @@ -20,6 +23,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List): self._recording = None self._sorting_info = None + # caching + self._cached_spike_vector = None + self._cached_spike_trains = {} + def __repr__(self): clsname = self.__class__.__name__ nseg = self.get_num_segments() @@ -106,17 +113,36 @@ def get_unit_spike_train( start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, return_times: bool = False, + use_cache: bool = True, ): segment_index = self._check_segment_index(segment_index) - segment = self._sorting_segments[segment_index] - spike_frames = segment.get_unit_spike_train( - unit_id=unit_id, start_frame=start_frame, end_frame=end_frame - ).astype("int64") + if use_cache: + if segment_index not in self._cached_spike_trains: + self._cached_spike_trains[segment_index] = {} + if unit_id not in self._cached_spike_trains[segment_index]: + segment = self._sorting_segments[segment_index] + spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype( + "int64" + ) + self._cached_spike_trains[segment_index][unit_id] = spike_frames + else: + spike_frames = self._cached_spike_trains[segment_index][unit_id] + if start_frame is not None: + spike_frames = spike_frames[spike_frames >= start_frame] + if end_frame is not None: + spike_frames = spike_frames[spike_frames < end_frame] + else: + segment = self._sorting_segments[segment_index] + spike_frames = segment.get_unit_spike_train( + unit_id=unit_id, start_frame=start_frame, end_frame=end_frame + ).astype("int64") + if return_times: if self.has_recording(): times = self.get_times(segment_index=segment_index) return times[spike_frames] else: + segment = self._sorting_segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times @@ -190,31 +216,41 @@ def get_times(self, segment_index=None): else: return None - def _save(self, format="npz", **save_kwargs): + def _save(self, format="numpy_folder", **save_kwargs): """ This function replaces the old CachesortingExtractor, but enables more engines - for caching a results. At the moment only 'npz' is supported. + for caching a results. + + Since v0.98.0 'numpy_folder' is used by defult. + From v0.96.0 to 0.97.0 'npz_folder' was the default. + """ - if format == "npz": + if format == "numpy_folder": + from .sortingfolder import NumpyFolderSorting + folder = save_kwargs.pop("folder") - # TODO save properties/features as npz!!!!! - from .npzsortingextractor import NpzSortingExtractor + NumpyFolderSorting.write_sorting(self, folder) + cached = NumpyFolderSorting(folder) - save_path = folder / "sorting_cached.npz" - NpzSortingExtractor.write_sorting(self, save_path) - cached = NpzSortingExtractor(save_path) - cached.dump(folder / "npz.json", relative_to=folder) + if self.has_recording(): + warnings.warn("The registered recording will not be persistent on disk, but only available in memory") + cached.register_recording(self._recording) - from .npzfolder import NpzFolderSorting + elif format == "npz_folder": + from .sortingfolder import NpzFolderSorting + folder = save_kwargs.pop("folder") + NpzFolderSorting.write_sorting(self, folder) cached = NpzFolderSorting(folder_path=folder) + if self.has_recording(): warnings.warn("The registered recording will not be persistent on disk, but only available in memory") cached.register_recording(self._recording) + elif format == "memory": from .numpyextractors import NumpySorting - cached = NumpySorting.from_extractor(self) + cached = NumpySorting.from_sorting(self) else: raise ValueError(f"format {format} not supported") return cached @@ -225,8 +261,16 @@ def get_unit_property(self, unit_id, key): return v def get_total_num_spikes(self): + warnings.warn( + "Sorting.get_total_num_spikes() is deprecated, se sorting.count_num_spikes_per_unit()", + DeprecationWarning, + stacklevel=2, + ) + return self.count_num_spikes_per_unit() + + def count_num_spikes_per_unit(self): """ - Get total number of spikes for each unit across segments. + For each unit : get number of spikes across segments. Returns ------- @@ -242,6 +286,17 @@ def get_total_num_spikes(self): num_spikes[unit_id] = n return num_spikes + def count_total_num_spikes(self): + """ + Get total number of spikes summed across segment and units. + + Returns + ------- + total_num_spikes: int + The total number of spike + """ + return self.to_spike_vector().size + def select_units(self, unit_ids, renamed_unit_ids=None): """ Selects a subset of units @@ -319,8 +374,17 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True): def get_all_spike_trains(self, outputs="unit_id"): """ - Return all spike trains concatenated + Return all spike trains concatenated. + + This is deprecated use sorting.to_spike_vector() instead """ + + warnings.warn( + "Sorting.get_all_spike_trains() will be deprecated. Sorting.to_spike_vector() instead", + DeprecationWarning, + stacklevel=2, + ) + assert outputs in ("unit_id", "unit_index") spikes = [] for segment_index in range(self.get_num_segments()): @@ -347,7 +411,7 @@ def get_all_spike_trains(self, outputs="unit_id"): spikes.append((spike_times, spike_labels)) return spikes - def to_spike_vector(self, extremum_channel_inds=None): + def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True): """ Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. @@ -356,11 +420,17 @@ def to_spike_vector(self, extremum_channel_inds=None): Parameters ---------- + concatenated: bool + With concatenated=True (default) the output is one numpy "spike vector" with spikes from all segments. + With concatenated=False the output is a list "spike vector" by segment. extremum_channel_inds: None or dict If a dictionnary of unit_id to channel_ind is given then an extra field 'channel_index'. This can be convinient for computing spikes postion after sorter. This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + use_cache: bool + When True (default) the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). + This caching only occurs when extremum_channel_inds=None. Returns ------- @@ -370,31 +440,131 @@ def to_spike_vector(self, extremum_channel_inds=None): is given """ - spikes_ = self.get_all_spike_trains(outputs="unit_index") - - n = np.sum([e[0].size for e in spikes_]) - spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + spike_dtype = minimum_spike_dtype if extremum_channel_inds is not None: - spike_dtype += [("channel_index", "int64")] - - spikes = np.zeros(n, dtype=spike_dtype) + spike_dtype = spike_dtype + [("channel_index", "int64")] + ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) - pos = 0 - for segment_index, (spike_times, spike_labels) in enumerate(spikes_): - n = spike_times.size - spikes[pos : pos + n]["sample_index"] = spike_times - spikes[pos : pos + n]["unit_index"] = spike_labels - spikes[pos : pos + n]["segment_index"] = segment_index - pos += n + if use_cache and self._cached_spike_vector is not None: + # the cache already exists + if extremum_channel_inds is None: + spikes = self._cached_spike_vector + else: + spikes = np.zeros(self._cached_spike_vector.size, dtype=spike_dtype) + spikes["sample_index"] = self._cached_spike_vector["sample_index"] + spikes["unit_index"] = self._cached_spike_vector["unit_index"] + spikes["segment_index"] = self._cached_spike_vector["segment_index"] + if extremum_channel_inds is not None: + spikes["channel_index"] = ext_channel_inds[spikes["unit_index"]] + + if not concatenated: + spikes_ = [] + for segment_index in range(self.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + spikes_.append(spikes[s0:s1]) + spikes = spikes_ - if extremum_channel_inds is not None: - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) - # vector way - spikes["channel_index"] = ext_channel_inds[spikes["unit_index"]] + else: + # the cache not needed or do not exists yet + spikes = [] + for segment_index in range(self.get_num_segments()): + sample_indices = [] + unit_indices = [] + for u, unit_id in enumerate(self.unit_ids): + spike_times = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + sample_indices.append(spike_times) + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) + + if len(sample_indices) > 0: + sample_indices = np.concatenate(sample_indices, dtype="int64") + unit_indices = np.concatenate(unit_indices, dtype="int64") + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(len(sample_indices), dtype=spike_dtype) + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = segment_index + if extremum_channel_inds is not None: + # vector way + spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + spikes.append(spikes_in_seg) + + if concatenated: + spikes = np.concatenate(spikes) + + if use_cache and self._cached_spike_vector is None and extremum_channel_inds is None: + # cache it if necessary but only without "channel_index" + if concatenated: + self._cached_spike_vector = spikes + else: + self._cached_spike_vector = np.concatenate(spikes) return spikes + def to_numpy_sorting(self, propagate_cache=True): + """ + Turn any sorting in a NumpySorting. + useful to have it in memory with a unique vector representation. + + Parameters + ---------- + propagate_cache : bool + Propagate the cache of indivudual spike trains. + + """ + from .numpyextractors import NumpySorting + + sorting = NumpySorting.from_sorting(self) + if propagate_cache and self._cached_spike_trains is not None: + sorting._cached_spike_trains = self._cached_spike_trains + return sorting + + def to_shared_memory_sorting(self): + """ + Turn any sorting in a SharedMemorySorting. + Usefull to have it in memory with a unique vector representation and sharable across processes. + """ + from .numpyextractors import SharedMemorySorting + + sorting = SharedMemorySorting.from_sorting(self) + return sorting + + def to_multiprocessing(self, n_jobs): + """ + When necessary turn sorting object into: + * NumpySorting when n_jobs=1 + * SharedMemorySorting when n_jobs>1 + + If the sorting is already NumpySorting, SharedMemorySorting or NumpyFolderSorting + then this return the sortign itself, no transformation so. + + Parameters + ---------- + n_jobs: int + The number of jobs. + Returns + ------- + sharable_sorting: + A sorting that can be used for multiprocessing. + """ + from .numpyextractors import NumpySorting, SharedMemorySorting + from .sortingfolder import NumpyFolderSorting + + if n_jobs == 1: + if isinstance(self, (NumpySorting, SharedMemorySorting, NumpyFolderSorting)): + return self + else: + return NumpySorting.from_sorting(self) + else: + if isinstance(self, (SharedMemorySorting, NumpyFolderSorting)): + return self + else: + return SharedMemorySorting.from_sorting(self) + class BaseSortingSegment(BaseSegment): """ diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index d9a4ce0963..d185111b8c 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._bin_kwargs = d["kwargs"] if "num_channels" not in self._bin_kwargs: assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index b5c1d2c888..72a95637f6 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -116,7 +116,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(e.absolute()) for e in file_path_list], + "file_paths": [str(Path(e).absolute()) for e in file_path_list], "sampling_frequency": sampling_frequency, "t_starts": t_starts, "num_channels": num_channels, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index f77982fd1e..123e2f0bdf 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -107,7 +107,7 @@ def generate_sorting( else: units_dict[unit_id] = np.array([], dtype=int) units_dict_list.append(units_dict) - sorting = NumpySorting.from_dict(units_dict_list, sampling_frequency) + sorting = NumpySorting.from_unit_dict(units_dict_list, sampling_frequency) return sorting @@ -319,7 +319,7 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No d[unit_id] = times spiketrains.append(d) - sorting_with_dup = NumpySorting.from_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_with_dup = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) return sorting_with_dup @@ -357,7 +357,7 @@ def inject_some_split_units(sorting, split_ids=[], num_split=2, output_ids=False new_units[unit_id] = original_times spiketrains.append(new_units) - sorting_with_split = NumpySorting.from_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_with_split = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) if output_ids: return sorting_with_split, other_ids else: diff --git a/src/spikeinterface/core/npyfoldersnippets.py b/src/spikeinterface/core/npyfoldersnippets.py index b7c773aad3..c002bbe044 100644 --- a/src/spikeinterface/core/npyfoldersnippets.py +++ b/src/spikeinterface/core/npyfoldersnippets.py @@ -48,7 +48,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._bin_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index f534592624..80979ce6c9 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -47,7 +47,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(f) for f in file_paths], + "file_paths": [str(Path(f).absolute()) for f in file_paths], "sampling_frequency": sampling_frequency, "channel_ids": channel_ids, "nbefore": nbefore, diff --git a/src/spikeinterface/core/npzfolder.py b/src/spikeinterface/core/npzfolder.py deleted file mode 100644 index 9d2eb43af6..0000000000 --- a/src/spikeinterface/core/npzfolder.py +++ /dev/null @@ -1,54 +0,0 @@ -from pathlib import Path -import json - -import numpy as np - -from .base import _make_paths_absolute -from .npzsortingextractor import NpzSortingExtractor -from .core_tools import define_function_from_class - - -class NpzFolderSorting(NpzSortingExtractor): - """ - NpzFolderSorting is an internal format used in spikeinterface. - It is a NpzSortingExtractor + metadata contained in a folder. - - It is created with the function: `sorting.save(folder='/myfolder')` - - Parameters - ---------- - folder_path: str or Path - - Returns - ------- - sorting: NpzFolderSorting - The sorting - """ - - extractor_name = "NpzFolder" - mode = "folder" - name = "npzfolder" - - def __init__(self, folder_path): - folder_path = Path(folder_path) - - with open(folder_path / "npz.json", "r") as f: - d = json.load(f) - - if not d["class"].endswith(".NpzSortingExtractor"): - raise ValueError("This folder is not an npz spikeinterface folder") - - assert d["relative_paths"] - - d = _make_paths_absolute(d, folder_path) - - NpzSortingExtractor.__init__(self, **d["kwargs"]) - - folder_metadata = folder_path - self.load_metadata_from_folder(folder_metadata) - - self._kwargs = dict(folder_path=str(folder_path.absolute())) - self._npz_kwargs = d["kwargs"] - - -read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 8e54258a14..97f22615df 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -9,6 +9,11 @@ BaseSnippets, BaseSnippetsSegment, ) +from .basesorting import minimum_spike_dtype +from .core_tools import make_shared_array + +from multiprocessing.shared_memory import SharedMemory + from typing import List, Union @@ -94,37 +99,66 @@ def get_traces(self, start_frame, end_frame, channel_indices): class NumpySorting(BaseSorting): + """ + In memory sorting object. + The internal representation is always done with a long "spike vector". + + + But we have convenient class methods to instantiate from: + * other sorting object: `NumpySorting.from_sorting()` + * from time+labels: `NumpySorting.from_times_labels()` + * from dict of list: `NumpySorting.from_unit_dict()` + * from neo: `NumpySorting.from_neo_spiketrain_list()` + + Parameters + ---------- + spikes: numpy.array + A numpy vector, the one given by Sorting.to_spike_vector(). + sampling_frequency: float + The sampling frequency in Hz + channel_ids: list + A list of unit_ids. + """ + name = "numpy" - def __init__(self, sampling_frequency, unit_ids=[]): + def __init__(self, spikes, sampling_frequency, unit_ids): + """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = False - self._is_json_serializable = False - @staticmethod - def from_extractor(source_sorting: BaseSorting) -> "NumpySorting": - """ - Create a numpy sorting from another extractor - """ - unit_ids = source_sorting.get_unit_ids() - nseg = source_sorting.get_num_segments() + self._is_dumpable = True + self._is_json_serializable = False - sorting = NumpySorting(source_sorting.get_sampling_frequency(), unit_ids) + if spikes.size == 0: + nseg = 1 + else: + nseg = spikes[-1]["segment_index"] + 1 for segment_index in range(nseg): - units_dict = {} - for unit_id in unit_ids: - units_dict[unit_id] = source_sorting.get_unit_spike_train(unit_id, segment_index) - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) - sorting.copy_metadata(source_sorting) + # important trick : the cache is already spikes vector + self._cached_spike_vector = spikes + self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids) + + @staticmethod + def from_sorting(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting": + """ + Create a numpy sorting from another sorting extractor + """ + + sorting = NumpySorting( + source_sorting.to_spike_vector(), source_sorting.get_sampling_frequency(), source_sorting.unit_ids + ) + if with_metadata: + sorting.copy_metadata(source_sorting) return sorting @staticmethod def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": """ - Construct sorting extractor from: + Construct NumpySorting extractor from: * an array of spike times (in frames) * an array of spike labels and adds all the In case of multisegment, it is a list of array. @@ -148,25 +182,34 @@ def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None labels_list = [np.asarray(e) for e in labels_list] nseg = len(times_list) + if unit_ids is None: unit_ids = np.unique(np.concatenate([np.unique(labels_list[i]) for i in range(nseg)])) - sorting = NumpySorting(sampling_frequency, unit_ids) + spikes = [] for i in range(nseg): - units_dict = {} times, labels = times_list[i], labels_list[i] - for unit_id in unit_ids: - mask = labels == unit_id - units_dict[unit_id] = times[mask] - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + unit_index = np.zeros(labels.size, dtype="int64") + for u, unit_id in enumerate(unit_ids): + unit_index[labels == unit_id] = u + spikes_in_seg = np.zeros(len(times), dtype=minimum_spike_dtype) + spikes_in_seg["sample_index"] = times + spikes_in_seg["unit_index"] = unit_index + spikes_in_seg["segment_index"] = i + order = np.argsort(times) + spikes_in_seg = spikes_in_seg[order] + spikes.append(spikes_in_seg) + spikes = np.concatenate(spikes) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) return sorting @staticmethod - def from_dict(units_dict_list, sampling_frequency) -> "NumpySorting": + def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": """ - Construct sorting extractor from a list of dict. - The list length is the segment count + Construct NumpySorting from a list of dict. + The list length is the segment count. Each dict have unit_ids as keys and spike times as values. Parameters @@ -178,16 +221,44 @@ def from_dict(units_dict_list, sampling_frequency) -> "NumpySorting": unit_ids = list(units_dict_list[0].keys()) - sorting = NumpySorting(sampling_frequency, unit_ids) - for i, units_dict in enumerate(units_dict_list): - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + nseg = len(units_dict_list) + spikes = [] + for seg_index in range(nseg): + units_dict = units_dict_list[seg_index] + + sample_indices = [] + unit_indices = [] + for u, unit_id in enumerate(unit_ids): + spike_times = units_dict[unit_id] + sample_indices.append(spike_times) + + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) + if len(sample_indices) > 0: + sample_indices = np.concatenate(sample_indices) + unit_indices = np.concatenate(unit_indices) + + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = seg_index + spikes.append(spikes_in_seg) + spikes = np.concatenate(spikes) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + + # Trick : populate the cache with dict that already exists + sorting._cached_spike_trains = {seg_ind: d for seg_ind, d in enumerate(units_dict_list)} return sorting @staticmethod def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) -> "NumpySorting": """ - Construct a sorting with a neo spiketrain list. + Construct a NumpySorting with a neo spiketrain list. If this is a list of list, it is multi segment. @@ -211,18 +282,20 @@ def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) if unit_ids is None: unit_ids = np.arange(len(neo_spiketrains[0]), dtype="int64") - sorting = NumpySorting(sampling_frequency, unit_ids) + units_dict_list = [] for seg_index in range(nseg): units_dict = {} for u, unit_id in enumerate(unit_ids): st = neo_spiketrains[seg_index][u] units_dict[unit_id] = (st.rescale("s").magnitude * sampling_frequency).astype("int64") - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + units_dict_list.append(units_dict) + + sorting = NumpySorting.from_unit_dict(units_dict_list, sampling_frequency) return sorting @staticmethod - def from_peaks(peaks, sampling_frequency) -> "NumpySorting": + def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": """ Construct a sorting from peaks returned by 'detect_peaks()' function. The unit ids correspond to the recording channel ids and spike trains are the @@ -240,19 +313,38 @@ def from_peaks(peaks, sampling_frequency) -> "NumpySorting": sorting The NumpySorting object """ - return NumpySorting.from_times_labels(peaks["sample_index"], peaks["channel_index"], sampling_frequency) + spikes = np.zeros(peaks.size, dtype=minimum_spike_dtype) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + if unit_ids is None: + unit_ids = np.unique(peaks["channel_index"]) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + + return sorting class NumpySortingSegment(BaseSortingSegment): - def __init__(self, units_dict): + def __init__(self, spikes, segment_index, unit_ids): BaseSortingSegment.__init__(self) - for unit_id, times in units_dict.items(): - assert times.dtype.kind == "i", "numpy array of spike times must be integer" - assert np.all(np.diff(times) >= 0), "unsorted times" - self._units_dict = units_dict + self.spikes = spikes + self.segment_index = segment_index + self.unit_ids = list(unit_ids) + self.spikes_in_seg = None def get_unit_spike_train(self, unit_id, start_frame, end_frame): - times = self._units_dict[unit_id] + if self.spikes_in_seg is None: + # the slicing of segment is done only once the first time + # this fasten the constructor a lot + s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") + s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + self.spikes_in_seg = self.spikes[s0:s1] + + unit_index = self.unit_ids.index(unit_id) + times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"] + if start_frame is not None: times = times[times >= start_frame] if end_frame is not None: @@ -260,6 +352,61 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times +class SharedMemorySorting(BaseSorting): + def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype, main_shm_owner=True): + assert len(shape) == 1 + assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" + + BaseSorting.__init__(self, sampling_frequency, unit_ids) + self._is_dumpable = True + self._is_json_serializable = False + + self.shm = SharedMemory(shm_name, create=False) + self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) + + nseg = self.shm_spikes[-1]["segment_index"] + 1 + for segment_index in range(nseg): + self.add_sorting_segment(NumpySortingSegment(self.shm_spikes, segment_index, unit_ids)) + + # important trick : the cache is already spikes vector + self._cached_spike_vector = self.shm_spikes + + # this is very important for the shm.unlink() + # only the main instance need to call it + # all other instances that are loaded from dict are not the main owner + self.main_shm_owner = main_shm_owner + + self._kwargs = dict( + shm_name=shm_name, + shape=shape, + sampling_frequency=sampling_frequency, + unit_ids=unit_ids, + # this ensure that all dump/load will not be main shm owner + main_shm_owner=False, + ) + + def __del__(self): + self.shm.close() + if self.main_shm_owner: + self.shm.unlink() + + @staticmethod + def from_sorting(source_sorting): + spikes = source_sorting.to_spike_vector() + shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) + shm_spikes[:] = spikes + sorting = SharedMemorySorting( + shm.name, + spikes.shape, + source_sorting.get_sampling_frequency(), + source_sorting.unit_ids, + dtype=spikes.dtype, + main_shm_owner=True, + ) + shm.close() + return sorting + + class NumpyEvent(BaseEvent): def __init__(self, channel_ids, structured_dtype): BaseEvent.__init__(self, channel_ids, structured_dtype) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 865e5cc283..e5901d7ee0 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -312,9 +312,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): The input recording channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str or tuple + dimensions : str, tuple, or list If str, it needs to be 'x', 'y', 'z'. - If tuple, it sorts the locations in two dimensions using lexsort. + If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') Returns @@ -334,7 +334,7 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" order_f = np.argsort(locations[:, dim], kind="stable") else: - assert isinstance(dimensions, tuple), "dimensions can be a str or a tuple" + assert isinstance(dimensions, (tuple, list)), "dimensions can be str, tuple, or list" locations_to_sort = () for dim in dimensions: dim = ["x", "y", "z"].index(dim) diff --git a/src/spikeinterface/core/snippets_tools.py b/src/spikeinterface/core/snippets_tools.py index a88056b8b1..7f342ef604 100644 --- a/src/spikeinterface/core/snippets_tools.py +++ b/src/spikeinterface/core/snippets_tools.py @@ -26,7 +26,7 @@ def snippets_from_sorting(recording, sorting, nbefore=20, nafter=44, wf_folder=N Snippets extractor created """ job_kwargs = fix_job_kwargs(job_kwargs) - strains = sorting.get_all_spike_trains() + spikes = sorting.to_spike_vector(concatenated=False) peaks2 = sorting.to_spike_vector() peaks2["unit_index"] = 0 @@ -58,7 +58,7 @@ def snippets_from_sorting(recording, sorting, nbefore=20, nafter=44, wf_folder=N nse = NumpySnippets( snippets_list=wfs, - spikesframes_list=[np.sort(s[0]) for s in strains], + spikesframes_list=[s["sample_index"] for s in spikes], sampling_frequency=recording.get_sampling_frequency(), nbefore=nbefore, channel_ids=recording.get_channel_ids(), diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py new file mode 100644 index 0000000000..c813c26442 --- /dev/null +++ b/src/spikeinterface/core/sortingfolder.py @@ -0,0 +1,136 @@ +from pathlib import Path +import json + +import numpy as np + +from .base import _make_paths_absolute +from .basesorting import BaseSorting, BaseSortingSegment +from .npzsortingextractor import NpzSortingExtractor +from .core_tools import define_function_from_class +from .numpyextractors import NumpySortingSegment + + +class NumpyFolderSorting(BaseSorting): + """ + NumpyFolderSorting is the new internal format used in spikeinterface (>=0.99.0) for caching sorting objects. + + It is a simple folder that contains: + * a file "spike.npy" (numpy format) with all flatten spikes (using sorting.to_spike_vector()) + * a "numpysorting_info.json" containing sampling_frequency, unit_ids and num_segments + * a metadata folder for units properties. + + It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` + + """ + + extractor_name = "NumpyFolderSorting" + mode = "folder" + name = "NumpyFolder" + + def __init__(self, folder_path): + folder_path = Path(folder_path) + + with open(folder_path / "numpysorting_info.json", "r") as f: + info = json.load(f) + + sampling_frequency = info["sampling_frequency"] + unit_ids = np.array(info["unit_ids"]) + num_segments = info["num_segments"] + + BaseSorting.__init__(self, sampling_frequency, unit_ids) + + self.spikes = np.load(folder_path / "spikes.npy", mmap_mode="r") + + for segment_index in range(num_segments): + self.add_sorting_segment(NumpySortingSegment(self.spikes, segment_index, unit_ids)) + + # important trick : the cache is already spikes vector + self._cached_spike_vector = self.spikes + + folder_metadata = folder_path + self.load_metadata_from_folder(folder_metadata) + + self._kwargs = dict(folder_path=str(folder_path.absolute())) + + @staticmethod + def write_sorting(sorting, save_path): + # the folder can already exists but not contaning numpysorting_info.json + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + info_file = save_path / "numpysorting_info.json" + if info_file.exists(): + raise ValueError("NumpyFolderSorting.write_sorting the folder already contains numpysorting_info.json") + d = { + "sampling_frequency": float(sorting.get_sampling_frequency()), + "unit_ids": sorting.unit_ids.tolist(), + "num_segments": sorting.get_num_segments(), + } + info_file.write_text(json.dumps(d), encoding="utf8") + np.save(save_path / "spikes.npy", sorting.to_spike_vector()) + + +class NpzFolderSorting(NpzSortingExtractor): + """ + NpzFolderSorting is the old internal format used in spikeinterface (<=0.98.0) + + This a folder that contains: + + * "sorting_cached.npz" file in the NpzSortingExtractor format + * "npz.json" which the json description of NpzSortingExtractor + * a metadata folder for units properties. + + It is created with the function: `sorting.save(folder='/myfolder', format="npz_folder")` + + Parameters + ---------- + folder_path: str or Path + + Returns + ------- + sorting: NpzFolderSorting + The sorting + """ + + extractor_name = "NpzFolder" + mode = "folder" + name = "npzfolder" + + def __init__(self, folder_path): + folder_path = Path(folder_path) + + with open(folder_path / "npz.json", "r") as f: + d = json.load(f) + + if not d["class"].endswith(".NpzSortingExtractor"): + raise ValueError("This folder is not an npz spikeinterface folder") + + assert d["relative_paths"] + + d = _make_paths_absolute(d, folder_path) + + NpzSortingExtractor.__init__(self, **d["kwargs"]) + + folder_metadata = folder_path + self.load_metadata_from_folder(folder_metadata) + + self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._npz_kwargs = d["kwargs"] + + @staticmethod + def write_sorting(sorting, save_path): + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + npz_file = save_path / "sorting_cached.npz" + if npz_file.exists(): + raise ValueError("NpzFolderSorting.write_sorting the folder already contains sorting_cached.npz") + NpzSortingExtractor.write_sorting(sorting, npz_file) + cached = NpzSortingExtractor(npz_file) + cached.dump(save_path / "npz.json", relative_to=save_path) + + +read_numpy_sorting_folder = define_function_from_class( + source_class=NumpyFolderSorting, name="read_numpy_sorting_folder" +) +read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index e337d0b035..1a13974b51 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -78,6 +78,7 @@ def check_sortings_equal( ) -> None: assert SX1.get_num_segments() == SX2.get_num_segments() + # TODO for later use to_spike_vector() to do this without looping for segment_idx in range(SX1.get_num_segments()): # get_unit_ids ids1 = np.sort(np.array(SX1.get_unit_ids())) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index ed9a79d055..38987a58e5 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -2,6 +2,7 @@ test for BaseRecording are done with BinaryRecordingExtractor. but check only for BaseRecording general methods. """ +import json import shutil from pathlib import Path import pytest @@ -106,7 +107,7 @@ def test_BaseRecording(): check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True, check_properties=True) # dump/load dict - relative - d = rec.to_dict(relative_to=cache_folder) + d = rec.to_dict(relative_to=cache_folder, recursive=True) rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder) rec3 = load_extractor(d, base_folder=cache_folder) @@ -115,6 +116,18 @@ def test_BaseRecording(): rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) + # dump/load relative=True + rec.dump_to_json(cache_folder / "test_BaseRecording_rel_true.json", relative_to=True) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) + check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) + with open(cache_folder / "test_BaseRecording_rel_true.json") as json_file: + data = json.load(json_file) + assert ( + "/" not in data["kwargs"]["file_paths"][0] + ) # Relative to parent folder, so there shouldn't be any '/' in the path. + # cache to binary folder = cache_folder / "simple_recording" rec.save(format="binary", folder=folder) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index d286a0dd37..ba2c3bbfb4 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -87,7 +87,7 @@ def test_BaseSnippets(): times0 = snippets.get_frames(segment_index=0) - seg0_times = sorting.get_all_spike_trains()[0][0] + seg0_times = sorting.to_spike_vector(concatenated=False)[0]["sample_index"] assert np.array_equal(seg0_times, times0) @@ -107,7 +107,7 @@ def test_BaseSnippets(): snippets3 = load_extractor(cache_folder / "test_BaseSnippets.pkl") # dump/load dict - relative - d = snippets.to_dict(relative_to=cache_folder) + d = snippets.to_dict(relative_to=cache_folder, recursive=True) snippets2 = BaseExtractor.from_dict(d, base_folder=cache_folder) snippets3 = load_extractor(d, base_folder=cache_folder) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6e471121b6..0bdd9aecdd 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -13,6 +13,9 @@ NpzSortingExtractor, NumpyRecording, NumpySorting, + SharedMemorySorting, + NpzFolderSorting, + NumpyFolderSorting, create_sorting_npz, generate_sorting, load_extractor, @@ -29,6 +32,7 @@ def test_BaseSorting(): num_seg = 2 file_path = cache_folder / "test_BaseSorting.npz" + file_path.parent.mkdir(exist_ok=True) create_sorting_npz(num_seg, file_path) @@ -67,11 +71,20 @@ def test_BaseSorting(): check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) check_sortings_equal(sorting, sorting3, check_annotations=True, check_properties=True) - # cache - folder = cache_folder / "simple_sorting" + # cache old format : npz_folder + folder = cache_folder / "simple_sorting_npz_folder" sorting.set_property("test", np.ones(len(sorting.unit_ids))) - sorting.save(folder=folder) + sorting.save(folder=folder, format="npz_folder") sorting2 = BaseExtractor.load_from_folder(folder) + assert isinstance(sorting2, NpzFolderSorting) + + # cache new format : numpy_folder + folder = cache_folder / "simple_sorting_numpy_folder" + sorting.set_property("test", np.ones(len(sorting.unit_ids))) + sorting.save(folder=folder, format="numpy_folder") + sorting2 = BaseExtractor.load_from_folder(folder) + assert isinstance(sorting2, NumpyFolderSorting) + # but also possible sorting3 = BaseExtractor.load(folder) check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) @@ -81,14 +94,19 @@ def test_BaseSorting(): sorting4 = sorting.save(format="memory") check_sortings_equal(sorting, sorting4, check_annotations=True, check_properties=True) - spikes = sorting.get_all_spike_trains() + with pytest.warns(DeprecationWarning): + num_spikes = sorting.get_all_spike_trains() # print(spikes) spikes = sorting.to_spike_vector() # print(spikes) + assert sorting._cached_spike_vector is not None spikes = sorting.to_spike_vector(extremum_channel_inds={0: 15, 1: 5, 2: 18}) # print(spikes) + num_spikes_per_unit = sorting.count_num_spikes_per_unit() + total_spikes = sorting.count_total_num_spikes() + # select units keep_units = [0, 1] sorting_select = sorting.select_units(unit_ids=keep_units) @@ -102,6 +120,14 @@ def test_BaseSorting(): for unit in sorting_clean.get_unit_ids(): assert unit not in empty_units + sorting4 = sorting.to_numpy_sorting() + sorting5 = sorting.to_multiprocessing(n_jobs=2) + # create a clone with the same share mem buffer + sorting6 = load_extractor(sorting5.to_dict()) + assert isinstance(sorting6, SharedMemorySorting) + del sorting6 + del sorting5 + def test_npy_sorting(): sfreq = 10 @@ -113,7 +139,7 @@ def test_npy_sorting(): "0": np.array([0, 1]), "1": np.array([], dtype="int64"), } - sorting = NumpySorting.from_dict( + sorting = NumpySorting.from_unit_dict( [spike_times_0, spike_times_1], sfreq, ) @@ -134,7 +160,7 @@ def test_npy_sorting(): seg_nframes = [9, 5] rec = NumpyRecording([np.zeros((nframes, 10)) for nframes in seg_nframes], sampling_frequency=sfreq) # assert_raises(Exception, sorting.register_recording, rec) - with pytest.warns(): + with pytest.warns(UserWarning): sorting.register_recording(rec) # Registering a rec with too many segments @@ -144,14 +170,15 @@ def test_npy_sorting(): def test_empty_sorting(): - sorting = NumpySorting.from_dict({}, 30000) + sorting = NumpySorting.from_unit_dict({}, 30000) assert len(sorting.unit_ids) == 0 - spikes = sorting.get_all_spike_trains() - assert len(spikes) == 1 - assert len(spikes[0][0]) == 0 - assert len(spikes[0][1]) == 0 + with pytest.warns(DeprecationWarning): + spikes = sorting.get_all_spike_trains() + assert len(spikes) == 1 + assert len(spikes[0][0]) == 0 + assert len(spikes[0][1]) == 0 spikes = sorting.to_spike_vector() assert spikes.shape == (0,) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 89a4143e19..3dc09f1e08 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -35,7 +35,7 @@ def test_write_binary_recording(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) @@ -62,7 +62,7 @@ def test_write_binary_recording_offset(tmp_path): recorder_binary = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=sampling_frequency, - num_chan=num_channels, + num_channels=num_channels, dtype=dtype, file_offset=byte_offset, ) @@ -91,7 +91,7 @@ def test_write_binary_recording_parallel(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) for segment_index in range(recording.get_num_segments()): binary_traces = recorder_binary.get_traces(segment_index=segment_index) @@ -118,7 +118,7 @@ def test_write_binary_recording_multiple_segment(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) for segment_index in range(recording.get_num_segments()): diff --git a/src/spikeinterface/core/tests/test_frameslicesorting.py b/src/spikeinterface/core/tests/test_frameslicesorting.py index 010d733f6d..e404cfb1be 100644 --- a/src/spikeinterface/core/tests/test_frameslicesorting.py +++ b/src/spikeinterface/core/tests/test_frameslicesorting.py @@ -20,13 +20,13 @@ def test_FrameSliceSorting(): "1": np.arange(min_spike_time, max_spike_time), } # Sorting with attached rec - sorting = NumpySorting.from_dict([spike_times], sf) + sorting = NumpySorting.from_unit_dict([spike_times], sf) rec = NumpyRecording([np.zeros((nsamp, 5))], sampling_frequency=sf) sorting.register_recording(rec) # Sorting without attached rec - sorting_norec = NumpySorting.from_dict([spike_times], sf) + sorting_norec = NumpySorting.from_unit_dict([spike_times], sf) # Sorting with attached rec and exceeding spikes - sorting_exceeding = NumpySorting.from_dict([spike_times], sf) + sorting_exceeding = NumpySorting.from_unit_dict([spike_times], sf) rec_exceeding = NumpyRecording([np.zeros((max_spike_time - 1, 5))], sampling_frequency=sf) with warnings.catch_warnings(): warnings.filterwarnings("ignore") diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 23752699a2..4a5bffbc05 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,9 +4,10 @@ import pytest import numpy as np -from spikeinterface.core import NumpyRecording, NumpySorting, NumpyEvent -from spikeinterface.core import create_sorting_npz +from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent +from spikeinterface.core import create_sorting_npz, load_extractor from spikeinterface.core import NpzSortingExtractor +from spikeinterface.core.basesorting import minimum_spike_dtype if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -34,7 +35,8 @@ def test_NumpySorting(): # empty unit_ids = [] - sorting = NumpySorting(sampling_frequency, unit_ids) + spikes = np.zeros(0, dtype=minimum_spike_dtype) + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) # print(sorting) # 2 columns @@ -57,9 +59,38 @@ def test_NumpySorting(): create_sorting_npz(num_seg, file_path) other_sorting = NpzSortingExtractor(file_path) - sorting = NumpySorting.from_extractor(other_sorting) + sorting = NumpySorting.from_sorting(other_sorting) # print(sorting) + # construct back from kwargs keep the same array + sorting2 = load_extractor(sorting.to_dict()) + assert np.shares_memory(sorting2._cached_spike_vector, sorting._cached_spike_vector) + + +def test_SharedMemorySorting(): + sampling_frequency = 30000 + unit_ids = ["a", "b", "c"] + spikes = np.zeros(100, dtype=minimum_spike_dtype) + spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype="int64") + spikes["unit_index"][0::3] = 0 + spikes["unit_index"][1::3] = 1 + spikes["unit_index"][2::3] = 2 + np_sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + print(np_sorting) + + sorting = SharedMemorySorting.from_sorting(np_sorting) + # print(sorting) + assert sorting._cached_spike_vector is not None + + # print(sorting.to_spike_vector()) + d = sorting.to_dict() + + sorting_reload = load_extractor(d) + # print(sorting_reload) + # print(sorting_reload.to_spike_vector()) + + assert sorting.shm.name == sorting_reload.shm.name + def test_NumpyEvent(): # one segment - dtype simple @@ -100,6 +131,7 @@ def test_NumpyEvent(): if __name__ == "__main__": - test_NumpyRecording() + # test_NumpyRecording() test_NumpySorting() - test_NumpyEvent() + # test_SharedMemorySorting() + # test_NumpyEvent() diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py new file mode 100644 index 0000000000..cf7cade3ef --- /dev/null +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -0,0 +1,56 @@ +import pytest + +from pathlib import Path +import shutil + +import numpy as np + +from spikeinterface.core import NpzFolderSorting, NumpyFolderSorting, load_extractor +from spikeinterface.core import generate_sorting +from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "core" +else: + cache_folder = Path("cache_folder") / "core" + + +def test_NumpyFolderSorting(): + sorting = generate_sorting() + + folder = cache_folder / "numpy_sorting_1" + if folder.is_dir(): + shutil.rmtree(folder) + + NumpyFolderSorting.write_sorting(sorting, folder) + + sorting_loaded = NumpyFolderSorting(folder) + check_sortings_equal(sorting_loaded, sorting) + assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) + assert np.array_equal( + sorting_loaded.to_spike_vector(), + sorting.to_spike_vector(), + ) + + +def test_NpzFolderSorting(): + sorting = generate_sorting() + + folder = cache_folder / "npz_folder_sorting_1" + if folder.is_dir(): + shutil.rmtree(folder) + + NpzFolderSorting.write_sorting(sorting, folder) + + sorting_loaded = NpzFolderSorting(folder) + check_sortings_equal(sorting_loaded, sorting) + assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) + assert np.array_equal( + sorting_loaded.to_spike_vector(), + sorting.to_spike_vector(), + ) + + +if __name__ == "__main__": + test_NumpyFolderSorting() + test_NpzFolderSorting() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 9057659124..1a79019f96 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -82,7 +82,7 @@ def test_get_template_extremum_amplitude(): setup_module() test_get_template_amplitudes() - # test_get_template_extremum_channel() - # test_get_template_extremum_channel_peak_shift() - # test_get_template_extremum_amplitude() - # test_get_template_channel_sparsity() + test_get_template_extremum_channel() + test_get_template_extremum_channel_peak_shift() + test_get_template_extremum_amplitude() + test_get_template_channel_sparsity() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 5f5695d7f6..107ef5f180 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -212,7 +212,8 @@ def test_extract_waveforms(): if folder_sort.is_dir(): shutil.rmtree(folder_sort) recording = recording.save(folder=folder_rec) - sorting = sorting.save(folder=folder_sort) + # we force "npz_folder" because we want to force the to_multiprocessing to be a SharedMemorySorting + sorting = sorting.save(folder=folder_sort, format="npz_folder") # 1 job folder1 = cache_folder / "test_extract_waveforms_1job" @@ -467,7 +468,7 @@ def test_empty_sorting(): num_channels = 2 recording = generate_recording(num_channels=num_channels, sampling_frequency=sf, durations=[15.32]) - sorting = NumpySorting.from_dict({}, sf) + sorting = NumpySorting.from_unit_dict({}, sf) folder = cache_folder / "empty_sorting" wvf_extractor = extract_waveforms(recording, sorting, folder, allow_unfiltered=True) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 546bee2ec1..ef60ee6e47 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -921,10 +921,10 @@ def save( zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): if self.recording.check_if_json_serializable(): - rec_dict = self.recording.to_dict(relative_to=relative_to) + rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) if self.sorting.check_if_json_serializable(): - sort_dict = self.sorting.to_dict(relative_to=relative_to) + sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( @@ -1334,7 +1334,7 @@ def run_extract_waveforms(self, seed=None, **job_kwargs): sel = selected_spikes[unit_id][segment_index] selected_spike_times[segment_index][unit_id] = spike_times[sel] - spikes = NumpySorting.from_dict(selected_spike_times, self.sampling_frequency).to_spike_vector() + spikes = NumpySorting.from_unit_dict(selected_spike_times, self.sampling_frequency).to_spike_vector() if self.folder is not None: wf_folder = self.folder / "waveforms" diff --git a/src/spikeinterface/core/zarrrecordingextractor.py b/src/spikeinterface/core/zarrrecordingextractor.py index afa27da905..4dc94a24dd 100644 --- a/src/spikeinterface/core/zarrrecordingextractor.py +++ b/src/spikeinterface/core/zarrrecordingextractor.py @@ -49,7 +49,7 @@ def __init__(self, root_path: Union[Path, str], storage_options=None): root_path = Path(root_path) else: root_path_init = str(root_path) - root_path_kwarg = str(root_path.absolute()) + root_path_kwarg = str(Path(root_path).absolute()) else: root_path_init = root_path root_path_kwarg = root_path_init diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index e32086d0df..5e7047a5c1 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -137,7 +137,7 @@ def get_potential_auto_merge( # STEP 1 : if "min_spikes" in steps: - num_spikes = np.array(list(sorting.get_total_num_spikes().values())) + num_spikes = np.array(list(sorting.count_num_spikes_per_unit().values())) to_remove = num_spikes < minimum_spikes pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False @@ -256,7 +256,7 @@ def compute_correlogram_diff( # Index of the middle of the correlograms. m = correlograms_smoothed.shape[2] // 2 - num_spikes = sorting.get_total_num_spikes() + num_spikes = sorting.count_num_spikes_per_unit() corr_diff = np.full((n, n), np.nan, dtype="float64") for unit_ind1 in range(n): diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 9f2e52ab5e..c2617d5b52 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -112,7 +112,7 @@ def remove_redundant_units( else: remove_unit_ids.append(u2) elif remove_strategy == "max_spikes": - num_spikes = sorting.get_total_num_spikes() + num_spikes = sorting.count_num_spikes_per_unit() for u1, u2 in redundant_unit_pairs: if num_spikes[u1] < num_spikes[u2]: remove_unit_ids.append(u1) diff --git a/src/spikeinterface/curation/tests/test_curationsorting.py b/src/spikeinterface/curation/tests/test_curationsorting.py index fd0b206629..91bc21a49f 100644 --- a/src/spikeinterface/curation/tests/test_curationsorting.py +++ b/src/spikeinterface/curation/tests/test_curationsorting.py @@ -16,7 +16,7 @@ def test_split_merge(): }, {0: np.arange(15), 1: np.arange(17), 2: np.arange(40, 140), 4: np.arange(40, 140), 5: np.arange(40, 140)}, ] - parent_sort = NumpySorting.from_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms + parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("someprop", [float(k) for k in spikestimes[0].keys()]) # float # %% @@ -54,7 +54,7 @@ def test_curation(): }, {"a": np.arange(12, 15), "b": np.arange(3, 17), "c": np.arange(50)}, ] - parent_sort = NumpySorting.from_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms + parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float cs = CurationSorting(parent_sort, properties_policy="remove") # %% @@ -81,7 +81,7 @@ def test_curation(): ) # Test with empty sorting - empty_sorting = CurationSorting(NumpySorting.from_dict({}, parent_sort.sampling_frequency)) + empty_sorting = CurationSorting(NumpySorting.from_unit_dict({}, parent_sort.sampling_frequency)) if __name__ == "__main__": diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 33a46d2bea..8f669657ef 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -158,8 +158,9 @@ def export_to_phy( # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - spike_times, spike_labels = all_spikes[0] + all_spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] + spike_times = all_spikes_seg0["sample_index"] + spike_labels = all_spikes_seg0["unit_index"] np.save(str(output_folder / "spike_times.npy"), spike_times[:, np.newaxis]) np.save(str(output_folder / "spike_templates.npy"), spike_labels[:, np.newaxis]) np.save(str(output_folder / "spike_clusters.npy"), spike_labels[:, np.newaxis]) @@ -168,7 +169,7 @@ def export_to_phy( # shape (num_units, num_samples, max_num_channels) max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) num_samples = waveform_extractor.nbefore + waveform_extractor.nafter - templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype=waveform_extractor.dtype) + templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") # here we pad template inds with -1 if len of sparse channels is unequal templates_ind = -np.ones((len(unit_ids), max_num_channels), dtype="int64") for unit_ind, unit_id in enumerate(unit_ids): diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index dda04fbb17..1fac418e85 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -106,7 +106,10 @@ def __init__(self, folder_path, load_sync_channel=False): sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc) self.set_property("inter_sample_shift", sample_shifts) - self._kwargs = {"folder_path": str(folder_path.absolute()), "load_sync_channel": load_sync_channel} + self._kwargs = { + "folder_path": str(Path(folder_path).absolute()), + "load_sync_channel": load_sync_channel, + } class CBinIblRecordingSegment(BaseRecordingSegment): diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 5e17fd3045..fa2bdde450 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -44,7 +44,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign folder_path = Path(folder_path) assert folder_path.is_dir(), "Folder {} doesn't exist".format(folder_path) if sampling_frequency is None: - h5_path = str(folder_path) + ".h5" + h5_path = str(Path(folder_path).absolute()) + ".h5" if Path(h5_path).exists(): with h5py.File(h5_path, mode="r") as f: sampling_frequency = f["sr"][0] @@ -85,7 +85,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) self.set_property("unsorted", np.array([metadata[u]["group_type"] == 0 for u in range(unit_counter)])) self.set_property("artifact", np.array([metadata[u]["group_type"] == -1 for u in range(unit_counter)])) - self._kwargs = {"folder_path": str(folder_path), "user": user, "det_sign": det_sign} + self._kwargs = {"folder_path": str(Path(folder_path).absolute()), "user": user, "det_sign": det_sign} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 407d388044..ebff40fae0 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -4,11 +4,13 @@ from spikeinterface.core import ( BaseRecording, BaseSorting, + BinaryFolderRecording, BinaryRecordingExtractor, NumpyRecording, NpzSortingExtractor, NumpySorting, NpySnippetsExtractor, + ZarrRecordingExtractor, ) # sorting/recording/event from neo @@ -57,7 +59,9 @@ ######################################## recording_extractor_full_list = [ + BinaryFolderRecording, BinaryRecordingExtractor, + ZarrRecordingExtractor, # natively implemented in spikeinterface.extractors NumpyRecording, SHYBRIDRecordingExtractor, diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 6b904f812b..178596d052 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -108,7 +108,7 @@ def __init__(self, file_path, keep_good_only=True): self.set_property("template", np.array(templates)) self.set_property("template_frames_cut_before", np.array(templates_frames_cut_before)) - self._kwargs = {"file_path": str(file_path), "keep_good_only": keep_good_only} + self._kwargs = {"file_path": str(Path(file_path).absolute()), "keep_good_only": keep_good_only} # TODO features # ~ for uc, unit in enumerate(units): diff --git a/src/spikeinterface/extractors/matlabhelpers.py b/src/spikeinterface/extractors/matlabhelpers.py index 46bcf2d88c..4f22d25339 100644 --- a/src/spikeinterface/extractors/matlabhelpers.py +++ b/src/spikeinterface/extractors/matlabhelpers.py @@ -26,7 +26,7 @@ def __init__(self, file_path): if not file_path.is_file(): raise ValueError(f"Specified file path '{file_path}' is not a file.") - self._kwargs = {"file_path": str(file_path.absolute())} + self._kwargs = {"file_path": str(Path(file_path).absolute())} try: # load old-style (up to 7.2) .mat file self._data = loadmat(file_path, matlab_compatible=True) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 68317e25be..815c617677 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -196,7 +196,7 @@ class MdaSortingExtractor(BaseSorting): name = "mda" def __init__(self, file_path, sampling_frequency): - firings = readmda(str(file_path)) + firings = readmda(str(Path(file_path).absolute())) labels = firings[2, :] unit_ids = np.unique(labels).astype(int) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) @@ -204,7 +204,10 @@ def __init__(self, file_path, sampling_frequency): sorting_segment = MdaSortingSegment(firings) self.add_sorting_segment(sorting_segment) - self._kwargs = {"file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency} + self._kwargs = { + "file_path": str(Path(file_path).absolute()), + "sampling_frequency": sampling_frequency, + } @staticmethod def write_sorting(sorting, save_path, write_primary_channels=False): diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 78844a5267..a58b5ab5ec 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseEventExtractor @@ -32,7 +34,7 @@ def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=Non NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(folder_path), lsx_files=lsx_files)) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), lsx_files=lsx_files)) @classmethod def map_to_neo_kwargs(cls, folder_path, lsx_files=None): diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index cb0cf19ff8..6b1d47e4fa 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -24,7 +26,7 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): def __init__(self, file_path, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__(self, all_annotations=all_annotations, **neo_kwargs) - self._kwargs.update({"file_path": file_path}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 23fcb2c419..3e30cf77ae 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,3 +1,5 @@ +from pathlib import Path + import probeinterface as pi from spikeinterface.core.core_tools import define_function_from_class @@ -57,7 +59,13 @@ def __init__( self.set_property("row", self.get_property("contact_vector")["row"]) self.set_property("col", self.get_property("contact_vector")["col"]) - self._kwargs.update({"file_path": str(file_path), "mea_pitch": mea_pitch, "electrode_width": electrode_width}) + self._kwargs.update( + { + "file_path": str(Path(file_path).absolute()), + "mea_pitch": mea_pitch, + "electrode_width": electrode_width, + } + ) @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 5408173a12..8300e6bc5e 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -56,7 +56,7 @@ def __init__( use_names_as_ids=use_names_as_ids, **neo_kwargs, ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): @@ -107,7 +107,7 @@ def __init__( ) self._kwargs = { - "file_path": file_path, + "file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency, "stream_id": stream_id, "stream_name": stream_name, diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 86865f312b..2451ca8fe1 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -34,7 +36,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) self.extra_requirements.append("neo[ced]") @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index b50df7868c..5d8c56ee87 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -31,7 +33,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("neo[edf]") @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 1cbd5bd869..2a61e7385f 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor @@ -36,7 +38,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 7010b55721..ac85dbdf30 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -1,4 +1,5 @@ import numpy as np +from pathlib import Path import probeinterface as pi @@ -70,7 +71,7 @@ def __init__( probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) self.set_property("electrode", self.get_property("contact_vector")["electrode"]) - self._kwargs.update(dict(file_path=str(file_path), rec_name=rec_name)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod def map_to_neo_kwargs(cls, file_path, rec_name=None): diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 6e377ea799..4b6af54bcd 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -40,7 +42,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index ab24034b9a..7dda9175f5 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -8,6 +8,22 @@ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor +def drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs): + # Temporary function until neo version 0.13.0 is released + from packaging.version import parse as parse_version + from importlib.metadata import version + + neo_version = version("neo") + minor_version = parse_version(neo_version).minor + + # The possibility of loading only spike_trains or only analog_signals is not present in neo <= 0.11.0 + if minor_version < 13: + neo_kwargs.pop("load_spiketrains") + neo_kwargs.pop("load_analogsignal") + + return neo_kwargs + + class MEArecRecordingExtractor(NeoBaseRecordingExtractor): """ Class for reading data from a MEArec simulated data. @@ -40,14 +56,19 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): if hasattr(self.neo_reader._recgen, "gain_to_uV"): self.set_channel_gains(self.neo_reader._recgen.gain_to_uV) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs( cls, file_path, ): - neo_kwargs = {"filename": str(file_path), "load_spiketrains": False, "load_analogsignal": True} + neo_kwargs = { + "filename": str(file_path), + "load_spiketrains": False, + "load_analogsignal": True, + } + neo_kwargs = drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs=neo_kwargs) return neo_kwargs @@ -63,11 +84,17 @@ def __init__(self, file_path: Union[str, Path]): sampling_frequency = self.read_sampling_frequency(file_path=file_path) NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, use_format_ids=True, **neo_kwargs) - self._kwargs = {"file_path": str(file_path)} + self._kwargs = {"file_path": str(Path(file_path).absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path), "load_spiketrains": True, "load_analogsignal": False} + neo_kwargs = { + "filename": str(file_path), + "load_spiketrains": True, + "load_analogsignal": False, + } + neo_kwargs = drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs=neo_kwargs) + return neo_kwargs def read_sampling_frequency(self, file_path: Union[str, Path]) -> float: diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 6f73952eb1..672602b66c 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -1,4 +1,5 @@ from typing import Optional +from pathlib import Path from spikeinterface.core.core_tools import define_function_from_class @@ -32,7 +33,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, all_annotation NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): @@ -61,7 +62,7 @@ class NeuralynxSortingExtractor(NeoBaseSortingExtractor): mode = "folder" NeoRawIOClass = "NeuralynxRawIO" - neo_returns_timestamps = False + neo_returns_frames = True need_t_start_from_signal_stream = True name = "neuralynx" @@ -82,7 +83,7 @@ def __init__( ) self._kwargs = { - "folder_path": folder_path, + "folder_path": str(Path(folder_path).absolute()), "sampling_frequency": sampling_frequency, "stream_id": stream_id, "stream_name": stream_name, diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 49c194ce92..801b9c1928 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -51,7 +51,9 @@ def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=No NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path), xml_file_path=xml_file_path)) + if xml_file_path is not None: + xml_file_path = str(Path(xml_file_path).absolute()) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), xml_file_path=xml_file_path)) @classmethod def map_to_neo_kwargs(cls, file_path, xml_file_path=None): @@ -60,9 +62,9 @@ def map_to_neo_kwargs(cls, file_path, xml_file_path=None): # binary_file is the binary file in .dat, .lfp, .eeg if xml_file_path is not None: - neo_kwargs = {"binary_file": file_path, "filename": xml_file_path} + neo_kwargs = {"binary_file": Path(file_path), "filename": Path(xml_file_path)} else: - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": Path(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 8781a4df71..2762e5645b 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -37,7 +39,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) self.extra_requirements.append("neo[nixio]") @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index e1a6598f61..a771dc47b1 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -61,7 +61,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): @@ -149,8 +149,8 @@ def __init__( else: exp_id = exp_ids[block_index] - # do not load probe for NIDQ stream - if "NI-DAQmx" not in stream_name: + # do not load probe for NIDQ stream or if load_sync_channel is True + if "NI-DAQmx" not in stream_name and not load_sync_channel: settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"] if Path(settings_file).is_file(): @@ -204,7 +204,7 @@ def __init__( self._kwargs.update( dict( - folder_path=str(folder_path), + folder_path=str(Path(folder_path).absolute()), load_sync_channel=load_sync_channel, load_sync_timestamps=load_sync_timestamps, experiment_names=experiment_names, diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 84c06e6974..c3ff59fe82 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor @@ -30,7 +32,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): @@ -61,7 +63,7 @@ def __init__(self, file_path): self.neo_reader.parse_header() sampling_frequency = self.neo_reader._global_ssampling_rate NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) - self._kwargs = {"file_path": str(file_path)} + self._kwargs = {"file_path": str(Path(file_path).absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 7fab8e4087..af172855ed 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor @@ -31,7 +33,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("sonpy") @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index e5841e9df8..49d55ca3eb 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -30,7 +32,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) @classmethod def map_to_neo_kwargs(cls, file_path): diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 1998995cb4..8c3b33505d 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -1,6 +1,7 @@ from packaging import version import numpy as np +from pathlib import Path import neo import probeinterface as pi @@ -90,7 +91,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ self.set_property("inter_sample_shift", sample_shifts) - self._kwargs.update(dict(folder_path=str(folder_path), load_sync_channel=load_sync_channel)) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), load_sync_channel=load_sync_channel)) @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index bd4cbe2339..60cd39c010 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor @@ -35,7 +37,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index e5ac7e18bc..bca4c75d99 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -104,7 +104,6 @@ def read_nwbfile( -------- >>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3") """ - file_path = str(file_path) from pynwb import NWBHDF5IO, NWBFile if stream_mode == "fsspec": @@ -131,6 +130,7 @@ def read_nwbfile( io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3") else: + file_path = str(Path(file_path).absolute()) io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True) nwbfile = io.read() @@ -475,19 +475,17 @@ def __init__( self.stream_cache_path = stream_cache_path if stream_cache_path is not None else "cache" self.cfs = CachingFileSystem( fs=fsspec.filesystem("http"), - cache_storage=self.stream_cache_path, + cache_storage=str(self.stream_cache_path), ) - self._file_path = self.cfs.open(str(file_path), "rb") - file = h5py.File(self._file_path) + file_path_ = self.cfs.open(file_path, "rb") + file = h5py.File(file_path_) self.io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) elif stream_mode == "ros3": - self._file_path = str(file_path) - self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True, driver="ros3") - + self.io = NWBHDF5IO(file_path, mode="r", load_namespaces=True, driver="ros3") else: - self._file_path = str(file_path) - self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True) + file_path_ = str(Path(file_path).absolute()) + self.io = NWBHDF5IO(file_path_, mode="r", load_namespaces=True) self._nwbfile = self.io.read() units_ids = list(self._nwbfile.units.id[:]) diff --git a/src/spikeinterface/extractors/tests/common_tests.py b/src/spikeinterface/extractors/tests/common_tests.py index c1a98698b0..858c86d92a 100644 --- a/src/spikeinterface/extractors/tests/common_tests.py +++ b/src/spikeinterface/extractors/tests/common_tests.py @@ -38,7 +38,6 @@ def test_open(self): # test streams and blocks retrieval full_path = self.get_full_path(path) - rec = self.ExtractorClass(full_path, **kwargs) assert hasattr(rec, "extra_requirements") diff --git a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py index 3c4e23f14a..2e364b13bc 100644 --- a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py +++ b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py @@ -22,7 +22,7 @@ class CompressedBinaryIblExtractorTest(RecordingCommonTestSuite, unittest.TestCa # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.widgets as sw # ~ from probeinterface.plotting import plot_probe -# ~ sw.plot_timeseries(rec) +# ~ sw.plot_traces(rec) # ~ plot_probe(rec.get_probe()) # ~ plt.show() diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index a41fd080b4..71a19f30d3 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -4,7 +4,7 @@ import numpy as np import h5py -from spikeinterface.extractors import NwbRecordingExtractor +from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "extractors" @@ -15,7 +15,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -42,7 +42,7 @@ def test_s3_nwb_ros3(): @pytest.mark.streaming_extractors -def test_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -66,3 +66,56 @@ def test_s3_nwb_fsspec(): if rec.has_scaled(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + + +@pytest.mark.ros3_test +@pytest.mark.streaming_extractors +@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") +def test_sorting_s3_nwb_ros3(): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # we provide the 'sampling_frequency' because the NWB file does not the electrical series + sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = sort.get_num_segments() + num_units = len(sort.unit_ids) + + for segment_index in range(num_seg): + for unit in sort.unit_ids: + spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + +@pytest.mark.streaming_extractors +def test_sorting_s3_nwb_fsspec(): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # we provide the 'sampling_frequency' because the NWB file does not the electrical series + sort = NwbSortingExtractor( + file_path, sampling_frequency=30000, stream_mode="fsspec", stream_cache_path=cache_folder + ) + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = sort.get_num_segments() + num_units = len(sort.unit_ids) + + for segment_index in range(num_seg): + for unit in sort.unit_ids: + spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + +if __name__ == "__main__": + test_recording_s3_nwb_ros3() + test_recording_s3_nwb_fsspec() + test_sorting_s3_nwb_ros3() + test_sorting_s3_nwb_fsspec() diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 1fb4ad9555..bb04c21533 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -56,7 +56,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(YassSortingSegment(spiketrains)) - self._kwargs = {"folder_path": str(folder_path)} + self._kwargs = {"folder_path": str(Path(folder_path).absolute())} self.extra_requirements.append("pyyaml") diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index dc3624ba3e..3ebeafcfec 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -18,7 +18,9 @@ def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") - self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + self.spikes = self.waveform_extractor.sorting.to_spike_vector( + extremum_channel_inds=extremum_channel_inds, use_cache=False + ) def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 39118e6304..6cd5238abd 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -216,7 +216,7 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): """ num_seg = sorting.get_num_segments() num_units = len(sorting.unit_ids) - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) num_half_bins = int(window_size // bin_size) num_bins = int(2 * num_half_bins) @@ -224,7 +224,8 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") for seg_index in range(num_seg): - spike_times, spike_labels = spikes[seg_index] + spike_times = spikes[seg_index]["sample_index"] + spike_labels = spikes[seg_index]["unit_index"] c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size) @@ -305,11 +306,13 @@ def compute_correlograms_numba(sorting, window_size, bin_size): num_bins = 2 * int(window_size / bin_size) num_units = len(sorting.unit_ids) - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): - spike_times, spike_labels = spikes[seg_index] + spike_times = spikes[seg_index]["sample_index"] + spike_labels = spikes[seg_index]["unit_index"] + _compute_correlograms_numba( correlograms, spike_times.astype(np.int64), spike_labels.astype(np.int32), window_size, bin_size ) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 678ce8c2fd..aec70141cf 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -233,15 +233,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert num_bins >= 1 bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) ISIs = np.zeros((num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): + spike_times = spikes[seg_index]["sample_index"].astype(np.int64) + spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) + _compute_isi_histograms_numba( ISIs, - spikes[seg_index][0].astype(np.int64), - spikes[seg_index][1].astype(np.int32), + spike_times, + spike_labels, window_size, bin_size, fs, diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index a94f275f62..991d79506e 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -310,8 +310,11 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = self.extension_folder / "all_pcs.npy" file_path = Path(file_path) - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - spike_times, spike_labels = all_spikes[0] + # spikes = sorting.to_spike_vector(concatenated=False) + # # This is the first segment only + # spikes = spikes[0] + # spike_times = spikes["sample_index"] + # spike_labels = spikes["unit_index"] sparsity = self.get_sparsity() if sparsity is None: @@ -330,7 +333,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): # nSpikes, nFeaturesPerChannel, nPCFeatures # this comes from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets - shape = (spike_times.size, p["n_components"], max_channels_per_template) + num_spikes = sorting.to_spike_vector().size + shape = (num_spikes, p["n_components"], max_channels_per_template) all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) @@ -339,9 +343,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): init_func = _init_work_all_pc_extractor init_args = ( recording, + sorting.to_multiprocessing(job_kwargs["n_jobs"]), all_pcs_args, - spike_times, - spike_labels, we.nbefore, we.nafter, unit_channels, @@ -635,15 +638,21 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): all_pcs[i, :, c] = pca_model[chan_ind].transform(w) -def _init_work_all_pc_extractor( - recording, all_pcs_args, spike_times, spike_labels, nbefore, nafter, unit_channels, pca_model -): +def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx["recording"] = recording + worker_ctx["sorting"] = sorting + + spikes = sorting.to_spike_vector(concatenated=False) + # This is the first segment only + spikes = spikes[0] + spike_times = spikes["sample_index"] + spike_labels = spikes["unit_index"] + worker_ctx["all_pcs"] = np.lib.format.open_memmap(**all_pcs_args) worker_ctx["spike_times"] = spike_times worker_ctx["spike_labels"] = spike_labels diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b86a33179c..62a4e2c320 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -26,12 +26,15 @@ def _set_params(self, peak_sign="neg", return_scaled=True): def _select_extension_data(self, unit_ids): # load filter and save amplitude files + sorting = self.waveform_extractor.sorting + spikes = sorting.to_spike_vector(concatenated=False) + (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + new_extension_data = dict() - for seg_index in range(self.waveform_extractor.recording.get_num_segments()): + for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - _, all_labels = self.waveform_extractor.sorting.get_all_spike_trains()[seg_index] - filtered_idxs = np.in1d(all_labels, np.array(unit_ids)).nonzero() + filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -45,7 +48,7 @@ def _run(self, **job_kwargs): recording = we.recording sorting = we.sorting - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") + all_spikes = sorting.to_spike_vector() self._all_spikes = all_spikes peak_sign = self._params["peak_sign"] @@ -76,7 +79,7 @@ def _run(self, **job_kwargs): "The sorting object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) - init_args = (recording, sorting, extremum_channels_index, peak_shifts, return_scaled) + init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs ) @@ -107,7 +110,6 @@ def get_data(self, outputs="concatenated"): """ we = self.waveform_extractor sorting = we.sorting - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") if outputs == "concatenated": amplitudes = [] @@ -115,11 +117,13 @@ def get_data(self, outputs="concatenated"): amplitudes.append(self._extension_data[f"amplitude_segment_{segment_index}"]) return amplitudes elif outputs == "by_unit": + all_spikes = sorting.to_spike_vector(concatenated=False) + amplitudes_by_unit = [] for segment_index in range(we.get_num_segments()): amplitudes_by_unit.append({}) for unit_index, unit_id in enumerate(sorting.unit_ids): - _, spike_labels = all_spikes[segment_index] + spike_labels = all_spikes[segment_index]["unit_index"] mask = spike_labels == unit_index amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] amplitudes_by_unit[segment_index][unit_id] = amps @@ -193,8 +197,7 @@ def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, p worker_ctx["min_shift"] = np.min(peak_shifts) worker_ctx["max_shifts"] = np.max(peak_shifts) - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - worker_ctx["all_spikes"] = all_spikes + worker_ctx["all_spikes"] = sorting.to_spike_vector(concatenated=False) worker_ctx["extremum_channels_index"] = extremum_channels_index return worker_ctx @@ -209,7 +212,9 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - spike_times, spike_labels = all_spikes[segment_index] + spike_times = all_spikes[segment_index]["sample_index"] + spike_labels = all_spikes[segment_index]["unit_index"] + d = np.diff(spike_times) assert np.all(d >= 0) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index aac96be7b6..c6f498f7e8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -50,9 +50,6 @@ def _run(self, **job_kwargs): we = self.waveform_extractor - extremum_channel_inds = get_template_extremum_channel(we, outputs="index") - self.spikes = we.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) self._extension_data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index f9df45df2a..0adda426a9 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -29,7 +29,7 @@ def test_compute_unit_center_of_mass(): # sorting to dict d = {unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids} - sorting_unaligned = NumpySorting.from_dict(d, sampling_frequency=sorting.get_sampling_frequency()) + sorting_unaligned = NumpySorting.from_unit_dict(d, sampling_frequency=sorting.get_sampling_frequency()) print(sorting_unaligned) sorting_aligned = align_sorting(sorting_unaligned, unit_peak_shifts) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index bfbac11722..d6648150de 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -128,7 +128,7 @@ def test_auto_equal_cross_correlograms(): spike_times = np.sort(np.unique(np.random.randint(0, 100000, num_spike))) num_spike = spike_times.size units_dict = {"1": spike_times, "2": spike_times} - sorting = NumpySorting.from_dict([units_dict], sampling_frequency=10000.0) + sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=10000.0) for method in methods: correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) @@ -178,7 +178,7 @@ def test_detect_injected_correlation(): spike_times2 = np.sort(spike_times2) units_dict = {"1": spike_times1, "2": spike_times2} - sorting = NumpySorting.from_dict([units_dict], sampling_frequency=sampling_frequency) + sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=sampling_frequency) for method in methods: correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) @@ -204,13 +204,13 @@ def test_detect_injected_correlation(): if __name__ == "__main__": - # ~ test_make_bins() - # test_equal_results_correlograms() - # ~ test_flat_cross_correlogram() - # ~ test_auto_equal_cross_correlograms() + test_make_bins() + test_equal_results_correlograms() + test_flat_cross_correlogram() + test_auto_equal_cross_correlograms() test_detect_injected_correlation() - # test = CorrelogramsExtensionTest() - # test.setUp() - # test.test_compute_correlograms() - # test.test_extension() + test = CorrelogramsExtensionTest() + test.setUp() + test.test_compute_correlograms() + test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 870e710877..5d64525b52 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -64,11 +64,10 @@ def test_compute_for_all_spikes(self): pc_file_sparse = pc.extension_folder / "all_pc_sparse.npy" pc_sparse.run_for_all_spikes(pc_file_sparse, chunk_size=10000, n_jobs=1) all_pc_sparse = np.load(pc_file_sparse) - all_spikes = we_copy.sorting.get_all_spike_trains(outputs="unit_id") - _, spike_labels = all_spikes[0] - for unit_id, sparse_channel_ids in sparsity.unit_id_to_channel_ids.items(): - # check dimensions - pc_unit = all_pc_sparse[spike_labels == unit_id] + all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] + for unit_index, unit_id in enumerate(we.unit_ids): + sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] + pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) def test_sparse(self): @@ -198,8 +197,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - # test.test_extension() - # test.test_shapes() - # test.test_compute_for_all_spikes() - # test.test_sparse() + test.test_extension() + test.test_shapes() + test.test_compute_for_all_spikes() + test.test_sparse() test.test_project_new() diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 9f303de6e1..740fdd234b 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -568,7 +568,7 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( - contact_locations, local_radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 + contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights( # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) - nearest_template_mask = dist < local_radius_um + nearest_template_mask = dist < radius_um weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32) for count, sigma in enumerate(sigma_um): diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index e30176f099..0b8d8a730b 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -14,9 +14,9 @@ class DepthOrderRecording(ChannelSliceRecording): The recording to re-order. channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str or tuple + dimensions : str, tuple, list If str, it needs to be 'x', 'y', 'z'. - If tuple, it sorts the locations in two dimensions using lexsort. + If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') """ @@ -31,6 +31,11 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording, channel_ids=reordered_channel_ids, ) + self._kwargs = dict( + parent_recording=parent_recording, + channel_ids=channel_ids, + dimensions=dimensions, + ) depth_order = define_function_from_class(source_class=DepthOrderRecording, name="depth_order") diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 56d43e13e8..79b5ba5bc3 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -40,7 +40,7 @@ def __init__(self, recording: BaseRecording, freq_min: float = 300.0, freq_max: for parent_segment in recording._recording_segments: self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max)) - self._kwargs = {"recording": recording.to_dict(), "freq_min": freq_min, "freq_max": freq_max} + self._kwargs = {"recording": recording, "freq_min": freq_min, "freq_max": freq_max} class GaussianFilterRecordingSegment(BasePreprocessorSegment): diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 957d4f588e..8b0c8006d2 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -18,12 +18,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="monopolar_triangulation", - local_radius_um=75.0, + radius_um=75.0, max_distance_um=150.0, optimizer="minimize_with_log_penality", enforce_decrease=True, @@ -81,12 +81,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="center_of_mass", - local_radius_um=75.0, + radius_um=75.0, feature="ptp", ), "estimate_motion_kwargs": dict( @@ -109,12 +109,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="grid_convolution", - local_radius_um=40.0, + radius_um=40.0, upsampling_um=5.0, sigma_um=np.linspace(5.0, 25.0, 5), sigma_ms=0.25, @@ -258,7 +258,7 @@ def correct_motion( noise_levels = get_noise_levels(recording, return_scaled=False) if select_kwargs is None: - # maybe do this directly in the folderwhen not None + # maybe do this directly in the folder when not None gather_mode = "memory" # node detect @@ -328,10 +328,12 @@ def correct_motion( estimate_motion_kwargs=estimate_motion_kwargs, interpolate_motion_kwargs=interpolate_motion_kwargs, job_kwargs=job_kwargs, + sampling_frequency=recording.sampling_frequency, ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - recording.dump_to_json(folder / "recording.json") + if recording.check_if_json_serializable(): + recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) np.save(folder / "peak_locations.npy", peak_locations) diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 5d6cc0eb16..95e5a097ff 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -105,10 +105,10 @@ def test_filter_opencl(): # rec2_cached0 = rec2.save(chunk_size=1000,verbose=False, progress_bar=True, n_jobs=4) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries - # plot_timeseries(rec, segment_index=0) - # plot_timeseries(rec_filtered, segment_index=0) - # plot_timeseries(rec2_cached0, segment_index=0) + # from spikeinterface.widgets import plot_traces + # plot_traces(rec, segment_index=0) + # plot_traces(rec_filtered, segment_index=0) + # plot_traces(rec2_cached0, segment_index=0) # plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 45db8440b9..b62a73a8cb 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -30,7 +30,7 @@ def test_normalize_by_quantile(): rec2.save(verbose=False) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/preprocessing/tests/test_phase_shift.py b/src/spikeinterface/preprocessing/tests/test_phase_shift.py index 41293b6c25..b1ccc433b3 100644 --- a/src/spikeinterface/preprocessing/tests/test_phase_shift.py +++ b/src/spikeinterface/preprocessing/tests/test_phase_shift.py @@ -104,9 +104,9 @@ def test_phase_shift(): # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.full as si - # ~ si.plot_timeseries(rec, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec2, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec3, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec2, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec3, segment_index=0, time_range=[0, 10]) # ~ plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index d4f58d3cc3..cca41ebf7d 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -27,7 +27,7 @@ def test_rectify(): assert traces.shape[1] == 1 # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 9130cb37d5..778de8aea4 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -367,9 +367,12 @@ def compute_refrac_period_violations( fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) num_segments = sorting.get_num_segments() - spikes = sorting.get_all_spike_trains(outputs="unit_index") + + spikes = sorting.to_spike_vector(concatenated=False) + if unit_ids is None: unit_ids = sorting.unit_ids + num_spikes = compute_num_spikes(waveform_extractor) t_c = int(round(censored_period_ms * fs * 1e-3)) @@ -377,9 +380,9 @@ def compute_refrac_period_violations( nb_rp_violations = np.zeros((num_units), dtype=np.int64) for seg_index in range(num_segments): - _compute_rp_violations_numba( - nb_rp_violations, spikes[seg_index][0].astype(np.int64), spikes[seg_index][1].astype(np.int32), t_c, t_r - ) + spike_times = spikes[seg_index]["sample_index"].astype(np.int64) + spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) + _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) T = waveform_extractor.get_total_samples() diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2a0feb4da8..e725498773 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -466,15 +466,14 @@ def nearest_neighbors_isolation( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than ", - f"specified by `min_spikes` ({min_spikes}); ", - f"returning NaN as the quality metric...", + f"Unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan, np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate ", - f"below the specified `min_fr` ({min_fr}Hz); " f"returning NaN as the quality metric...", + f"Unit {this_unit_id} has a firing rate below the specified `min_fr` " + f"({min_fr} Hz); returning NaN as the quality metric..." ) return np.nan, np.nan else: @@ -652,15 +651,14 @@ def nearest_neighbors_noise_overlap( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than ", - f"specified by `min_spikes` ({min_spikes}); ", - f"returning NaN as the quality metric...", + f"Unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate ", - f"below the specified `min_fr` ({min_fr}Hz); " f"returning NaN as the quality metric...", + f"Unit {this_unit_id} has a firing rate below the specified `min_fr` " + f"({min_fr} Hz); returning NaN as the quality metric...", ) return np.nan else: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index ececf167b0..e2b95c8e39 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -264,7 +264,9 @@ def test_calculate_rp_violations(simulated_data): assert np.allclose(list(rp_contamination_gt.values()), list(rp_contamination.values()), rtol=0.05) np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) - sorting = NumpySorting.from_dict({0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000) + sorting = NumpySorting.from_unit_dict( + {0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000 + ) we.sorting = sorting rp_contamination, counts = compute_refrac_period_violations(we, refractory_period_ms=1, censored_period_ms=0.0) assert np.isnan(rp_contamination[1]) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4bc61768c0..bd792e1aac 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -279,7 +279,7 @@ def test_recordingless(self): def test_empty_units(self): we = self.we1 empty_spike_train = np.array([], dtype="int64") - empty_sorting = NumpySorting.from_dict( + empty_sorting = NumpySorting.from_unit_dict( {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, sampling_frequency=we.sampling_frequency, ) @@ -296,7 +296,8 @@ def test_empty_units(self): if __name__ == "__main__": test = QualityMetricsExtensionTest() test.setUp() - test.test_drift_metrics() - test.test_extension() + # test.test_drift_metrics() + # test.test_extension() # test.test_nn_metrics() # test.test_peak_sign() + test.test_empty_units() diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index cbaba31d02..7ea2fe5a23 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -120,8 +120,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if output_folder is None: output_folder = cls.sorter_name + "_output" - #  .absolute() not anymore - output_folder = Path(output_folder) + # Resolve path + output_folder = Path(output_folder).absolute() sorter_output_folder = output_folder / "sorter_output" if output_folder.is_dir(): diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 8a6998db92..267ff38e36 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -44,6 +44,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "ntbuff": 64, "nfilt_factor": 4, "NT": None, + "AUCsplit": 0.9, "wave_length": 61, "keep_good_only": False, "skip_kilosort_preprocessing": False, @@ -66,6 +67,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -161,7 +163,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 10.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.9 + ops["AUCsplit"] = params["AUCsplit"] # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed ops["minFR"] = params["minFR"] diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index ced2bd05ab..0c9e36177e 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -50,6 +50,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "ntbuff": 64, "nfilt_factor": 4, "NT": None, + "AUCsplit": 0.9, "do_correction": True, "wave_length": 61, "keep_good_only": False, @@ -76,6 +77,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "keep_good_only": "If True only 'good' units are returned", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -182,7 +184,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 10.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.9 + ops["AUCsplit"] = params["AUCsplit"] # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed ops["minFR"] = params["minFR"] diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index c514480896..77e83e35b9 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -48,6 +48,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "nfilt_factor": 4, "do_correction": True, "NT": None, + "AUCsplit": 0.8, "wave_length": 61, "keep_good_only": False, "skip_kilosort_preprocessing": False, @@ -73,6 +74,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -171,7 +173,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 20.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.8 + ops["AUCsplit"] = params["AUCsplit"] # minimum firing rate on a "good" channel (0 to skip) ops["minfr_goodchannels"] = params["minfr_goodchannels"] diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 64ec92793e..69f97fd11c 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -140,7 +140,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # convert sorting to new API and save it unit_ids = old_api_sorting.get_unit_ids() units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}] - new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate) + new_api_sorting = NumpySorting.from_unit_dict(units_dict_list, samplerate) NpzSortingExtractor.write_sorting(new_api_sorting, str(sorter_output_folder / "firings.npz")) @classmethod diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 24c4a7ccfc..9de2762562 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "local_radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, @@ -75,8 +75,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) - if "local_radius_um" not in detection_params: - detection_params["local_radius_um"] = params["general"]["local_radius_um"] + if "radius_um" not in detection_params: + detection_params["radius_um"] = params["general"]["radius_um"] if "exclude_sweep_ms" not in detection_params: detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"]) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index a812d4ce49..42f51d3a77 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -12,7 +12,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, - "general": {"ms_before": 2.5, "ms_after": 3.5, "local_radius_um": 100}, + "general": {"ms_before": 2.5, "ms_after": 3.5, "radius_um": 100}, "filtering": {"freq_min": 300, "freq_max": 8000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 0.4}, "hdbscan_kwargs": { @@ -68,7 +68,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # detection detection_params = params["detection"].copy() - detection_params["local_radius_um"] = params["general"]["local_radius_um"] + detection_params["radius_um"] = params["general"]["radius_um"] detection_params["noise_levels"] = noise_levels peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) @@ -89,7 +89,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # localization localization_params = params["localization"].copy() - localization_params["local_radius_um"] = params["general"]["local_radius_um"] + localization_params["radius_um"] = params["general"]["radius_um"] peak_locations = localize_peaks( recording, some_peaks, method="monopolar_triangulation", **localization_params, **job_kwargs ) @@ -127,7 +127,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params["noise_levels"] = noise_levels matching_params["peak_sign"] = params["detection"]["peak_sign"] matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - matching_params["local_radius_um"] = params["general"]["local_radius_um"] + matching_params["radius_um"] = params["general"]["radius_um"] # TODO: route that params # ~ 'num_closest' : 5, diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 478ace13e5..52098f45cd 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -273,7 +273,7 @@ def run_sorters( if need_dump: if not recording.check_if_dumpable(): raise Exception("recording not dumpable call recording.save() before") - recording_arg = recording.to_dict() + recording_arg = recording.to_dict(recursive=True) else: recording_arg = recording diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 7eea75ce81..d68b8e5449 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -161,15 +161,16 @@ def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0 if self.verbose: print("Performing the comparison with (sliced) ground truth") - times1 = self.gt_sorting.get_all_spike_trains()[0] - times2 = self.clustering.get_all_spike_trains()[0] - matches = make_matching_events(times1[0], times2[0], int(delta * self.sampling_rate / 1000)) + spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] + spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] + + matches = make_matching_events( + spikes1["sample_index"], spikes2["sample_index"], int(delta * self.sampling_rate / 1000) + ) self.matches = matches idx = matches["index1"] - self.sliced_gt_sorting = NumpySorting.from_times_labels( - times1[0][idx], times1[1][idx], self.sampling_rate, unit_ids=self.gt_sorting.unit_ids - ) + self.sliced_gt_sorting = NumpySorting(spikes1[idx], self.sampling_rate, self.gt_sorting.unit_ids) self.comp = GroundTruthComparison(self.sliced_gt_sorting, self.clustering, exhaustive_gt=self.exhaustive_gt) @@ -251,10 +252,11 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels_ids = sorting.get_all_spike_trains()[0][1] + labels = sorting.to_spike_vector(concatenated=False)[0]["unit_index"] for unit_ind, unit_id in enumerate(sorting.unit_ids): - where = np.flatnonzero(labels_ids == unit_id) + where = np.flatnonzero(labels == unit_ind) + xk = xs[where] yk = ys[where] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index bf3577368e..dd35670abd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -337,12 +337,10 @@ def plot_motion_corrected_peaks(self, scaling_probe=1.5, alpha=0.05, figsize=(15 channel_positions = self.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - times = self.recording.get_times() - peak_locations_corrected = correct_motion_on_peaks( self.selected_peaks, self.peak_locations, - times, + self.recording.sampling_frequency, self.motion, self.temporal_bins, self.spatial_bins, diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index b5ad24a5b3..d02464d0d0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -447,15 +447,15 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"): import spikeinterface.full as si - unit_id = benchmark.waveforms.sorting.unit_ids[cell_ind] + sorting = benchmark.waveforms.sorting + unit_id = sorting.unit_ids[cell_ind] - mask = benchmark.waveforms.sorting.get_all_spike_trains()[0][1] == unit_id - times = ( - benchmark.waveforms.sorting.get_all_spike_trains()[0][0][mask] / benchmark.recording.get_sampling_frequency() - ) + spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] + mask = spikes_seg0["unit_index"] == cell_ind + times = spikes_seg0[mask] / sorting.get_sampling_frequency() print(benchmark.recording) - # si.plot_timeseries(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) + # si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) # axs[0, 1].set_ylabel('Neurons') # si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1], diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index b82102e9fd..1514a63dd4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -113,26 +113,24 @@ def run(self, peaks=None, positions=None, delta=0.2): if positions is not None: self._positions = positions - times1 = self.gt_sorting.get_all_spike_trains()[0] + spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0]["sample_index"] times2 = self.peaks["sample_index"] print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) - matches = make_matching_events(times1[0], times2, int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(spikes1["sample_index"], times2, int(delta * self.sampling_rate / 1000)) self.matches = matches self.deltas = {"labels": [], "delta": matches["delta_frame"]} - self.deltas["labels"] = times1[1][matches["index1"]] + self.deltas["labels"] = spikes1["unit_index"][matches["index1"]] - # print(len(times1[0]), len(matches['index1'])) gt_matches = matches["index1"] - self.sliced_gt_sorting = NumpySorting.from_times_labels( - times1[0][gt_matches], times1[1][gt_matches], self.sampling_rate, unit_ids=self.gt_sorting.unit_ids - ) - ratio = 100 * len(gt_matches) / len(times1[0]) + self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) + + ratio = 100 * len(gt_matches) / len(spikes1) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) - matches = make_matching_events(times2, times1[0], int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) @@ -231,10 +229,11 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels_ids = sorting.get_all_spike_trains()[0][1] + labels = sorting.to_spike_vector(concatenated=False)[0]["unit_index"] for unit_ind, unit_id in enumerate(sorting.unit_ids): - where = np.flatnonzero(labels_ids == unit_id) + where = np.flatnonzero(labels == unit_ind) + xk = xs[where] yk = ys[where] @@ -539,11 +538,11 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) nb_spikes += [b] centers = compute_center_of_mass(self.waveforms["gt"]) - times, labels = self.sliced_gt_sorting.get_all_spike_trains()[0] + spikes_seg0 = self.sliced_gt_sorting.to_spike_vector(concatenated=False)[0] stds = [] means = [] - for found, real in zip(unit_ids2, inds_1): - mask = labels == found + for found, real in zip(inds_2, inds_1): + mask = spikes_seg0["unit_index"] == found center = np.array([self.sliced_gt_positions[mask]["x"], self.sliced_gt_positions[mask]["y"]]).mean() means += [np.mean(center - centers[real])] stds += [np.std(center - centers[real])] @@ -613,22 +612,23 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) def explore_garbage(self, channel_index, nb_bins=None, dt=None): mask = self.garbage_peaks["channel_index"] == channel_index times2 = self.garbage_peaks[mask]["sample_index"] - times1 = self.gt_sorting.get_all_spike_trains()[0] + spikes1 = self.gt_sorting.to_spike_vector(concatenate=False)[0] + from spikeinterface.comparison.comparisontools import make_matching_events if dt is None: delta = self.waveforms["garbage"].nafter else: delta = dt - matches = make_matching_events(times2, times1[0], delta) - units = times1[1][matches["index2"]] + matches = make_matching_events(times2, spikes1["sample_index"], delta) + unit_inds = spikes1["unit_index"][matches["index2"]] dt = matches["delta_frame"] res = {} fig, ax = plt.subplots() if nb_bins is None: nb_bins = 2 * delta xaxis = np.linspace(-delta, delta, nb_bins) - for unit_id in np.unique(units): - mask = units == unit_id - res[unit_id] = dt[mask] - ax.hist(res[unit_id], bins=xaxis) + for unit_ind in np.unique(unit_inds): + mask = unit_inds == unit_ind + res[unit_ind] = dt[mask] + ax.hist(res[unit_ind], bins=xaxis) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index a6185f5193..46aba7e96f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -37,7 +37,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "tmp_folder": None, - "local_radius_um": 100, + "radius_um": 100, "n_pca": 10, "max_spikes_per_unit": 200, "ms_before": 1.5, @@ -104,7 +104,7 @@ def main_function(cls, recording, peaks, params): chan_distances = get_channel_distances(recording) for main_chan in unit_inds: - (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["local_radius_um"]) + (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["radius_um"]) sparsity_mask[main_chan, closest_chans] = True if params["waveform_mode"] == "shared_memory": diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 53833b01a2..6edf5af16b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -579,7 +579,7 @@ def remove_duplicates_via_matching( f.write(blanck) f.close() - recording = BinaryRecordingExtractor(tmp_filename, num_chan=num_chans, sampling_frequency=fs, dtype="float32") + recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 082d2dc0ba..8d21041599 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -35,7 +35,7 @@ class PositionAndFeaturesClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "local_radius_um": 100, + "radius_um": 100, "max_spikes_per_unit": 200, "selection_method": "random", "ms_before": 1.5, @@ -69,9 +69,9 @@ def main_function(cls, recording, peaks, params): features_list = [position_method, "ptp", "energy"] features_params = { - position_method: {"local_radius_um": params["local_radius_um"]}, - "ptp": {"all_channels": False, "local_radius_um": params["local_radius_um"]}, - "energy": {"local_radius_um": params["local_radius_um"]}, + position_method: {"radius_um": params["radius_um"]}, + "ptp": {"all_channels": False, "radius_um": params["radius_um"]}, + "energy": {"radius_um": params["radius_um"]}, } features_data = compute_features_from_peaks( diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 02247dd288..fcbcac097f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -34,7 +34,7 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "local_radius_um": 100, + "radius_um": 100, "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", "nb_projections": {"ptp": 8, "energy": 2}, @@ -106,7 +106,7 @@ def main_function(cls, recording, peaks, params): projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) features_params[f"random_projections_{proj_type}"] = { - "local_radius_um": params["local_radius_um"], + "radius_um": params["radius_um"], "projections": projections, "min_values": min_values, } diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 24fea6429f..68b34a7041 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -367,7 +367,7 @@ def main_function(cls, recording, peaks, params): if HAVE_NUMBA: - @numba.jit(fastmath=True, cache=True) + @numba.jit(nopython=True, fastmath=True, cache=True) def sparse_euclidean(x, y, n_samples, n_dense): """Euclidean distance metric over sparse vectors, where first n_dense elements are indices, and n_samples is the length of the second dimension diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index c075e8e7c1..adc025e829 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -105,15 +105,15 @@ def compute(self, traces, peaks, waveforms): class PeakToPeakFeature(PipelineNode): def __init__( - self, recording, name="ptp_feature", return_output=True, parents=None, local_radius_um=150.0, all_channels=True + self, recording, name="ptp_feature", return_output=True, parents=None, radius_um=150.0, all_channels=True ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.all_channels = all_channels - self._kwargs.update(dict(local_radius_um=local_radius_um, all_channels=all_channels)) + self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -139,19 +139,19 @@ def __init__( name="ptp_lag_feature", return_output=True, parents=None, - local_radius_um=150.0, + radius_um=150.0, all_channels=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.all_channels = all_channels - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um, all_channels=all_channels)) + self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -184,20 +184,20 @@ def __init__( return_output=True, parents=None, projections=None, - local_radius_um=150.0, + radius_um=150.0, min_values=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.min_values = min_values self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(projections=projections, local_radius_um=local_radius_um, min_values=min_values)) + self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) self._dtype = recording.get_dtype() @@ -230,19 +230,19 @@ def __init__( return_output=True, parents=None, projections=None, - local_radius_um=150.0, + radius_um=150.0, min_values=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.projections = projections self.min_values = min_values - self.local_radius_um = local_radius_um - self._kwargs.update(dict(projections=projections, min_values=min_values, local_radius_um=local_radius_um)) + self.radius_um = radius_um + self._kwargs.update(dict(projections=projections, min_values=min_values, radius_um=radius_um)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -267,14 +267,14 @@ def compute(self, traces, peaks, waveforms): class StdPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="std_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="std_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -292,14 +292,14 @@ def compute(self, traces, peaks, waveforms): class GlobalPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="global_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="global_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -317,14 +317,14 @@ def compute(self, traces, peaks, waveforms): class KurtosisPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -344,14 +344,14 @@ def compute(self, traces, peaks, waveforms): class EnergyFeature(PipelineNode): - def __init__(self, recording, name="energy_feature", return_output=True, parents=None, local_radius_um=50.0): + def __init__(self, recording, name="energy_feature", return_output=True, parents=None, radius_um=50.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) def get_dtype(self): return np.dtype("float32") diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index ba4c0e93f3..4e2625acec 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -35,7 +35,7 @@ class NaiveMatching(BaseTemplateMatchingEngine): "exclude_sweep_ms": 0.1, "detect_threshold": 5, "noise_levels": None, - "local_radius_um": 100, + "radius_um": 100, "random_chunk_kwargs": {}, } @@ -54,7 +54,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["local_radius_um"] + d["neighbours_mask"] = channel_distance < d["radius_um"] d["nbefore"] = we.nbefore d["nafter"] = we.nafter diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5fbe1b94f3..7d6d707ea2 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -50,7 +50,7 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): "peak_shift_ms": 0.2, "detect_threshold": 5, "noise_levels": None, - "local_radius_um": 100, + "radius_um": 100, "num_closest": 5, "sample_shift": 3, "ms_before": 0.8, @@ -103,7 +103,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["local_radius_um"] + d["neighbours_mask"] = channel_distance < d["radius_um"] sparsity = compute_sparsity(we, method="snr", peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) template_sparsity_inds = sparsity.unit_id_to_channel_indices @@ -154,7 +154,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): # distance channel from unit distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < d["local_radius_um"] + near_cluster_mask = distances < d["radius_um"] # nearby cluster for each channel possible_clusters_by_channel = [] diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 2fa23ee98d..b4a44105e4 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -19,7 +19,7 @@ def correct_motion_on_peaks( peaks, peak_locations, - times, + sampling_frequency, motion, temporal_bins, spatial_bins, @@ -34,8 +34,8 @@ def correct_motion_on_peaks( peaks vector peak_locations: np.array peaks location vector - times: np.array - times vector of recording + sampling_frequency: np.array + sampling_frequency of the recording motion: np.array 2D motion.shape[0] equal temporal_bins.shape[0] motion.shape[1] equal 1 when "rigid" motion equal temporal_bins.shape[0] when "non-rigid" @@ -51,18 +51,17 @@ def correct_motion_on_peaks( """ corrected_peak_locations = peak_locations.copy() + spike_times = peaks["sample_index"] / sampling_frequency if spatial_bins.shape[0] == 1: # rigid motion interpolation 1D - sample_bins = np.searchsorted(times, temporal_bins) - f = scipy.interpolate.interp1d(sample_bins, motion[:, 0], bounds_error=False, fill_value="extrapolate") - shift = f(peaks["sample_index"]) + f = scipy.interpolate.interp1d(temporal_bins, motion[:, 0], bounds_error=False, fill_value="extrapolate") + shift = f(spike_times) corrected_peak_locations[direction] -= shift else: # non rigid motion = interpolation 2D f = scipy.interpolate.RegularGridInterpolator( (temporal_bins, spatial_bins), motion, method="linear", bounds_error=False, fill_value=None ) - spike_times = times[peaks["sample_index"]] shift = f(np.c_[spike_times, peak_locations[direction]]) corrected_peak_locations[direction] -= shift diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index c495a3bfa4..4fd7611bb7 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -504,7 +504,7 @@ class DetectPeakLocallyExclusive(PeakDetectorWrapper): params_doc = ( DetectPeakByChannel.params_doc + """ - local_radius_um: float + radius_um: float The radius to use to select neighbour channels for locally exclusive detection. """ ) @@ -516,7 +516,7 @@ def check_params( peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, noise_levels=None, random_chunk_kwargs={}, ): @@ -533,7 +533,7 @@ def check_params( ) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < local_radius_um + neighbours_mask = channel_distance < radius_um return args + (neighbours_mask,) @classmethod @@ -580,7 +580,7 @@ class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper): params_doc = ( DetectPeakByChannel.params_doc + """ - local_radius_um: float + radius_um: float The radius to use to select neighbour channels for locally exclusive detection. """ ) @@ -594,7 +594,7 @@ def check_params( exclude_sweep_ms=0.1, noise_levels=None, device=None, - local_radius_um=50, + radius_um=50, return_tensor=False, random_chunk_kwargs={}, ): @@ -615,7 +615,7 @@ def check_params( neighbour_indices_by_chan = [] num_channels = recording.get_num_channels() for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < local_radius_um)[0]) + neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0]) max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) for i, neigh in enumerate(neighbour_indices_by_chan): @@ -640,7 +640,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, devi if HAVE_NUMBA: - @numba.jit(parallel=False) + @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_pos( traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask ): @@ -665,7 +665,7 @@ def _numba_detect_peak_pos( break return peak_mask - @numba.jit(parallel=False) + @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_neg( traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask ): @@ -836,7 +836,7 @@ def check_params( peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, noise_levels=None, random_chunk_kwargs={}, ): @@ -847,7 +847,7 @@ def check_params( abs_threholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < local_radius_um + neighbours_mask = channel_distance < radius_um executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index d1df720624..bd793b3f53 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -101,14 +101,14 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ class LocalizeBase(PipelineNode): - def __init__(self, recording, return_output=True, parents=None, local_radius_um=75.0): + def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um - self._kwargs["local_radius_um"] = local_radius_um + self.neighbours_mask = self.channel_distance < radius_um + self._kwargs["radius_um"] = radius_um def get_dtype(self): return self._dtype @@ -152,18 +152,14 @@ class LocalizeCenterOfMass(LocalizeBase): need_waveforms = True name = "center_of_mass" params_doc = """ - local_radius_um: float + radius_um: float Radius in um for channel sparsity. feature: str ['ptp', 'mean', 'energy', 'peak_voltage'] Feature to consider for computation. Default is 'ptp' """ - def __init__( - self, recording, return_output=True, parents=["extract_waveforms"], local_radius_um=75.0, feature="ptp" - ): - LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, local_radius_um=local_radius_um - ) + def __init__(self, recording, return_output=True, parents=["extract_waveforms"], radius_um=75.0, feature="ptp"): + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents, radius_um=radius_um) self._dtype = np.dtype(dtype_localize_by_method["center_of_mass"]) assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" @@ -216,7 +212,7 @@ class LocalizeMonopolarTriangulation(PipelineNode): need_waveforms = False name = "monopolar_triangulation" params_doc = """ - local_radius_um: float + radius_um: float For channel sparsity. max_distance_um: float, default: 1000 Boundary for distance estimation. @@ -234,15 +230,13 @@ def __init__( recording, return_output=True, parents=["extract_waveforms"], - local_radius_um=75.0, + radius_um=75.0, max_distance_um=150.0, optimizer="minimize_with_log_penality", enforce_decrease=True, feature="ptp", ): - LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, local_radius_um=local_radius_um - ) + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents, radius_um=radius_um) assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" self.max_distance_um = max_distance_um @@ -309,7 +303,7 @@ class LocalizeGridConvolution(PipelineNode): need_waveforms = True name = "grid_convolution" params_doc = """ - local_radius_um: float + radius_um: float Radius in um for channel sparsity. upsampling_um: float Upsampling resolution for the grid of templates @@ -333,7 +327,7 @@ def __init__( recording, return_output=True, parents=["extract_waveforms"], - local_radius_um=40.0, + radius_um=40.0, upsampling_um=5.0, sigma_um=np.linspace(5.0, 25.0, 5), sigma_ms=0.25, @@ -344,7 +338,7 @@ def __init__( ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.sigma_um = sigma_um self.margin_um = margin_um self.upsampling_um = upsampling_um @@ -371,7 +365,7 @@ def __init__( self.prototype = self.prototype[:, np.newaxis] self.template_positions, self.weights, self.nearest_template_mask = get_grid_convolution_templates_and_weights( - contact_locations, self.local_radius_um, self.upsampling_um, self.sigma_um, self.margin_um + contact_locations, self.radius_um, self.upsampling_um, self.sigma_um, self.margin_um ) self.weights_sparsity_mask = self.weights > self.sparsity_threshold @@ -379,7 +373,7 @@ def __init__( self._dtype = np.dtype(dtype_localize_by_method["grid_convolution"]) self._kwargs.update( dict( - local_radius_um=self.local_radius_um, + radius_um=self.radius_um, prototype=self.prototype, template_positions=self.template_positions, nearest_template_mask=self.nearest_template_mask, diff --git a/src/spikeinterface/sortingcomponents/peak_pipeline.py b/src/spikeinterface/sortingcomponents/peak_pipeline.py index 9e43fd2d78..6f0f26201f 100644 --- a/src/spikeinterface/sortingcomponents/peak_pipeline.py +++ b/src/spikeinterface/sortingcomponents/peak_pipeline.py @@ -223,7 +223,7 @@ def __init__( ms_after: float, parents: Optional[List[PipelineNode]] = None, return_output: bool = False, - local_radius_um: float = 100.0, + radius_um: float = 100.0, ): """ Extract sparse waveforms from a recording. The strategy in this specific node is to reshape the waveforms @@ -260,10 +260,10 @@ def __init__( return_output=return_output, ) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py index e46d037c9e..b3b5f656cb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py @@ -34,8 +34,8 @@ def test_features_from_peaks(): feature_params = { "amplitude": {"all_channels": False, "peak_sign": "neg"}, "ptp": {"all_channels": False}, - "center_of_mass": {"local_radius_um": 120.0}, - "energy": {"local_radius_um": 160.0}, + "center_of_mass": {"radius_um": 120.0}, + "energy": {"radius_um": 160.0}, } features = compute_features_from_peaks(recording, peaks, feature_list, feature_params=feature_params, **job_kwargs) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 9860275739..0558c16cca 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -45,7 +45,7 @@ def setup_module(): extract_dense_waveforms = ExtractDenseWaveforms(recording, ms_before=0.1, ms_after=0.3, return_output=False) pipeline_nodes = [ extract_dense_waveforms, - LocalizeCenterOfMass(recording, parents=[extract_dense_waveforms], local_radius_um=60.0), + LocalizeCenterOfMass(recording, parents=[extract_dense_waveforms], radius_um=60.0), ] peaks, peak_locations = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index b25cea69a6..b7ab67350e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -44,12 +44,11 @@ def test_correct_motion_on_peaks(): # fake locations peak_locations = np.zeros((peaks.size), dtype=[("x", "float32"), ("y", "float")]) - times = rec.get_times() corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - times, + rec.sampling_frequency, motion, temporal_bins, spatial_bins, diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 380bd67a94..0511f9b69d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -73,7 +73,7 @@ def sorting_fixture(): def spike_trains(sorting): - spike_trains = sorting.get_all_spike_trains()[0][0] + spike_trains = sorting.to_spike_vector()["sample_index"] return spike_trains @@ -139,7 +139,7 @@ def peak_detector_kwargs(recording): exclude_sweep_ms=1.0, peak_sign="both", detect_threshold=5, - local_radius_um=50, + radius_um=50, ) return peak_detector_keyword_arguments @@ -194,12 +194,12 @@ def test_iterative_peak_detection_sparse(recording, job_kwargs, pca_model_folder ms_before = 1.0 ms_after = 1.0 - local_radius_um = 40 + radius_um = 40 waveform_extraction_node = ExtractSparseWaveforms( recording=recording, ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, ) waveform_denoising_node = TemporalPCADenoising( @@ -368,7 +368,7 @@ def test_peak_detection_with_pipeline(recording, job_kwargs, torch_job_kwargs): pipeline_nodes = [ extract_dense_waveforms, PeakToPeakFeature(recording, all_channels=False, parents=[extract_dense_waveforms]), - LocalizeCenterOfMass(recording, local_radius_um=50.0, parents=[extract_dense_waveforms]), + LocalizeCenterOfMass(recording, radius_um=50.0, parents=[extract_dense_waveforms]), ] peaks, ptp, peak_locations = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py index c4192c5fcf..34bc93fbfa 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py @@ -83,7 +83,7 @@ def test_pca_denoising_sparse(mearec_recording, detected_peaks, model_path_of_tr peaks = detected_peaks # Parameters - local_radius_um = 40 + radius_um = 40 ms_before = 1.0 ms_after = 1.0 @@ -94,7 +94,7 @@ def test_pca_denoising_sparse(mearec_recording, detected_peaks, model_path_of_tr parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, return_output=True, ) pca_denoising = TemporalPCADenoising( @@ -143,7 +143,7 @@ def test_pca_projection_sparsity(mearec_recording, detected_peaks, model_path_of peaks = detected_peaks # Parameters - local_radius_um = 40 + radius_um = 40 ms_before = 1.0 ms_after = 1.0 @@ -154,7 +154,7 @@ def test_pca_projection_sparsity(mearec_recording, detected_peaks, model_path_of parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, return_output=True, ) temporal_pca = TemporalPCAProjection( @@ -181,7 +181,7 @@ def test_initialization_with_wrong_parents_failure(mearec_recording, model_path_ model_folder_path = model_path_of_trained_pca dummy_parent = PipelineNode(recording=recording) extract_waveforms = ExtractSparseWaveforms( - recording=recording, ms_before=1, ms_after=1, local_radius_um=40, return_output=True + recording=recording, ms_before=1, ms_after=1, radius_um=40, return_output=True ) match_error = f"TemporalPCA should have a single {WaveformsNode.__name__} in its parents" diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 5283fd0f99..14b66fc847 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -29,7 +29,7 @@ def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0 ms_before=ms_before, ms_after=ms_after, return_output=True, - local_radius_um=5, + radius_um=5, ) nbefore = sparse_waveforms.nbefore diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index de96fe445a..28cf8a3be0 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -93,7 +93,7 @@ def fit( ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - local_radius_um: float = None, + radius_um: float = None, ) -> IncrementalPCA: """ Train a pca model using the data in the recording object and the parameters provided. @@ -114,7 +114,7 @@ def fit( The parameters for peak selection. whiten : bool, optional Whether to whiten the data, by default True. - local_radius_um : float, optional + radius_um : float, optional The radius (in micrometers) to use for definint sparsity, by default None. ms_before : float, optional The number of milliseconds to include before the peak of the spike, by default 1. @@ -148,7 +148,7 @@ def fit( ) # compute PCA by_channel_global (with sparsity) - sparsity = ChannelSparsity.from_radius(we, radius_um=local_radius_um) if local_radius_um else None + sparsity = ChannelSparsity.from_radius(we, radius_um=radius_um) if radius_um else None pc = compute_principal_components( we, n_components=n_components, mode="by_channel_global", sparsity=sparsity, whiten=whiten ) diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index 83f4b85fee..d3066f51fa 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -1,37 +1,3 @@ -# check if backend are available -try: - import matplotlib - - HAVE_MPL = True -except: - HAVE_MPL = False - -try: - import sortingview - - HAVE_SV = True -except: - HAVE_SV = False - -try: - import ipywidgets - - HAVE_IPYW = True -except: - HAVE_IPYW = False - - -# theses import make the Widget.resgister() at import time -if HAVE_MPL: - import spikeinterface.widgets.matplotlib - -if HAVE_SV: - import spikeinterface.widgets.sortingview - -if HAVE_IPYW: - import spikeinterface.widgets.ipywidgets - -# when importing widget list backend are already registered from .widget_list import * # general functions diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 06f68a754e..c0dcd7ea6e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,30 +1,20 @@ # basics -# from .timeseries import plot_timeseries, TimeseriesWidget +# from .timeseries import plot_timeseries, TracesWidget from .rasters import plot_rasters, RasterWidget from .probemap import plot_probe_map, ProbeMapWidget # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget -# from .correlograms import (plot_crosscorrelograms, CrossCorrelogramsWidget, -# plot_autocorrelograms, AutoCorrelogramsWidget) - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget # waveform/PC related -# from .unitwaveforms import plot_unit_waveforms, plot_unit_templates -# from .unitwaveformdensitymap import plot_unit_waveform_density_map, UnitWaveformDensityMapWidget -# from .amplitudes import plot_amplitudes_distribution from .principalcomponent import plot_principal_component -# from .unitlocalization import plot_unit_localization, UnitLocalizationWidget - # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# from .depthamplitude import plot_units_depth_vs_amplitude - # comparison related from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget @@ -77,8 +67,6 @@ ComparisonPerformancesByTemplateSimilarity, ) -# unit summary -# from .unitsummary import plot_unit_summary, UnitSummaryWidget # unit presence from .presence import plot_presence, PresenceWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py deleted file mode 100644 index 37bfab9d66..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py +++ /dev/null @@ -1,147 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import compute_spike_amplitudes -from .utils import get_unit_colors - - -class AmplitudeBaseWidget(BaseWidget): - def __init__(self, waveform_extractor, unit_ids=None, compute_kwargs={}, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - - if self.we.is_extension("spike_amplitudes"): - sac = self.we.load_extension("spike_amplitudes") - self.amplitudes = sac.get_data(outputs="by_unit") - else: - self.amplitudes = compute_spike_amplitudes(self.we, outputs="by_unit", **compute_kwargs) - - if unit_ids is None: - unit_ids = waveform_extractor.sorting.unit_ids - self.unit_ids = unit_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - self._do_plot() - - -class AmplitudeTimeseriesWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - # ~ unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - fs = sorting.get_sampling_frequency() - - # TODO handle segment - ax = self.ax - for i, unit_id in enumerate(self.unit_ids): - for segment_index in range(num_seg): - times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - times = times / fs - amps = self.amplitudes[segment_index][unit_id] - ax.scatter(times, amps, color=self.unit_colors[unit_id], s=3, alpha=1) - - if i == 0: - ax.set_title(f"segment {segment_index}") - if i == len(self.unit_ids) - 1: - ax.set_xlabel("Times [s]") - if segment_index == 0: - ax.set_ylabel(f"Amplitude") - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -class AmplitudeDistributionWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - - ax = self.ax - unit_amps = [] - for i, unit_id in enumerate(unit_ids): - amps = [] - for segment_index in range(num_seg): - amps.append(self.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = self.unit_colors[unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -def plot_amplitudes_timeseries(*args, **kwargs): - W = AmplitudeTimeseriesWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_timeseries.__doc__ = AmplitudeTimeseriesWidget.__doc__ - - -def plot_amplitudes_distribution(*args, **kwargs): - W = AmplitudeDistributionWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_distribution.__doc__ = AmplitudeDistributionWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py deleted file mode 100644 index 8e12559066..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py +++ /dev/null @@ -1,107 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from spikeinterface.postprocessing import compute_correlograms - - -class CrossCorrelogramsWidget(BaseWidget): - """ - Plots spike train cross-correlograms. - The diagonal is auto-correlogram. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - n = len(sorting.unit_ids) - fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - color = "g" - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -def plot_crosscorrelograms(*args, **kwargs): - W = CrossCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_crosscorrelograms.__doc__ = CrossCorrelogramsWidget.__doc__ - - -class AutoCorrelogramsWidget(BaseWidget): - """ - Plots spike train auto-correlograms. - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, ncols=5, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - num_axes = len(sorting.unit_ids) - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - color = "g" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -def plot_autocorrelograms(*args, **kwargs): - W = AutoCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_autocorrelograms.__doc__ = AutoCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py deleted file mode 100644 index a382fee9bc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import get_template_extremum_channel, get_template_extremum_amplitude -from .utils import get_unit_colors - - -class UnitsDepthAmplitudeWidget(BaseWidget): - def __init__(self, waveform_extractor, peak_sign="neg", depth_axis=1, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - self.peak_sign = peak_sign - self.depth_axis = depth_axis - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - ax = self.ax - we = self.we - unit_ids = we.unit_ids - - channels_index = get_template_extremum_channel(we, peak_sign=self.peak_sign, outputs="index") - contact_positions = we.get_channel_locations() - - channel_depth = contact_positions[:, self.depth_axis] - unit_depth = [channel_depth[channels_index[unit_id]] for unit_id in unit_ids] - - unit_amplitude = get_template_extremum_amplitude(we, peak_sign=self.peak_sign) - unit_amplitude = np.abs([unit_amplitude[unit_id] for unit_id in unit_ids]) - - colors = [self.unit_colors[unit_id] for unit_id in unit_ids] - - num_spikes = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - for segment_index in range(we.get_num_segments()): - st = we.sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - num_spikes[i] += st.size - - size = num_spikes / max(num_spikes) * 120 - ax.scatter(unit_amplitude, unit_depth, color=colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(unit_amplitude) * 1.2) - - -def plot_units_depth_vs_amplitude(*args, **kwargs): - W = UnitsDepthAmplitudeWidget(*args, **kwargs) - W.plot() - return W - - -plot_units_depth_vs_amplitude.__doc__ = UnitsDepthAmplitudeWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py index 5856549da3..ab6fa2ace5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py @@ -6,7 +6,7 @@ import scipy.spatial -class TimeseriesWidget(BaseWidget): +class TracesWidget(BaseWidget): """ Plots recording timeseries. @@ -46,7 +46,7 @@ class TimeseriesWidget(BaseWidget): Returns ------- - W: TimeseriesWidget + W: TracesWidget The output widget """ @@ -225,9 +225,9 @@ def _initialize_stats(self): def plot_timeseries(*args, **kwargs): - W = TimeseriesWidget(*args, **kwargs) + W = TracesWidget(*args, **kwargs) W.plot() return W -plot_timeseries.__doc__ = TimeseriesWidget.__doc__ +plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py deleted file mode 100644 index a2b8beea3f..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import matplotlib.pylab as plt -from .basewidget import BaseWidget - -from probeinterface.plotting import plot_probe - -from spikeinterface.postprocessing import compute_unit_locations - -from .utils import get_unit_colors - - -class UnitLocalizationWidget(BaseWidget): - """ - Plot unit localization on probe. - - Parameters - ---------- - waveform_extractor: WaveformaExtractor - WaveformaExtractorr object - peaks: None or numpy array - Optionally can give already detected peaks - to avoid multiple computation. - method: str default 'center_of_mass' - Method used to estimate unit localization if 'unit_location' is None - method_kwargs: dict - Option for the method - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - with_channel_ids: bool False default - add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__( - self, - waveform_extractor, - method="center_of_mass", - method_kwargs={}, - unit_colors=None, - with_channel_ids=False, - figure=None, - ax=None, - ): - BaseWidget.__init__(self, figure, ax) - - self.waveform_extractor = waveform_extractor - self.method = method - self.method_kwargs = method_kwargs - - if unit_colors is None: - unit_colors = get_unit_colors(waveform_extractor.sorting) - self.unit_colors = unit_colors - - self.with_channel_ids = with_channel_ids - - def plot(self): - we = self.waveform_extractor - unit_ids = we.unit_ids - - if we.is_extension("unit_locations"): - unit_locations = we.load_extension("unit_locations").get_data() - else: - unit_locations = compute_unit_locations(we, method=self.method, **self.method_kwargs) - - ax = self.ax - probegroup = we.get_probegroup() - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if self.with_channel_ids: - text_on_contact = self.waveform_extractor.recording.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - ax.set_title("") - - color = np.array([self.unit_colors[unit_id] for unit_id in unit_ids]) - loc = ax.scatter(unit_locations[:, 0], unit_locations[:, 1], marker="1", color=color, s=80, lw=3) - loc.set_zorder(3) - - -def plot_unit_localization(*args, **kwargs): - W = UnitLocalizationWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_localization.__doc__ = UnitLocalizationWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py deleted file mode 100644 index a1d0589abc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from .utils import get_unit_colors - -from .unitprobemap import plot_unit_probe_map -from .unitwaveformdensitymap_ import plot_unit_waveform_density_map -from .amplitudes import plot_amplitudes_timeseries -from .unitwaveforms_ import plot_unit_waveforms -from .isidistribution import plot_isi_distribution - - -class UnitSummaryWidget(BaseWidget): - """ - Plot a unit summary. - - If amplitudes are alreday computed they are displayed. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object - unit_id: into or str - The unit id to plot the summary of - unit_colors: list or None - Optional matplotlib color for the unit - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: UnitSummaryWidget - The output widget - """ - - def __init__(self, waveform_extractor, unit_id, unit_colors=None, figure=None, ax=None): - assert ax is None - # ~ assert axes is None - - if figure is None: - figure = plt.figure( - constrained_layout=False, - figsize=(15, 7), - ) - - BaseWidget.__init__(self, figure, None) - - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - self.unit_id = unit_id - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - def plot(self): - we = self.waveform_extractor - - fig = self.figure - self.ax.remove() - - if we.is_extension("spike_amplitudes"): - nrows = 3 - else: - nrows = 2 - - gs = fig.add_gridspec(nrows, 6) - - ax = fig.add_subplot(gs[:, 0]) - plot_unit_probe_map(we, unit_ids=[self.unit_id], axes=[ax], colorbar=False) - ax.set_title("") - - ax = fig.add_subplot(gs[0:2, 1:3]) - plot_unit_waveforms(we, unit_ids=[self.unit_id], radius_um=60, axes=[ax], unit_colors=self.unit_colors) - ax.set_title(None) - - ax = fig.add_subplot(gs[0:2, 3:5]) - plot_unit_waveform_density_map(we, unit_ids=[self.unit_id], max_channels=1, ax=ax, same_axis=True) - ax.set_ylabel(None) - - ax = fig.add_subplot(gs[0:2, 5]) - plot_isi_distribution(we.sorting, unit_ids=[self.unit_id], axes=[ax]) - ax.set_title("") - - if we.is_extension("spike_amplitudes"): - ax = fig.add_subplot(gs[-1, 1:]) - plot_amplitudes_timeseries(we, unit_ids=[self.unit_id], ax=ax, unit_colors=self.unit_colors) - ax.set_ylabel(None) - ax.set_title(None) - - fig.suptitle(f"Unit ID: {self.unit_id}") - - -def plot_unit_summary(*args, **kwargs): - W = UnitSummaryWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_summary.__doc__ = UnitSummaryWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py deleted file mode 100644 index c5cbe07a7b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformDensityMapWidget(BaseWidget): - """ - Plots unit waveforms using heat map density. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - same_axis: bool - If True then all density are plot on the same axis and then channels is the union - all channel per units. - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces, only used if channel_locs is True - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - max_channels=None, - radius_um=None, - same_axis=False, - unit_colors=None, - ax=None, - axes=None, - ): - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self.sorting.get_unit_ids() - self.unit_ids = unit_ids - - if channel_ids is None: - channel_ids = self.recording.get_channel_ids() - self.channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.same_axis = same_axis - - if axes is None and ax is None: - if same_axis: - fig, ax = plt.subplots() - axes = None - else: - nrows = len(unit_ids) - fig, axes = plt.subplots(nrows=nrows, squeeze=False) - axes = axes[:, 0] - ax = None - BaseWidget.__init__(self, figure=None, ax=ax, axes=axes) - - def plot(self): - we = self.waveform_extractor - - # channel sparsity - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: np.arange(len(self.channel_ids)) for unit_id in self.unit_ids} - channel_inds = {unit_id: inds for unit_id, inds in channel_inds.items() if unit_id in self.unit_ids} - - if self.same_axis: - # channel union - inds = np.unique(np.concatenate([inds.tolist() for inds in channel_inds.values()])) - channel_inds = {unit_id: inds for unit_id in self.unit_ids} - - # bins - templates = we.get_all_templates(unit_ids=self.unit_ids, mode="median") - bin_min = np.min(templates) * 1.3 - bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) - - # 2d histograms - all_hist2d = None - for unit_index, unit_id in enumerate(self.unit_ids): - chan_inds = channel_inds[unit_id] - - wfs = we.get_waveforms(unit_id) - wfs = wfs[:, :, chan_inds] - - # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 - - if self.same_axis: - if all_hist2d is None: - all_hist2d = hist2d - else: - all_hist2d += hist2d - else: - ax = self.axes[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - if self.same_axis: - ax = self.ax - im = ax.imshow( - all_hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - # plot median - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - template = templates[unit_index, :, chan_inds] - template_flat = template.flatten() - color = self.unit_colors[unit_id] - ax.plot(template_flat, color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * wfs.shape[1], color="w", lw=3) - channel_id = self.recording.channel_ids[chan_ind] - x = i * wfs.shape[1] + wfs.shape[1] // 2 - y = (bin_max + bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -def plot_unit_waveform_density_map(*args, **kwargs): - W = UnitWaveformDensityMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveform_density_map.__doc__ = UnitWaveformDensityMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py deleted file mode 100644 index a1e28bbb82..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py +++ /dev/null @@ -1,218 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformsWidget(BaseWidget): - """ - Plots unit waveforms. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces. - axis_equal: bool - Equal aspect ratio for x and y axis, to visualize the array geometry to scale. - lw: float - Line width for the traces. - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - unit_selected_waveforms: None or dict - A dict key is unit_id and value is the subset of waveforms indices that should be - be displayed - show_all_channels: bool - Show the whole probe if True, or only selected channels if False - The axis to be used. If not given an axis is created - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - plot_waveforms=True, - plot_templates=True, - plot_channels=False, - unit_colors=None, - max_channels=None, - radius_um=None, - ncols=5, - axes=None, - lw=2, - axis_equal=False, - unit_selected_waveforms=None, - set_title=True, - ): - self.waveform_extractor = waveform_extractor - self._recording = waveform_extractor.recording - self._sorting = waveform_extractor.sorting - sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self._sorting.get_unit_ids() - self._unit_ids = unit_ids - if channel_ids is None: - channel_ids = self._recording.get_channel_ids() - self._channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self._sorting) - self.unit_colors = unit_colors - - self.ncols = ncols - self._plot_waveforms = plot_waveforms - self._plot_templates = plot_templates - self._plot_channels = plot_channels - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.unit_selected_waveforms = unit_selected_waveforms - - # TODO - self._lw = lw - self._axis_equal = axis_equal - - self._set_title = set_title - - if axes is None: - num_axes = len(unit_ids) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - self._do_plot() - - def _do_plot(self): - we = self.waveform_extractor - unit_ids = self._unit_ids - channel_ids = self._channel_ids - - channel_locations = self._recording.get_channel_locations(channel_ids=channel_ids) - templates = we.get_all_templates(unit_ids=unit_ids) - - xvectors, y_scale, y_offset = get_waveforms_scales(we, templates, channel_locations) - - ncols = min(self.ncols, len(unit_ids)) - nrows = int(np.ceil(len(unit_ids) / ncols)) - - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: slice(None) for unit_id in unit_ids} - - for i, unit_id in enumerate(unit_ids): - ax = self.axes.flatten()[i] - color = self.unit_colors[unit_id] - - chan_inds = channel_inds[unit_id] - xvectors_flat = xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if self._plot_waveforms: - wfs = we.get_waveforms(unit_id) - if self.unit_selected_waveforms is not None: - wfs = wfs[self.unit_selected_waveforms[unit_id]][:, :, chan_inds] - else: - wfs = wfs[:, :, chan_inds] - wfs = wfs * y_scale + y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - ax.plot(xvectors_flat, wfs_flat, lw=1, alpha=0.3, color=color) - - # plot template - if self._plot_templates: - template = templates[i, :, :][:, chan_inds] * y_scale + y_offset[:, chan_inds] - if self._plot_waveforms and self._plot_templates: - color = "k" - ax.plot(xvectors_flat, template.T.flatten(), lw=1, color=color) - template_label = unit_ids[i] - ax.set_title(f"template {template_label}") - - # plot channels - if self._plot_channels: - # TODO enhance this - ax.scatter(channel_locations[:, 0], channel_locations[:, 1], color="k") - - -def get_waveforms_scales(we, templates, channel_locations): - """ - Return scales and x_vector for templates plotting - """ - wf_max = np.max(templates) - wf_min = np.max(templates) - - x_chans = np.unique(channel_locations[:, 0]) - if x_chans.size > 1: - delta_x = np.min(np.diff(x_chans)) - else: - delta_x = 40.0 - - y_chans = np.unique(channel_locations[:, 1]) - if y_chans.size > 1: - delta_y = np.min(np.diff(y_chans)) - else: - delta_y = 40.0 - - m = max(np.abs(wf_max), np.abs(wf_min)) - y_scale = delta_y / m * 0.7 - - y_offset = channel_locations[:, 1][None, :] - - xvect = delta_x * (np.arange(we.nsamples) - we.nbefore) / we.nsamples * 0.7 - - xvectors = channel_locations[:, 0][None, :] + xvect[:, None] - # put nan for discontinuity - xvectors[-1, :] = np.nan - - return xvectors, y_scale, y_offset - - -def plot_unit_waveforms(*args, **kwargs): - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveforms.__doc__ = UnitWaveformsWidget.__doc__ - - -def plot_unit_templates(*args, **kwargs): - kwargs["plot_waveforms"] = False - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_templates.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index d1a0acfe1e..e8b25f6823 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -21,8 +21,6 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Dict of colors with key: unit, value: color, default None """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs ): @@ -47,3 +45,37 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + + unit_amps = [] + for i, unit_id in enumerate(dp.unit_ids): + amps = [] + for segment_index in range(dp.num_segments): + amps.append(dp.amplitudes[segment_index][unit_id]) + amps = np.concatenate(amps) + unit_amps.append(amps) + parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) + + for i, pc in enumerate(parts["bodies"]): + color = dp.unit_colors[dp.unit_ids[i]] + pc.set_facecolor(color) + pc.set_edgecolor("black") + pc.set_alpha(1) + + ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) + ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) + + ylims = ax.get_ylim() + if np.max(ylims) < 0: + ax.set_ylim(min(ylims), 0) + if np.min(ylims) > 0: + ax.set_ylim(0, max(ylims)) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 833bdf2b06..7ef6e0ff61 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -35,8 +35,6 @@ class AmplitudesWidget(BaseWidget): True includes legend in plot, default True """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -112,3 +110,162 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if dp.plot_histograms: + assert np.asarray(axes).size == 2 + else: + assert np.asarray(axes).size == 1 + elif backend_kwargs["ax"] is not None: + assert not dp.plot_histograms + else: + if dp.plot_histograms: + backend_kwargs["num_axes"] = 2 + backend_kwargs["ncols"] = 2 + else: + backend_kwargs["num_axes"] = None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + scatter_ax = self.axes.flatten()[0] + + for unit_id in dp.unit_ids: + spiketrains = dp.spiketrains[unit_id] + amps = dp.amplitudes[unit_id] + scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) + + if dp.plot_histograms: + if dp.bins is None: + bins = int(len(spiketrains) / 30) + else: + bins = dp.bins + ax_hist = self.axes.flatten()[1] + ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + + if dp.plot_histograms: + ax_hist = self.axes.flatten()[1] + ax_hist.set_ylim(scatter_ax.get_ylim()) + ax_hist.axis("off") + self.figure.tight_layout() + + if dp.plot_legend: + if hasattr(self, "legend") and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + scatter_ax.set_xlim(0, dp.total_duration) + scatter_ax.set_xlabel("Times [s]") + scatter_ax.set_ylabel(f"Amplitude") + scatter_ax.spines["top"].set_visible(False) + scatter_ax.spines["right"].set_visible(False) + self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + we = data_plot["waveform_extractor"] + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + plot_histograms = widgets.Checkbox( + value=data_plot["plot_histograms"], + description="plot histograms", + disabled=False, + ) + + footer = plot_histograms + + self.controller = {"plot_histograms": plot_histograms} + self.controller.update(unit_controller) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + footer=footer, + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + self.figure.clear() + + unit_ids = self.controller["unit_ids"].value + plot_histograms = self.controller["plot_histograms"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_histograms"] = plot_histograms + + backend_kwargs = {} + # backend_kwargs["figure"] = self.fig + backend_kwargs["figure"] = self.figure + backend_kwargs["axes"] = None + backend_kwargs["ax"] = None + + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + unit_ids = make_serializable(dp.unit_ids) + + sa_items = [ + vv.SpikeAmplitudesItem( + unit_id=u, + spike_times_sec=dp.spiketrains[u].astype("float32"), + spike_amplitudes=dp.amplitudes[u].astype("float32"), + ) + for u in unit_ids + ] + + self.view = vv.SpikeAmplitudes( + start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + ) + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index 701817e168..e98abbed8f 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -1,11 +1,59 @@ +from .base import BaseWidget, to_attr + from .crosscorrelograms import CrossCorrelogramsWidget class AutoCorrelogramsWidget(CrossCorrelogramsWidget): - possible_backends = {} + # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): CrossCorrelogramsWidget.__init__(self, *args, **kargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + backend_kwargs["num_axes"] = len(dp.unit_ids) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id in enumerate(unit_ids): + ccg = correlograms[i, i] + ax = self.axes.flatten()[i] + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id] + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + ax.set_title(str(unit_id)) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + unit_ids = make_serializable(dp.unit_ids) + + ac_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + if i == j: + ac_items.append( + vv.AutocorrelogramItem( + unit_id=unit_ids[i], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.Autocorrelograms(autocorrelograms=ac_items) + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) + AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 9a914bf28d..dea46b8f51 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -19,43 +19,83 @@ def set_default_plotter_backend(backend): default_backend_ = backend -class BaseWidget: - # this need to be reset in the subclass - possible_backends = None +backend_kwargs_desc = { + "matplotlib": { + "figure": "Matplotlib figure. When None, it is created. Default None", + "ax": "Single matplotlib axis. When None, it is created. Default None", + "axes": "Multiple matplotlib axes. When None, they is created. Default None", + "ncols": "Number of columns to create in subplots. Default 5", + "figsize": "Size of matplotlib figure. Default None", + "figtitle": "The figure title. Default None", + }, + "sortingview": { + "generate_url": "If True, the figurl URL is generated and printed. Default True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", + "figlabel": "The figurl figure label. Default None", + "height": "The height of the sortingview View in jupyter. Default None", + }, + "ipywidgets": { + "width_cm": "Width of the figure in cm (default 10)", + "height_cm": "Height of the figure in cm (default 6)", + "display": "If True, widgets are immediately displayed", + }, +} + +default_backend_kwargs = { + "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, + "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, +} + - def __init__(self, plot_data=None, backend=None, **backend_kwargs): +class BaseWidget: + def __init__( + self, + data_plot=None, + backend=None, + immediate_plot=True, + **backend_kwargs, + ): # every widgets must prepare a dict "plot_data" in the init - self.plot_data = plot_data + self.data_plot = data_plot + backend = self.check_backend(backend) self.backend = backend - self.backend_kwargs = backend_kwargs - def check_backend(self, backend): - if backend is None: - backend = get_default_plotter_backend() - assert backend in self.possible_backends, ( - f"{backend} backend not available! Available backends are: " f"{list(self.possible_backends.keys())}" - ) - return backend - - def check_backend_kwargs(self, plotter, backend, **backend_kwargs): - plotter_kwargs = plotter.default_backend_kwargs + # check backend kwargs for k in backend_kwargs: - if k not in plotter_kwargs: + if k not in default_backend_kwargs[backend]: raise Exception( f"{k} is not a valid plot argument or backend keyword argument. " - f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + f"Possible backend keyword arguments for {backend} are: {list(default_backend_kwargs[backend].keys())}" ) + backend_kwargs_ = default_backend_kwargs[self.backend].copy() + backend_kwargs_.update(backend_kwargs) - def do_plot(self, backend, **backend_kwargs): - backend = self.check_backend(backend) - plotter = self.possible_backends[backend]() - self.check_backend_kwargs(plotter, backend, **backend_kwargs) - plotter.do_plot(self.plot_data, **backend_kwargs) - self.plotter = plotter + self.backend_kwargs = backend_kwargs_ + + if immediate_plot: + self.do_plot() + + # subclass must define one method per supported backend: + # def plot_matplotlib(self, data_plot, **backend_kwargs): + # def plot_ipywidgets(self, data_plot, **backend_kwargs): + # def plot_sortingview(self, data_plot, **backend_kwargs): @classmethod - def register_backend(cls, backend_plotter): - cls.possible_backends[backend_plotter.backend] = backend_plotter + def get_possible_backends(cls): + return [k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}")] + + def check_backend(self, backend): + if backend is None: + backend = get_default_plotter_backend() + assert backend in self.get_possible_backends(), ( + f"{backend} backend not available! Available backends are: " f"{self.get_possible_backends()}" + ) + return backend + + def do_plot(self): + func = getattr(self, f"plot_{self.backend}") + func(self.data_plot, **self.backend_kwargs) @staticmethod def check_extensions(waveform_extractor, extensions): @@ -74,27 +114,6 @@ def check_extensions(waveform_extractor, extensions): raise Exception(error_msg) -class BackendPlotter: - backend = "" - - @classmethod - def register(cls, widget_cls): - widget_cls.register_backend(cls) - - def update_backend_kwargs(self, **backend_kwargs): - backend_kwargs_ = self.default_backend_kwargs.copy() - backend_kwargs_.update(backend_kwargs) - return backend_kwargs_ - - -def copy_signature(source_fct): - def copy(target_fct): - target_fct.__signature__ = inspect.signature(source_fct) - return target_fct - - return copy - - class to_attr(object): def __init__(self, d): """ @@ -111,16 +130,3 @@ def __init__(self, d): def __getattribute__(self, k): d = object.__getattribute__(self, "__d") return d[k] - - -def define_widget_function_from_class(widget_class, name): - @copy_signature(widget_class) - def widget_func(*args, **kwargs): - W = widget_class(*args, **kwargs) - W.do_plot(W.backend, **W.backend_kwargs) - return W.plotter - - widget_func.__doc__ = widget_class.__doc__ - widget_func.__name__ = name - - return widget_func diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 8481c8ef0d..3ec3fa11b6 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor from ..core.basesorting import BaseSorting from ..postprocessing import compute_correlograms @@ -27,8 +27,6 @@ class CrossCorrelogramsWidget(BaseWidget): If given, a dictionary with unit ids as keys and colors as values, default None """ - possible_backends = {} - def __init__( self, waveform_or_sorting_extractor: Union[WaveformExtractor, BaseSorting], @@ -65,3 +63,61 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + backend_kwargs["ncols"] = len(dp.unit_ids) + backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id1 in enumerate(unit_ids): + for j, unit_id2 in enumerate(unit_ids): + ccg = correlograms[i, j] + ax = self.axes[i, j] + if i == j: + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id1] + else: + color = "k" + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + + for i, unit_id in enumerate(unit_ids): + self.axes[0, i].set_title(str(unit_id)) + self.axes[-1, i].set_xlabel("CCG (ms)") + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + unit_ids = make_serializable(dp.unit_ids) + + cc_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + cc_items.append( + vv.CrossCorrelogramItem( + unit_id1=unit_ids[i], + unit_id2=unit_ids[j], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector) + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/ipywidgets/__init__.py b/src/spikeinterface/widgets/ipywidgets/__init__.py deleted file mode 100644 index 63d1b3a486..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .quality_metrics import QualityMetricsPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter -from .unit_waveforms import UnitWaveformPlotter diff --git a/src/spikeinterface/widgets/ipywidgets/amplitudes.py b/src/spikeinterface/widgets/ipywidgets/amplitudes.py deleted file mode 100644 index dc55b927e0..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/amplitudes.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..amplitudes import AmplitudesWidget -from ..matplotlib.amplitudes import AmplitudesPlotter as MplAmplitudesPlotter - -from IPython.display import display - - -class AmplitudesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - plot_histograms = widgets.Checkbox( - value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, - ) - - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - mpl_plotter = MplAmplitudesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -AmplitudesPlotter.register(AmplitudesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig.clear() - - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py b/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py deleted file mode 100644 index e0eff7f330..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py +++ /dev/null @@ -1,20 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib import gridspec -import numpy as np - - -class IpywidgetsPlotter(BackendPlotter): - backend = "ipywidgets" - backend_kwargs_desc = { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", - } - default_backend_kwargs = {"width_cm": 25, "height_cm": 10, "display": True} - - def check_backend(self): - mpl_backend = mpl.get_backend() - assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" diff --git a/src/spikeinterface/widgets/ipywidgets/metrics.py b/src/spikeinterface/widgets/ipywidgets/metrics.py deleted file mode 100644 index ba6859b2a1..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/metrics.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from matplotlib.lines import Line2D - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..matplotlib.metrics import MetricsPlotter as MplMetricsPlotter - -from IPython.display import display - - -class MetricsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplMetricsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - self.unit_colors = data_plot["unit_colors"] - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - all_units = list(self.unit_colors.keys()) - colors = [] - sizes = [] - for unit in all_units: - color = "gray" if unit not in unit_ids else self.unit_colors[unit] - size = 1 if unit not in unit_ids else 5 - colors.append(color) - sizes.append(size) - - # here we do a trick: we just update colors - if hasattr(self.mpl_plotter, "patches"): - for p in self.mpl_plotter.patches: - p.set_color(colors) - p.set_sizes(sizes) - else: - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) - - if len(unit_ids) > 0: - for l in self.fig.legends: - l.remove() - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=self.unit_colors[unit]) - for unit in unit_ids - ] - labels = unit_ids - self.fig.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py b/src/spikeinterface/widgets/ipywidgets/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/spike_locations.py b/src/spikeinterface/widgets/ipywidgets/spike_locations.py deleted file mode 100644 index 633eb0ac39..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spike_locations.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..spike_locations import SpikeLocationsWidget -from ..matplotlib.spike_locations import ( - SpikeLocationsPlotter as MplSpikeLocationsPlotter, -) - -from IPython.display import display - - -class SpikeLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - mpl_plotter = MplSpikeLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py b/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py deleted file mode 100644 index e5a3ebcc71..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py +++ /dev/null @@ -1,145 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from .base_ipywidgets import IpywidgetsPlotter -from .timeseries import TimeseriesPlotter -from .utils import make_unit_controller - -from ..spikes_on_traces import SpikesOnTracesWidget -from ..matplotlib.spikes_on_traces import SpikesOnTracesPlotter as MplSpikesOnTracesPlotter - -from IPython.display import display - - -class SpikesOnTracesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - ratios = [0.2, 0.8] - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs_ts = backend_kwargs.copy() - backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] - backend_kwargs_ts["display"] = False - height_cm = backend_kwargs["height_cm"] - width_cm = backend_kwargs["width_cm"] - - # plot timeseries - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) - - ts_w = tsplotter.widget - ts_updater = tsplotter.updater - - we = data_plot["waveform_extractor"] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - self.controller = ts_updater.controller - self.controller.update(unit_controller) - - mpl_plotter = MplSpikesOnTracesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout(center=ts_w, left_sidebar=unit_widget, pane_widths=ratios + [0]) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ts_updater, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ts_updater = ts_updater - self.ax = ts_updater.ax - self.fig = self.ax.figure - self.controller = controller - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # update ts - # self.ts_updater.__call__(change) - - # update data plot - data_plot = self.data_plot.copy() - data_plot["timeseries"] = self.ts_updater.next_data_plot - data_plot["unit_ids"] = unit_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() - - # t = self.time_slider.value - # d = self.win_sizer.value - - # selected_layer = self.layer_selector.value - # segment_index = self.seg_selector.value - # mode = self.mode_selector.value - - # t_stop = self.t_stops[segment_index] - # if self.actual_segment_index != segment_index: - # # change time_slider limits - # self.time_slider.max = t_stop - # self.actual_segment_index = segment_index - - # # protect limits - # if t >= t_stop - d: - # t = t_stop - d - - # time_range = np.array([t, t+d]) - - # if mode =='line': - # # plot all layer - # layer_keys = self.data_plot['layer_keys'] - # recordings = self.recordings - # clims = None - # elif mode =='map': - # layer_keys = [selected_layer] - # recordings = {selected_layer: self.recordings[selected_layer]} - # clims = {selected_layer: self.data_plot["clims"][selected_layer]} - - # channel_ids = self.data_plot['channel_ids'] - # order = self.data_plot['order'] - # times, list_traces, frame_range, order = _get_trace_list(recordings, channel_ids, time_range, order, - # segment_index) - - # # matplotlib next_data_plot dict update at each call - # data_plot = self.next_data_plot - # data_plot['mode'] = mode - # data_plot['frame_range'] = frame_range - # data_plot['time_range'] = time_range - # data_plot['with_colorbar'] = False - # data_plot['recordings'] = recordings - # data_plot['layer_keys'] = layer_keys - # data_plot['list_traces'] = list_traces - # data_plot['times'] = times - # data_plot['clims'] = clims - - # backend_kwargs = {} - # backend_kwargs['ax'] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - # fig = self.ax.figure - # fig.canvas.draw() - # fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/template_metrics.py b/src/spikeinterface/widgets/ipywidgets/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/timeseries.py b/src/spikeinterface/widgets/ipywidgets/timeseries.py deleted file mode 100644 index 2448166f16..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/timeseries.py +++ /dev/null @@ -1,232 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from ...core import order_channels_by_depth - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_timeseries_controller, make_channel_controller, make_scale_controller - -from ..timeseries import TimeseriesWidget, _get_trace_list -from ..matplotlib.timeseries import TimeseriesPlotter as MplTimeseriesPlotter - -from IPython.display import display - - -class TimeseriesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - recordings = data_plot["recordings"] - - # first layer - rec0 = recordings[data_plot["layer_keys"][0]] - - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - ratios = [0.1, 0.8, 0.2] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) - plt.show() - - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) - - mpl_plotter = MplTimeseriesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self.updater) - else: - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - footer=ts_widget, - left_sidebar=scale_widget, - right_sidebar=ch_widget, - pane_heights=[0, 6, 1], - pane_widths=ratios, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -TimeseriesPlotter.register(TimeseriesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ax = ax - self.controller = controller - - self.recordings = data_plot["recordings"] - self.return_scaled = data_plot["return_scaled"] - self.next_data_plot = data_plot.copy() - self.list_traces = None - - self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - def __call__(self, change): - self.ax.clear() - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot - - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] - - if mode == "line": - clims = None - elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} - - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times - data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_locations.py b/src/spikeinterface/widgets/ipywidgets/unit_locations.py deleted file mode 100644 index e78c0d8fe5..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_locations.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_locations import UnitLocationsWidget -from ..matplotlib.unit_locations import UnitLocationsPlotter as MplUnitLocationsPlotter - -from IPython.display import display - - -class UnitLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplUnitLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitLocationsPlotter.register(UnitLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_templates.py b/src/spikeinterface/widgets/ipywidgets/unit_templates.py deleted file mode 100644 index 41da9d8cd3..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_templates.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - def do_plot(self, data_plot, **backend_kwargs): - super().do_plot(data_plot, **backend_kwargs) - self.controller["plot_templates"].layout.visibility = "hidden" - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py b/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py deleted file mode 100644 index 012b46038a..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_waveforms import UnitWaveformsWidget -from ..matplotlib.unit_waveforms import UnitWaveformPlotter as MplUnitWaveformPlotter - -from IPython.display import display - - -class UnitWaveformPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.1, 0.7, 0.2] - - with plt.ioff(): - output1 = widgets.Output() - with output1: - fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - output2 = widgets.Output() - with output2: - fig_probe, ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - same_axis_button = widgets.Checkbox( - value=False, - description="same axis", - disabled=False, - ) - - plot_templates_button = widgets.Checkbox( - value=True, - description="plot templates", - disabled=False, - ) - - hide_axis_button = widgets.Checkbox( - value=True, - description="hide axis", - disabled=False, - ) - - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - mpl_plotter = MplUnitWaveformPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig_wf.canvas, - left_sidebar=unit_widget, - right_sidebar=fig_probe.canvas, - pane_widths=ratios, - footer=footer, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig_wf, ax_probe, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig_wf = fig_wf - self.ax_probe = ax_probe - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig_wf.clear() - self.ax_probe.clear() - - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) - data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") - data_plot["same_axis"] = same_axis - data_plot["plot_templates"] = plot_templates - if data_plot["plot_waveforms"]: - data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - - backend_kwargs = {} - - if same_axis: - backend_kwargs["ax"] = self.fig_wf.add_subplot() - data_plot["set_title"] = False - else: - backend_kwargs["figure"] = self.fig_wf - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - if same_axis: - self.mpl_plotter.ax.axis("equal") - if hide_axis: - self.mpl_plotter.ax.axis("off") - else: - if hide_axis: - for i in range(len(unit_ids)): - ax = self.mpl_plotter.axes.flatten()[i] - ax.axis("off") - - # update probe plot - channel_locations = self.we.get_channel_locations() - self.ax_probe.plot( - channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 - ) - self.ax_probe.axis("off") - self.ax_probe.axis("equal") - - for unit in unit_ids: - channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] - self.ax_probe.plot( - channel_locations[channel_inds, 0], - channel_locations[channel_inds, 1], - ls="", - marker="o", - markersize=3, - color=self.next_data_plot["unit_colors"][unit], - ) - self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) - fig_probe = self.ax_probe.get_figure() - - self.fig_wf.canvas.draw() - self.fig_wf.canvas.flush_events() - fig_probe.canvas.draw() - fig_probe.canvas.flush_events() diff --git a/src/spikeinterface/widgets/matplotlib/__init__.py b/src/spikeinterface/widgets/matplotlib/__init__.py deleted file mode 100644 index 525396e30d..0000000000 --- a/src/spikeinterface/widgets/matplotlib/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .all_amplitudes_distributions import AllAmplitudesDistributionsPlotter -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .motion import MotionPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter -from .unit_depths import UnitDepthsPlotter -from .unit_summary import UnitSummaryPlotter diff --git a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py b/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py deleted file mode 100644 index 6985d2167a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..all_amplitudes_distributions import AllAmplitudesDistributionsWidget -from .base_mpl import MplPlotter - - -class AllAmplitudesDistributionsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - - unit_amps = [] - for i, unit_id in enumerate(dp.unit_ids): - amps = [] - for segment_index in range(dp.num_segments): - amps.append(dp.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = dp.unit_colors[dp.unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -AllAmplitudesDistributionsPlotter.register(AllAmplitudesDistributionsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/amplitudes.py b/src/spikeinterface/widgets/matplotlib/amplitudes.py deleted file mode 100644 index 747709211a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/amplitudes.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_mpl import MplPlotter - - -class AmplitudesPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - axes = backend_kwargs["axes"] - if dp.plot_histograms: - assert np.asarray(axes).size == 2 - else: - assert np.asarray(axes).size == 1 - elif backend_kwargs["ax"] is not None: - assert not dp.plot_histograms - else: - if dp.plot_histograms: - backend_kwargs["num_axes"] = 2 - backend_kwargs["ncols"] = 2 - else: - backend_kwargs["num_axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - scatter_ax = self.axes.flatten()[0] - - for unit_id in dp.unit_ids: - spiketrains = dp.spiketrains[unit_id] - amps = dp.amplitudes[unit_id] - scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) - - if dp.plot_histograms: - if dp.bins is None: - bins = int(len(spiketrains) / 30) - else: - bins = dp.bins - ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) - - if dp.plot_histograms: - ax_hist = self.axes.flatten()[1] - ax_hist.set_ylim(scatter_ax.get_ylim()) - ax_hist.axis("off") - self.figure.tight_layout() - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - scatter_ax.set_xlim(0, dp.total_duration) - scatter_ax.set_xlabel("Times [s]") - scatter_ax.set_ylabel(f"Amplitude") - scatter_ax.spines["top"].set_visible(False) - scatter_ax.spines["right"].set_visible(False) - self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py b/src/spikeinterface/widgets/matplotlib/autocorrelograms.py deleted file mode 100644 index 9245ef6881..0000000000 --- a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py +++ /dev/null @@ -1,30 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_mpl import MplPlotter - - -class AutoCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = len(dp.unit_ids) - - self.make_mpl_figure(**backend_kwargs) - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id] - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/base_mpl.py b/src/spikeinterface/widgets/matplotlib/base_mpl.py deleted file mode 100644 index 266adc8782..0000000000 --- a/src/spikeinterface/widgets/matplotlib/base_mpl.py +++ /dev/null @@ -1,102 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib.pyplot as plt -import numpy as np - - -class MplPlotter(BackendPlotter): - backend = "matplotlib" - backend_kwargs_desc = { - "figure": "Matplotlib figure. When None, it is created. Default None", - "ax": "Single matplotlib axis. When None, it is created. Default None", - "axes": "Multiple matplotlib axes. When None, they is created. Default None", - "ncols": "Number of columns to create in subplots. Default 5", - "figsize": "Size of matplotlib figure. Default None", - "figtitle": "The figure title. Default None", - } - default_backend_kwargs = {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None} - - def make_mpl_figure(self, figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): - """ - figure/ax/axes : only one of then can be not None - """ - if figure is not None: - assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" - if num_axes is None: - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - axes = [] - nrows = int(np.ceil(num_axes / ncols)) - axes = np.full((nrows, ncols), fill_value=None, dtype=object) - for i in range(num_axes): - ax = figure.add_subplot(nrows, ncols, i + 1) - r = i // ncols - c = i % ncols - axes[r, c] = ax - elif ax is not None: - assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" - figure = ax.get_figure() - axes = np.array([[ax]]) - elif axes is not None: - assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" - axes = np.asarray(axes) - figure = axes.flatten()[0].get_figure() - else: - # 'figure/ax/axes are all None - if num_axes is None: - # one fig with one ax - figure, ax = plt.subplots(figsize=figsize) - axes = np.array([[ax]]) - else: - if num_axes == 0: - # one figure without plots (diffred subplot creation with - figure = plt.figure(figsize=figsize) - ax = None - axes = None - elif num_axes == 1: - figure = plt.figure(figsize=figsize) - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - if num_axes < ncols: - ncols = num_axes - nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) - ax = None - # remove extra axes - if ncols * nrows > num_axes: - for i, extra_ax in enumerate(axes.flatten()): - if i >= num_axes: - extra_ax.remove() - r = i // ncols - c = i % ncols - axes[r, c] = None - - self.figure = figure - self.ax = ax - # axes is always a 2D array of ax - self.axes = axes - - if figtitle is not None: - self.figure.suptitle(figtitle) - - -class to_attr(object): - def __init__(self, d): - """ - Helper function that transform a dict into - an object where attributes are the keys of the dict - - d = {'a': 1, 'b': 'yep'} - o = to_attr(d) - print(o.a, o.b) - """ - object.__init__(self) - object.__setattr__(self, "__d", d) - - def __getattribute__(self, k): - d = object.__getattribute__(self, "__d") - return d[k] diff --git a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py b/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py deleted file mode 100644 index 24ecdcdffc..0000000000 --- a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py +++ /dev/null @@ -1,39 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_mpl import MplPlotter - - -class CrossCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["ncols"] = len(dp.unit_ids) - backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id1] - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/metrics.py b/src/spikeinterface/widgets/matplotlib/metrics.py deleted file mode 100644 index cec4c11644..0000000000 --- a/src/spikeinterface/widgets/matplotlib/metrics.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np - -from ..base import to_attr -from .base_mpl import MplPlotter - - -class MetricsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - metrics = dp.metrics - num_metrics = len(metrics.columns) - - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = num_metrics**2 - backend_kwargs["ncols"] = num_metrics - - all_unit_ids = metrics.index.values - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - if dp.unit_ids is None: - colors = ["gray"] * len(all_unit_ids) - else: - colors = [] - for unit in all_unit_ids: - color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] - colors.append(color) - - self.patches = [] - for i, m1 in enumerate(metrics.columns): - for j, m2 in enumerate(metrics.columns): - if i == j: - self.axes[i, j].hist(metrics[m1], color="gray") - else: - p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") - self.patches.append(p) - if i == num_metrics - 1: - self.axes[i, j].set_xlabel(m2, fontsize=10) - if j == 0: - self.axes[i, j].set_ylabel(m1, fontsize=10) - self.axes[i, j].set_xticklabels([]) - self.axes[i, j].set_yticklabels([]) - self.axes[i, j].spines["top"].set_visible(False) - self.axes[i, j].spines["right"].set_visible(False) - - self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py deleted file mode 100644 index abf02f4697..0000000000 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ /dev/null @@ -1,104 +0,0 @@ -from ..base import to_attr -from ..motion import MotionWidget -from .base_mpl import MplPlotter - -import numpy as np - - -class MotionPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None - - self.make_mpl_figure(**backend_kwargs) - fig = self.figure - fig.clear() - - is_rigid = dp.motion.shape[1] == 1 - - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) - ax0 = fig.add_subplot(gs[0, 0]) - ax1 = fig.add_subplot(gs[0, 1]) - ax2 = fig.add_subplot(gs[1, 0]) - if not is_rigid: - ax3 = fig.add_subplot(gs[1, 1]) - ax1.sharex(ax0) - ax1.sharey(ax0) - - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - - corrected_location = correct_motion_on_peaks( - dp.peaks, dp.peak_locations, dp.rec.get_times(), dp.motion, dp.temporal_bins, dp.spatial_bins, direction="y" - ) - - x = dp.peaks["sample_index"] / dp.rec.get_sampling_frequency() - y = dp.peak_locations["y"] - y2 = corrected_location["y"] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = np.abs(dp.peaks["amplitude"]) - amps /= np.quantile(amps, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - c = plt.get_cmap(dp.amplitude_cmap)(amps) - color_kwargs = dict( - color=None, - c=c, - ) # alpha=0.02 - else: - color_kwargs = dict(color="k", c=None) # alpha=0.02 - - ax0.scatter(x, y, s=1, **color_kwargs) - # for i in range(dp.motion.shape[1]): - # ax0.plot(dp.temporal_bins, dp.motion[:, i] + dp.spatial_bins[i], color="C8", alpha=1.0) - if dp.depth_lim is not None: - ax0.set_ylim(*dp.depth_lim) - ax0.set_title("Peak depth") - ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") - - ax1.scatter(x, y2, s=1, **color_kwargs) - ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") - ax1.set_title("Corrected peak depth") - - ax2.plot(dp.temporal_bins, dp.motion, alpha=0.2, color="black") - ax2.plot(dp.temporal_bins, np.mean(dp.motion, axis=1), color="C0") - ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("motion [um]") - ax2.set_title("Motion vectors") - - if not is_rigid: - im = ax3.imshow( - dp.motion.T, - aspect="auto", - origin="lower", - extent=( - dp.temporal_bins[0], - dp.temporal_bins[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], - ), - ) - im.set_clim(-motion_lim, motion_lim) - cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") - ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") - ax3.set_title("Motion vectors") - - -MotionPlotter.register(MotionWidget) diff --git a/src/spikeinterface/widgets/matplotlib/quality_metrics.py b/src/spikeinterface/widgets/matplotlib/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spike_locations.py b/src/spikeinterface/widgets/matplotlib/spike_locations.py deleted file mode 100644 index 5c74df3fc8..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spike_locations.py +++ /dev/null @@ -1,96 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np - -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikeLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - spike_locations = dp.spike_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - for i, unit in enumerate(unit_ids): - locs = spike_locations[unit] - - zorder = 5 if unit in dp.unit_ids else 3 - self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) - - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - # set proper axis limits - xlims, ylims = estimate_axis_lims(spike_locations) - - ax_xlims = list(self.ax.get_xlim()) - ax_ylims = list(self.ax.get_ylim()) - - ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] - ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] - ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] - ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] - - self.ax.set_xlim(ax_xlims) - self.ax.set_ylim(ax_ylims) - if dp.hide_axis: - self.ax.axis("off") - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py b/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py deleted file mode 100644 index d620c8f28f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..spikes_on_traces import SpikesOnTracesWidget -from .base_mpl import MplPlotter -from .timeseries import TimeseriesPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikesOnTracesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # first plot time series - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(dp.timeseries, **backend_kwargs) - self.ax = tsplotter.ax - self.axes = tsplotter.axes - self.figure = tsplotter.figure - - ax = self.ax - - we = dp.waveform_extractor - sorting = dp.waveform_extractor.sorting - frame_range = dp.timeseries["frame_range"] - segment_index = dp.timeseries["segment_index"] - min_y = np.min(dp.timeseries["channel_locations"][:, 1]) - max_y = np.max(dp.timeseries["channel_locations"][:, 1]) - - n = len(dp.timeseries["channel_ids"]) - order = dp.timeseries["order"] - if order is None: - order = np.arange(n) - - if ax.get_legend() is not None: - ax.get_legend().remove() - - # loop through units and plot a scatter of spikes at estimated location - handles = [] - labels = [] - - for unit in dp.unit_ids: - spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) - spike_start, spike_end = np.searchsorted(spike_frames, frame_range) - - chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] - - spike_frames_to_plot = spike_frames[spike_start:spike_end] - - if dp.timeseries["mode"] == "map": - spike_times_to_plot = sorting.get_unit_spike_train( - unit, segment_index=segment_index, return_times=True - )[spike_start:spike_end] - unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] - # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) - width = 2 * 1e-3 - ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) - patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] - for p in patches: - ax.add_patch(p) - handles.append( - Line2D( - [0], - [0], - ls="", - marker="o", - markersize=5, - markeredgewidth=2, - markeredgecolor=dp.unit_colors[unit], - markerfacecolor="none", - ) - ) - labels.append(unit) - else: - # construct waveforms - label_set = False - if len(spike_frames_to_plot) > 0: - vspacing = dp.timeseries["vspacing"] - traces = dp.timeseries["list_traces"][0] - waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) - - times = dp.timeseries["times"][waveform_idxs] - # discontinuity - times[:, -1] = np.nan - times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] - waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - - for i, chan_id in enumerate(dp.timeseries["channel_ids"]): - offset = vspacing * i - if chan_id in chan_ids: - l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) - if not label_set: - handles.append(l[0]) - labels.append(unit) - label_set = True - ax.legend(handles, labels) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_metrics.py b/src/spikeinterface/widgets/matplotlib/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_similarity.py b/src/spikeinterface/widgets/matplotlib/template_similarity.py deleted file mode 100644 index 1e0a2e6fae..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_similarity.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_mpl import MplPlotter - - -class TemplateSimilarityPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - im = self.ax.matshow(dp.similarity, cmap=dp.cmap) - - if dp.show_unit_ticks: - # Major ticks - self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) - self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - self.ax.set_yticklabels(dp.unit_ids, fontsize=12) - self.ax.set_xticklabels(dp.unit_ids, fontsize=12) - if dp.show_colorbar: - self.figure.colorbar(im) - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/matplotlib/timeseries.py b/src/spikeinterface/widgets/matplotlib/timeseries.py deleted file mode 100644 index 0a887b559f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/timeseries.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from .base_mpl import MplPlotter -from matplotlib.ticker import MaxNLocator - - -class TimeseriesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - ax = self.ax - n = len(dp.channel_ids) - if dp.channel_locations is not None: - y_locs = dp.channel_locations[:, 1] - else: - y_locs = np.arange(n) - min_y = np.min(y_locs) - max_y = np.max(y_locs) - - if dp.mode == "line": - offset = dp.vspacing * (n - 1) - - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - for i, chan_id in enumerate(dp.channel_ids): - offset = dp.vspacing * i - color = dp.colors[layer_key][chan_id] - ax.plot(dp.times, offset + traces[:, i], color=color) - ax.get_lines()[-1].set_label(layer_key) - - if dp.show_channel_ids: - ax.set_yticks(np.arange(n) * dp.vspacing) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - ax.set_xlim(*dp.time_range) - ax.set_ylim(-dp.vspacing, dp.vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.set_xlabel("time (s)") - if dp.add_legend: - ax.legend(loc="upper right") - - elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' - assert len(dp.clims) == 1 - clim = list(dp.clims.values())[0] - extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) - im = ax.imshow( - dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap - ) - - im.set_clim(*clim) - - if dp.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if dp.show_channel_ids: - ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_depths.py b/src/spikeinterface/widgets/matplotlib/unit_depths.py deleted file mode 100644 index aa16ff3578..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_depths.py +++ /dev/null @@ -1,22 +0,0 @@ -from ..base import to_attr -from ..unit_depths import UnitDepthsWidget -from .base_mpl import MplPlotter - - -class UnitDepthsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - size = dp.num_spikes / max(dp.num_spikes) * 120 - ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) - - -UnitDepthsPlotter.register(UnitDepthsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_locations.py b/src/spikeinterface/widgets/matplotlib/unit_locations.py deleted file mode 100644 index 6f084c0aec..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_locations.py +++ /dev/null @@ -1,95 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np -from spikeinterface.core import waveform_extractor - -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class UnitLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - unit_locations = dp.unit_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) - width = height = 10 - ellipse_kwargs = dict(width=width, height=height, lw=2) - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - patches = [ - Ellipse( - (unit_locations[unit]), - color=unit_colors[unit], - zorder=5 if unit in dp.unit_ids else 3, - alpha=0.9 if unit in dp.unit_ids else 0.5, - **ellipse_kwargs, - ) - for i, unit in enumerate(unit_ids) - ] - for p in patches: - self.ax.add_patch(p) - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - if dp.hide_axis: - self.ax.axis("off") - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_summary.py b/src/spikeinterface/widgets/matplotlib/unit_summary.py deleted file mode 100644 index 5327afa25e..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_summary.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_summary import UnitSummaryWidget -from .base_mpl import MplPlotter - - -from .unit_locations import UnitLocationsPlotter -from .amplitudes import AmplitudesPlotter -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter - -from .autocorrelograms import AutoCorrelogramsPlotter - - -class UnitSummaryPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # force the figure without axes - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (18, 7) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = 0 - backend_kwargs["ax"] = None - backend_kwargs["axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - # and use custum grid spec - fig = self.figure - nrows = 2 - ncols = 3 - if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: - ncols += 1 - if dp.plot_data_amplitudes is not None: - nrows += 1 - gs = fig.add_gridspec(nrows, ncols) - - if dp.plot_data_unit_locations is not None: - ax1 = fig.add_subplot(gs[:2, 0]) - UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - x, y = dp.unit_location[0], dp.unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) - ax2.set_title(None) - - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) - ax3.set_ylabel(None) - - if dp.plot_data_acc is not None: - ax4 = fig.add_subplot(gs[:2, 3]) - AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) - ax4.set_title(None) - ax4.set_yticks([]) - - if dp.plot_data_amplitudes is not None: - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) - AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) - - fig.suptitle(f"unit_id: {dp.unit_id}") - - -UnitSummaryPlotter.register(UnitSummaryWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_templates.py b/src/spikeinterface/widgets/matplotlib/unit_templates.py deleted file mode 100644 index c1ce085bf2..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_templates.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - pass - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms.py deleted file mode 100644 index f499954918..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py +++ /dev/null @@ -1,95 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms import UnitWaveformsWidget -from .base_mpl import MplPlotter - - -class UnitWaveformPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" - elif backend_kwargs["ax"] is not None: - assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" - else: - if dp.same_axis: - backend_kwargs["num_axes"] = 1 - backend_kwargs["ncols"] = None - else: - backend_kwargs["num_axes"] = len(dp.unit_ids) - backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) - - self.make_mpl_figure(**backend_kwargs) - - for i, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[i] - color = dp.unit_colors[unit_id] - - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] - xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if dp.plot_waveforms: - wfs = dp.wfs_by_ids[unit_id] - if dp.unit_selected_waveforms is not None: - wfs = wfs[dp.unit_selected_waveforms[unit_id]] - elif dp.max_spikes_per_unit is not None: - if len(wfs) > dp.max_spikes_per_unit: - random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] - wfs = wfs[random_idxs] - wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) - - if not dp.plot_templates: - ax.get_lines()[-1].set_label(f"{unit_id}") - - # plot template - if dp.plot_templates: - template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot( - xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id - ) - - template_label = dp.unit_ids[i] - if dp.set_title: - ax.set_title(f"template {template_label}") - - # plot channels - if dp.plot_channels: - # TODO enhance this - ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") - - if dp.same_axis and dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py deleted file mode 100644 index ff9c1ec91b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms_density_map import UnitWaveformDensityMapWidget -from .base_mpl import MplPlotter - - -class UnitWaveformDensityMapPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - self.make_mpl_figure(**backend_kwargs) - else: - if dp.same_axis: - num_axes = 1 - else: - num_axes = len(dp.unit_ids) - backend_kwargs["ncols"] = 1 - backend_kwargs["num_axes"] = num_axes - self.make_mpl_figure(**backend_kwargs) - - if dp.same_axis: - ax = self.ax - hist2d = dp.all_hist2d - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - else: - for unit_index, unit_id in enumerate(dp.unit_ids): - hist2d = dp.all_hist2d[unit_id] - ax = self.axes.flatten()[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[unit_index] - color = dp.unit_colors[unit_id] - ax.plot(dp.templates_flat[unit_id], color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes.flatten()[unit_index] - chan_inds = dp.channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * dp.template_width, color="w", lw=3) - channel_id = dp.channel_ids[chan_ind] - x = i * dp.template_width + dp.template_width // 2 - y = (dp.bin_max + dp.bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -UnitWaveformDensityMapPlotter.register(UnitWaveformDensityMapWidget) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 8e77e4a0f0..9dc51f522e 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -1,8 +1,9 @@ import warnings import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors +from ..core.core_tools import check_json class MetricsBaseWidget(BaseWidget): @@ -29,8 +30,6 @@ class MetricsBaseWidget(BaseWidget): If True, metrics data are included in unit table, by default True """ - possible_backends = {} - def __init__( self, metrics, @@ -77,3 +76,191 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + metrics = dp.metrics + num_metrics = len(metrics.columns) + + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) + + backend_kwargs["num_axes"] = num_metrics**2 + backend_kwargs["ncols"] = num_metrics + + all_unit_ids = metrics.index.values + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + if dp.unit_ids is None: + colors = ["gray"] * len(all_unit_ids) + else: + colors = [] + for unit in all_unit_ids: + color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] + colors.append(color) + + self.patches = [] + for i, m1 in enumerate(metrics.columns): + for j, m2 in enumerate(metrics.columns): + if i == j: + self.axes[i, j].hist(metrics[m1], color="gray") + else: + p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") + self.patches.append(p) + if i == num_metrics - 1: + self.axes[i, j].set_xlabel(m2, fontsize=10) + if j == 0: + self.axes[i, j].set_ylabel(m1, fontsize=10) + self.axes[i, j].set_xticklabels([]) + self.axes[i, j].set_yticklabels([]) + self.axes[i, j].spines["top"].set_visible(False) + self.axes[i, j].spines["right"].set_visible(False) + + self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + if data_plot["unit_ids"] is None: + data_plot["unit_ids"] = [] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + from matplotlib.lines import Line2D + + unit_ids = self.controller["unit_ids"].value + + unit_colors = self.data_plot["unit_colors"] + # matplotlib next_data_plot dict update at each call + all_units = list(unit_colors.keys()) + colors = [] + sizes = [] + for unit in all_units: + color = "gray" if unit not in unit_ids else unit_colors[unit] + size = 1 if unit not in unit_ids else 5 + colors.append(color) + sizes.append(size) + + # here we do a trick: we just update colors + if hasattr(self, "patches"): + for p in self.patches: + p.set_color(colors) + p.set_sizes(sizes) + else: + backend_kwargs = {} + backend_kwargs["figure"] = self.figure + self.plot_matplotlib(self.data_plot, **backend_kwargs) + + if len(unit_ids) > 0: + for l in self.figure.legends: + l.remove() + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in unit_ids + ] + labels = unit_ids + self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + metrics = dp.metrics + metric_names = list(metrics.columns) + + if dp.unit_ids is None: + unit_ids = metrics.index.values + else: + unit_ids = dp.unit_ids + unit_ids = make_serializable(unit_ids) + + metrics_sv = [] + for col in metric_names: + dtype = metrics.iloc[0][col].dtype + metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) + metrics_sv.append(metric) + + units_m = [] + for unit_id in unit_ids: + values = check_json(metrics.loc[unit_id].to_dict()) + values_skip_nans = {} + for k, v in values.items(): + if np.isnan(v): + continue + values_skip_nans[k] = v + + units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) + v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) + + if not dp.hide_unit_selector: + if dp.include_metrics_data: + # make a view of the sorting to add tmp properties + sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) + for col in metric_names: + if col not in sorting_copy.get_property_keys(): + sorting_copy.set_property(col, metrics[col].values) + # generate table with properties + v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) + else: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Splitter( + direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) + ) + else: + self.view = v_metrics + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6a89050856..cb11bcce0c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -1,11 +1,6 @@ import numpy as np -from warnings import warn -from .base import BaseWidget -from .utils import get_unit_colors - - -from ..core.template_tools import get_template_extremum_amplitude +from .base import BaseWidget, to_attr class MotionWidget(BaseWidget): @@ -14,44 +9,175 @@ class MotionWidget(BaseWidget): Parameters ---------- - recording : RecordingExtractor - The recording extractor object motion_info: dict The motion info return by correct_motion() or load back with load_motion_info() - depth_lim: tuple + recording : RecordingExtractor, optional + The recording extractor object (only used to get "real" times), default None + sampling_frequency : float, optional + The sampling frequency (needed if recording is None), default None + depth_lim : tuple The min and max depth to display, default None (min and max of the recording) - motion_lim: tuple + motion_lim : tuple The min and max motion to display, default None (min and max of the motion) - color_amplitude: bool + color_amplitude : bool If True, the color of the scatter points is the amplitude of the peaks, default False - scatter_decimate: int + scatter_decimate : int If > 1, the scatter points are decimated, default None - amplitude_cmap: str + amplitude_cmap : str The colormap to use for the amplitude, default 'inferno' + amplitude_clim : tuple + The min and max amplitude to display, default None (min and max of the amplitudes) + amplitude_alpha : float + The alpha of the scatter points, default 0.5 """ - possible_backends = {} - def __init__( self, - recording, motion_info, + recording=None, depth_lim=None, motion_lim=None, color_amplitude=False, scatter_decimate=None, amplitude_cmap="inferno", + amplitude_clim=None, + amplitude_alpha=1, backend=None, **backend_kwargs, ): + times = recording.get_times() if recording is not None else None + plot_data = dict( - rec=recording, + sampling_frequency=motion_info["parameters"]["sampling_frequency"], + times=times, depth_lim=depth_lim, motion_lim=motion_lim, color_amplitude=color_amplitude, scatter_decimate=scatter_decimate, amplitude_cmap=amplitude_cmap, + amplitude_clim=amplitude_clim, + amplitude_alpha=amplitude_alpha, **motion_info, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + dp = to_attr(data_plot) + + assert backend_kwargs["axes"] is None + assert backend_kwargs["ax"] is None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + is_rigid = dp.motion.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + ax0 = fig.add_subplot(gs[0, 0]) + ax1 = fig.add_subplot(gs[0, 1]) + ax2 = fig.add_subplot(gs[1, 0]) + if not is_rigid: + ax3 = fig.add_subplot(gs[1, 1]) + ax1.sharex(ax0) + ax1.sharey(ax0) + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(dp.motion)) * 1.05 + else: + motion_lim = dp.motion_lim + + if dp.times is None: + temporal_bins_plot = dp.temporal_bins + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + # use real times and adjust temporal bins with t_start + temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] + + corrected_location = correct_motion_on_peaks( + dp.peaks, + dp.peak_locations, + dp.sampling_frequency, + dp.motion, + dp.temporal_bins, + dp.spatial_bins, + direction="y", + ) + + y = dp.peak_locations["y"] + y2 = corrected_location["y"] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.get_cmap(dp.amplitude_cmap) + if dp.amplitude_clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.amplitude_alpha, + ) + else: + color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) + + ax0.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + ax0.set_ylim(*dp.depth_lim) + ax0.set_title("Peak depth") + ax0.set_xlabel("Times [s]") + ax0.set_ylabel("Depth [um]") + + ax1.scatter(x, y2, s=1, **color_kwargs) + ax1.set_xlabel("Times [s]") + ax1.set_ylabel("Depth [um]") + ax1.set_title("Corrected peak depth") + + ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") + ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.set_ylim(-motion_lim, motion_lim) + ax2.set_ylabel("Motion [um]") + ax2.set_title("Motion vectors") + axes = [ax0, ax1, ax2] + + if not is_rigid: + im = ax3.imshow( + dp.motion.T, + aspect="auto", + origin="lower", + extent=( + temporal_bins_plot[0], + temporal_bins_plot[-1], + dp.spatial_bins[0], + dp.spatial_bins[-1], + ), + ) + im.set_clim(-motion_lim, motion_lim) + cbar = fig.colorbar(im) + cbar.ax.set_xlabel("motion [um]") + ax3.set_xlabel("Times [s]") + ax3.set_ylabel("Depth [um]") + ax3.set_title("Motion vectors") + axes.append(ax3) + self.axes = np.array(axes) diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index f1c2ad6e23..4a6b46b72d 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,6 +1,5 @@ from .metrics import MetricsBaseWidget from ..core.waveform_extractor import WaveformExtractor -from ..qualitymetrics import compute_quality_metrics class QualityMetricsWidget(MetricsBaseWidget): @@ -23,8 +22,6 @@ class QualityMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 8f50eb1dde..b9760205f9 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget, define_widget_function_from_class +from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget from .crosscorrelograms import CrossCorrelogramsWidget @@ -9,7 +9,7 @@ from .unit_templates import UnitTemplatesWidget -from ..core import WaveformExtractor, ChannelSparsity +from ..core import WaveformExtractor class SortingSummaryWidget(BaseWidget): @@ -34,8 +34,6 @@ class SortingSummaryWidget(BaseWidget): (sortingview backend) """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -55,28 +53,97 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - # use other widgets to generate data (except for similarity) - template_plot_data = UnitTemplatesWidget( - we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True - ).plot_data - ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - amps_plot_data = AmplitudesWidget( - we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True - ).plot_data - locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data - plot_data = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, - templates=template_plot_data, - correlograms=ccg_plot_data, - amplitudes=amps_plot_data, - similarity=sim_plot_data, - unit_locations=locs_plot_data, + sparsity=sparsity, unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, + max_amplitudes_per_unit=max_amplitudes_per_unit, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + we = dp.waveform_extractor + unit_ids = dp.unit_ids + sparsity = dp.sparsity + + unit_ids = make_serializable(dp.unit_ids) + + v_spike_amplitudes = AmplitudesWidget( + we, + unit_ids=unit_ids, + max_spikes_per_unit=dp.max_amplitudes_per_unit, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", + ).view + v_average_waveforms = UnitTemplatesWidget( + we, + unit_ids=unit_ids, + sparsity=sparsity, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", + ).view + v_cross_correlograms = CrossCorrelogramsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + ).view + + v_unit_locations = UnitLocationsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + ).view + + w = TemplateSimilarityWidget( + we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" + ) + similarity = w.data_plot["similarity"] + + # similarity + similarity_scores = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + similarity_scores.append( + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32")) + ) + + # unit ids + v_units_table = generate_unit_table_view( + dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + ) + + if dp.curation: + v_curation = vv.SortingCuration2(label_choices=dp.label_choices) + v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) + else: + v1 = v_units_table + v2 = vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_unit_locations, stretch=0.2), + item2=vv.LayoutItem( + vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_average_waveforms), + item2=vv.LayoutItem( + vv.Splitter( + direction="vertical", + item1=vv.LayoutItem(v_spike_amplitudes), + item2=vv.LayoutItem(v_cross_correlograms), + ) + ), + ) + ), + ) + + # assemble layout + self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/sortingview/__init__.py b/src/spikeinterface/widgets/sortingview/__init__.py deleted file mode 100644 index 5663f95078..0000000000 --- a/src/spikeinterface/widgets/sortingview/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .sorting_summary import SortingSummaryPlotter -from .spike_locations import SortingviewPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter diff --git a/src/spikeinterface/widgets/sortingview/amplitudes.py b/src/spikeinterface/widgets/sortingview/amplitudes.py deleted file mode 100644 index 8676ccd994..0000000000 --- a/src/spikeinterface/widgets/sortingview/amplitudes.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_sortingview import SortingviewPlotter - - -class AmplitudesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Amplitudes" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - sa_items = [ - vv.SpikeAmplitudesItem( - unit_id=u, - spike_times_sec=dp.spiketrains[u].astype("float32"), - spike_amplitudes=dp.amplitudes[u].astype("float32"), - ) - for u in unit_ids - ] - - v_spike_amplitudes = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) - return v_spike_amplitudes - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/sortingview/autocorrelograms.py b/src/spikeinterface/widgets/sortingview/autocorrelograms.py deleted file mode 100644 index 345f8c2bdf..0000000000 --- a/src/spikeinterface/widgets/sortingview/autocorrelograms.py +++ /dev/null @@ -1,34 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class AutoCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Auto Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - unit_ids = self.make_serializable(dp.unit_ids) - - ac_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - if i == j: - ac_items.append( - vv.AutocorrelogramItem( - unit_id=unit_ids[i], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_autocorrelograms = vv.Autocorrelograms(autocorrelograms=ac_items) - - self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) - return v_autocorrelograms - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/base_sortingview.py b/src/spikeinterface/widgets/sortingview/base_sortingview.py deleted file mode 100644 index c42da0fba3..0000000000 --- a/src/spikeinterface/widgets/sortingview/base_sortingview.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from spikeinterface.widgets.base import BackendPlotter - - -class SortingviewPlotter(BackendPlotter): - backend = "sortingview" - backend_kwargs_desc = { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", - } - default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - - def __init__(self): - self.view = None - self.url = None - - def make_serializable(*args): - dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} - serializable_dict = check_json(dict_to_serialize) - returns = () - for i in range(len(args) - 1): - returns += (serializable_dict[str(i)],) - if len(returns) == 1: - returns = returns[0] - return returns - - @staticmethod - def is_notebook() -> bool: - try: - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - elif shell == "TerminalInteractiveShell": - return False # Terminal running IPython - else: - return False # Other type (?) - except NameError: - return False - - def handle_display_and_url(self, view, **backend_kwargs): - self.set_view(view) - if self.is_notebook() and backend_kwargs["display"]: - display(self.view.jupyter(height=backend_kwargs["height"])) - if backend_kwargs["generate_url"]: - figlabel = backend_kwargs.get("figlabel") - if figlabel is None: - figlabel = self.default_label - url = view.url(label=figlabel) - self.set_url(url) - print(url) - - # make view and url accessible by the plotter - def set_view(self, view): - self.view = view - - def set_url(self, url): - self.url = url - - -def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): - import sortingview.views as vv - - if unit_properties is None: - ut_columns = [] - ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] - else: - ut_columns = [] - ut_rows = [] - values = {} - valid_unit_properties = [] - for prop_name in unit_properties: - property_values = sorting.get_property(prop_name) - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - property_values = sorting.get_property(prop_name) - val0 = property_values[0] - if np.isnan(property_values[ui]): - continue - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) - - v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) - return v_units_table diff --git a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py b/src/spikeinterface/widgets/sortingview/crosscorrelograms.py deleted file mode 100644 index ec9c7bb16c..0000000000 --- a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py +++ /dev/null @@ -1,37 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class CrossCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Cross Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - cc_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - cc_items.append( - vv.CrossCorrelogramItem( - unit_id1=unit_ids[i], - unit_id2=unit_ids[j], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_cross_correlograms = vv.CrossCorrelograms( - cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) - return v_cross_correlograms - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/metrics.py b/src/spikeinterface/widgets/sortingview/metrics.py deleted file mode 100644 index d46256739e..0000000000 --- a/src/spikeinterface/widgets/sortingview/metrics.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from ..base import to_attr -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class MetricsPlotter(SortingviewPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - metrics = dp.metrics - metric_names = list(metrics.columns) - - if dp.unit_ids is None: - unit_ids = metrics.index.values - else: - unit_ids = dp.unit_ids - unit_ids = self.make_serializable(unit_ids) - - metrics_sv = [] - for col in metric_names: - dtype = metrics.iloc[0][col].dtype - metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) - metrics_sv.append(metric) - - units_m = [] - for unit_id in unit_ids: - values = check_json(metrics.loc[unit_id].to_dict()) - values_skip_nans = {} - for k, v in values.items(): - if np.isnan(v): - continue - values_skip_nans[k] = v - - units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) - v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) - - if not dp.hide_unit_selector: - if dp.include_metrics_data: - # make a view of the sorting to add tmp properties - sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) - for col in metric_names: - if col not in sorting_copy.get_property_keys(): - sorting_copy.set_property(col, metrics[col].values) - # generate table with properties - v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) - else: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Splitter( - direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) - ) - else: - view = v_metrics - - self.handle_display_and_url(view, **backend_kwargs) - return view diff --git a/src/spikeinterface/widgets/sortingview/quality_metrics.py b/src/spikeinterface/widgets/sortingview/quality_metrics.py deleted file mode 100644 index 379ba158a5..0000000000 --- a/src/spikeinterface/widgets/sortingview/quality_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..quality_metrics import QualityMetricsWidget - - -class QualityMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Quality Metrics" - - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/sorting_summary.py b/src/spikeinterface/widgets/sortingview/sorting_summary.py deleted file mode 100644 index bb248e1691..0000000000 --- a/src/spikeinterface/widgets/sortingview/sorting_summary.py +++ /dev/null @@ -1,86 +0,0 @@ -from ..base import to_attr -from ..sorting_summary import SortingSummaryWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter - - -class SortingSummaryPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Sorting Summary" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - amplitudes_plotter = AmplitudesPlotter() - v_spike_amplitudes = amplitudes_plotter.do_plot( - dp.amplitudes, generate_url=False, display=False, backend="sortingview" - ) - template_plotter = UnitTemplatesPlotter() - v_average_waveforms = template_plotter.do_plot( - dp.templates, generate_url=False, display=False, backend="sortingview" - ) - xcorrelograms_plotter = CrossCorrelogramsPlotter() - v_cross_correlograms = xcorrelograms_plotter.do_plot( - dp.correlograms, generate_url=False, display=False, backend="sortingview" - ) - unitlocation_plotter = UnitLocationsPlotter() - v_unit_locations = unitlocation_plotter.do_plot( - dp.unit_locations, generate_url=False, display=False, backend="sortingview" - ) - # similarity - similarity_scores = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - similarity_scores.append( - vv.UnitSimilarityScore( - unit_id1=u1, unit_id2=u2, similarity=dp.similarity["similarity"][i1, i2].astype("float32") - ) - ) - - # unit ids - v_units_table = generate_unit_table_view( - dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores - ) - - if dp.curation: - v_curation = vv.SortingCuration2(label_choices=dp.label_choices) - v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) - else: - v1 = v_units_table - v2 = vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_unit_locations, stretch=0.2), - item2=vv.LayoutItem( - vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_average_waveforms), - item2=vv.LayoutItem( - vv.Splitter( - direction="vertical", - item1=vv.LayoutItem(v_spike_amplitudes), - item2=vv.LayoutItem(v_cross_correlograms), - ) - ), - ) - ), - ) - - # assemble layout - v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) - - self.handle_display_and_url(v_summary, **backend_kwargs) - return v_summary - - -SortingSummaryPlotter.register(SortingSummaryWidget) diff --git a/src/spikeinterface/widgets/sortingview/spike_locations.py b/src/spikeinterface/widgets/sortingview/spike_locations.py deleted file mode 100644 index 747c3df4e7..0000000000 --- a/src/spikeinterface/widgets/sortingview/spike_locations.py +++ /dev/null @@ -1,64 +0,0 @@ -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class SpikeLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Spike Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - spike_locations = dp.spike_locations - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - xlims, ylims = estimate_axis_lims(spike_locations) - - unit_items = [] - for unit in unit_ids: - spike_times_sec = dp.sorting.get_unit_spike_train( - unit_id=unit, segment_index=dp.segment_index, return_times=True - ) - unit_items.append( - vv.SpikeLocationsItem( - unit_id=unit, - spike_times_sec=spike_times_sec.astype("float32"), - x_locations=spike_locations[unit]["x"].astype("float32"), - y_locations=spike_locations[unit]["y"].astype("float32"), - ) - ) - - v_spike_locations = vv.SpikeLocations( - units=unit_items, - hide_unit_selector=dp.hide_unit_selector, - x_range=xlims.astype("float32"), - y_range=ylims.astype("float32"), - channel_locations=locations, - disable_auto_rotate=True, - ) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[ - vv.LayoutItem(v_units_table, max_size=150), - vv.LayoutItem(v_spike_locations), - ], - ) - else: - view = v_spike_locations - - self.set_view(view) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_metrics.py b/src/spikeinterface/widgets/sortingview/template_metrics.py deleted file mode 100644 index 204bb8f377..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..template_metrics import TemplateMetricsWidget - - -class TemplateMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Template Metrics" - - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_similarity.py b/src/spikeinterface/widgets/sortingview/template_similarity.py deleted file mode 100644 index e35b8c2e34..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_similarity.py +++ /dev/null @@ -1,32 +0,0 @@ -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_sortingview import SortingviewPlotter - - -class TemplateSimilarityPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Template Similarity" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids = self.make_serializable(dp.unit_ids) - - # similarity - ss_items = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - ss_items.append( - vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) - ) - - view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/sortingview/timeseries.py b/src/spikeinterface/widgets/sortingview/timeseries.py deleted file mode 100644 index eec0e920e4..0000000000 --- a/src/spikeinterface/widgets/sortingview/timeseries.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import warnings - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from ..utils import array_to_image -from .base_sortingview import SortingviewPlotter - - -class TimeseriesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Timeseries" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - try: - import pyvips - except ImportError: - raise ImportError("To use the timeseries in sorting view you need the pyvips package.") - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' - - if not dp.order_channel_by_depth: - warnings.warn( - "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" - ) - - tiled_layers = [] - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - img = array_to_image( - traces, - clim=dp.clims[layer_key], - num_timepoints_per_row=dp.num_timepoints_per_row, - colormap=dp.cmap, - scalebar=True, - sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), - ) - - tiled_layers.append(vv.TiledImageLayer(layer_key, img)) - - view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - - self.set_view(view_ts) - - # timeseries currently doesn't display on the jupyter backend - backend_kwargs["display"] = False - self.handle_display_and_url(view_ts, **backend_kwargs) - return view_ts - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_locations.py b/src/spikeinterface/widgets/sortingview/unit_locations.py deleted file mode 100644 index 368b45321f..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_locations.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - - unit_items = [] - for unit_id in unit_ids: - unit_items.append( - vv.UnitLocationsItem( - unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) - ) - ) - - v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], - ) - else: - view = v_unit_locations - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_templates.py b/src/spikeinterface/widgets/sortingview/unit_templates.py deleted file mode 100644 index 37595740fd..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_templates.py +++ /dev/null @@ -1,54 +0,0 @@ -from ..base import to_attr -from ..unit_templates import UnitTemplatesWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitTemplatesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Templates" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices - - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - templates_dict = {} - for u_i, unit in enumerate(unit_ids): - templates_dict[unit] = {} - templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - - aw_items = [ - vv.AverageWaveformItem( - unit_id=u, - channel_ids=list(unit_id_to_channel_ids[u]), - waveform=t["mean"].astype("float32"), - waveform_std_dev=t["std"].astype("float32"), - ) - for u, t in templates_dict.items() - ] - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], - ) - else: - view = v_average_waveforms - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index da5ad5b08c..9771b2c0e9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -1,7 +1,6 @@ import numpy as np -from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -36,7 +35,7 @@ class SpikeLocationsWidget(BaseWidget): If True, the axis is set to off. Default False (matplotlib backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -105,6 +104,210 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.lines import Line2D + + from probeinterface import ProbeGroup + from probeinterface.plotting import plot_probe + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + spike_locations = dp.spike_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + for i, unit in enumerate(unit_ids): + locs = spike_locations[unit] + + zorder = 5 if unit in dp.unit_ids else 3 + self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) + + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + if dp.plot_legend: + if hasattr(self, "legend") and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + # set proper axis limits + xlims, ylims = estimate_axis_lims(spike_locations) + + ax_xlims = list(self.ax.get_xlim()) + ax_ylims = list(self.ax.get_ylim()) + + ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] + ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] + ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] + ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] + + self.ax.set_xlim(ax_xlims) + self.ax.set_ylim(ax_ylims) + if dp.hide_axis: + self.ax.axis("off") + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], + list(data_plot["unit_colors"].keys()), + ratios[0] * width_cm, + height_cm, + ) + + self.controller = unit_controller + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + spike_locations = dp.spike_locations + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + xlims, ylims = estimate_axis_lims(spike_locations) + + unit_items = [] + for unit in unit_ids: + spike_times_sec = dp.sorting.get_unit_spike_train( + unit_id=unit, segment_index=dp.segment_index, return_times=True + ) + unit_items.append( + vv.SpikeLocationsItem( + unit_id=unit, + spike_times_sec=spike_times_sec.astype("float32"), + x_locations=spike_locations[unit]["x"].astype("float32"), + y_locations=spike_locations[unit]["y"].astype("float32"), + ) + ) + + v_spike_locations = vv.SpikeLocations( + units=unit_items, + hide_unit_selector=dp.hide_unit_selector, + x_range=xlims.astype("float32"), + y_range=ylims.astype("float32"), + channel_locations=locations, + disable_auto_rotate=True, + ) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[ + vv.LayoutItem(v_units_table, max_size=150), + vv.LayoutItem(v_spike_locations), + ], + ) + else: + self.view = v_spike_locations + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) + def estimate_axis_lims(spike_locations, quantile=0.02): # set proper axis limits diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index b50896df4d..e7bcff0832 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -1,8 +1,8 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors -from .timeseries import TimeseriesWidget +from .traces import TracesWidget from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import WaveformExtractor @@ -60,8 +60,6 @@ class SpikesOnTracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -86,29 +84,8 @@ def __init__( **backend_kwargs, ): we = waveform_extractor - recording: BaseRecording = we.recording sorting: BaseSorting = we.sorting - ts_widget = TimeseriesWidget( - recording, - segment_index, - channel_ids, - order_channel_by_depth, - time_range, - mode, - return_scaled, - cmap, - show_channel_ids, - color_groups, - color, - clim, - tile_size, - seconds_per_row, - with_colorbar, - backend, - **backend_kwargs, - ) - if unit_ids is None: unit_ids = sorting.get_unit_ids() unit_ids = unit_ids @@ -133,9 +110,25 @@ def __init__( # get templates unit_locations = compute_unit_locations(we, outputs="by_unit") + options = dict( + segment_index=segment_index, + channel_ids=channel_ids, + order_channel_by_depth=order_channel_by_depth, + time_range=time_range, + mode=mode, + return_scaled=return_scaled, + cmap=cmap, + show_channel_ids=show_channel_ids, + color_groups=color_groups, + color=color, + clim=clim, + tile_size=tile_size, + with_colorbar=with_colorbar, + ) + plot_data = dict( - timeseries=ts_widget.plot_data, waveform_extractor=waveform_extractor, + options=options, unit_ids=unit_ids, sparsity=sparsity, unit_colors=unit_colors, @@ -143,3 +136,163 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + we = dp.waveform_extractor + recording = we.recording + sorting = we.sorting + + # first plot time series + ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + ax = self.ax + + frame_range = ts_widget.data_plot["frame_range"] + segment_index = ts_widget.data_plot["segment_index"] + min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + + n = len(ts_widget.data_plot["channel_ids"]) + order = ts_widget.data_plot["order"] + + if order is None: + order = np.arange(n) + + if ax.get_legend() is not None: + ax.get_legend().remove() + + # loop through units and plot a scatter of spikes at estimated location + handles = [] + labels = [] + + for unit in dp.unit_ids: + spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) + spike_start, spike_end = np.searchsorted(spike_frames, frame_range) + + chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] + + spike_frames_to_plot = spike_frames[spike_start:spike_end] + + if dp.options["mode"] == "map": + spike_times_to_plot = sorting.get_unit_spike_train( + unit, segment_index=segment_index, return_times=True + )[spike_start:spike_end] + unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] + # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) + width = 2 * 1e-3 + ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) + patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] + for p in patches: + ax.add_patch(p) + handles.append( + Line2D( + [0], + [0], + ls="", + marker="o", + markersize=5, + markeredgewidth=2, + markeredgecolor=dp.unit_colors[unit], + markerfacecolor="none", + ) + ) + labels.append(unit) + else: + # construct waveforms + label_set = False + if len(spike_frames_to_plot) > 0: + vspacing = ts_widget.data_plot["vspacing"] + traces = ts_widget.data_plot["list_traces"][0] + + waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] + waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + + times = ts_widget.data_plot["times"][waveform_idxs] + + # discontinuity + times[:, -1] = np.nan + times_r = times.reshape(times.shape[0] * times.shape[1]) + waveforms = traces[waveform_idxs] # [:, :, order] + waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) + + for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + offset = vspacing * i + if chan_id in chan_ids: + l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) + if not label_set: + handles.append(l[0]) + labels.append(unit) + label_set = True + ax.legend(handles, labels) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + dp = to_attr(data_plot) + we = dp.waveform_extractor + + ratios = [0.2, 0.8] + + backend_kwargs_ts = backend_kwargs.copy() + backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] + backend_kwargs_ts["display"] = False + height_cm = backend_kwargs["height_cm"] + width_cm = backend_kwargs["width_cm"] + + # plot timeseries + ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + self.controller = dict() + self.controller.update(ts_widget.controller) + self.controller.update(unit_controller) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index b441882730..748babb57d 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -22,8 +22,6 @@ class TemplateMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 475c873c29..63ac177835 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -1,9 +1,7 @@ import numpy as np -from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class TemplateSimilarityWidget(BaseWidget): @@ -27,8 +25,6 @@ class TemplateSimilarityWidget(BaseWidget): If True, color bar is displayed, default True. """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -63,3 +59,46 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + im = self.ax.matshow(dp.similarity, cmap=dp.cmap) + + if dp.show_unit_ticks: + # Major ticks + self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) + self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + self.ax.set_yticklabels(dp.unit_ids, fontsize=12) + self.ax.set_xticklabels(dp.unit_ids, fontsize=12) + if dp.show_colorbar: + self.figure.colorbar(im) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + # ensure serializable for sortingview + unit_ids = make_serializable(dp.unit_ids) + + # similarity + ss_items = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + ss_items.append( + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) + ) + + self.view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 3a60a9d2c7..a5f75ebf50 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,6 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -from spikeinterface.widgets import HAVE_MPL, HAVE_SV import spikeinterface.extractors as se import spikeinterface.widgets as sw @@ -68,7 +67,10 @@ def setUpClass(cls): # make sparse waveforms cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + if (cache_folder / "mearec_test_sparse").is_dir(): + cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") + else: + cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets"] @@ -81,16 +83,16 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - def test_plot_timeseries(self): - possible_backends = list(sw.TimeseriesWidget.possible_backends.keys()) + def test_plot_traces(self): + possible_backends = list(sw.TracesWidget.get_possible_backends()) for backend in possible_backends: if ON_GITHUB and backend == "sortingview": continue if backend not in self.skip_backends: - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, @@ -100,8 +102,8 @@ def test_plot_timeseries(self): ) if backend != "sortingview": - sw.plot_timeseries(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) - sw.plot_timeseries( + sw.plot_traces(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) + sw.plot_traces( self.recording, mode="line", show_channel_ids=True, @@ -109,7 +111,7 @@ def test_plot_timeseries(self): **self.backend_kwargs[backend], ) # multi layer - sw.plot_timeseries( + sw.plot_traces( {"rec0": self.recording, "rec1": scale(self.recording, gain=0.8, offset=0)}, color="r", mode="line", @@ -119,7 +121,7 @@ def test_plot_timeseries(self): ) def test_plot_unit_waveforms(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -143,7 +145,7 @@ def test_plot_unit_waveforms(self): ) def test_plot_unit_templates(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -164,7 +166,7 @@ def test_plot_unit_templates(self): ) def test_plot_unit_waveforms_density_map(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -173,7 +175,7 @@ def test_plot_unit_waveforms_density_map(self): ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -187,7 +189,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): ) def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -201,7 +203,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): ) def test_autocorrelograms(self): - possible_backends = list(sw.AutoCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.AutoCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -215,7 +217,7 @@ def test_autocorrelograms(self): ) def test_crosscorrelogram(self): - possible_backends = list(sw.CrossCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -229,7 +231,7 @@ def test_crosscorrelogram(self): ) def test_amplitudes(self): - possible_backends = list(sw.AmplitudesWidget.possible_backends.keys()) + possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -247,7 +249,7 @@ def test_amplitudes(self): ) def test_plot_all_amplitudes_distributions(self): - possible_backends = list(sw.AllAmplitudesDistributionsWidget.possible_backends.keys()) + possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.we.unit_ids[:4] @@ -259,7 +261,7 @@ def test_plot_all_amplitudes_distributions(self): ) def test_unit_locations(self): - possible_backends = list(sw.UnitLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -268,7 +270,7 @@ def test_unit_locations(self): ) def test_spike_locations(self): - possible_backends = list(sw.SpikeLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -277,35 +279,35 @@ def test_spike_locations(self): ) def test_similarity(self): - possible_backends = list(sw.TemplateSimilarityWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): - possible_backends = list(sw.QualityMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): - possible_backends = list(sw.TemplateMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): - possible_backends = list(sw.UnitDepthsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): - possible_backends = list(sw.UnitSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.UnitSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( @@ -316,7 +318,7 @@ def test_plot_unit_summary(self): ) def test_sorting_summary(self): - possible_backends = list(sw.SortingSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -332,15 +334,17 @@ def test_sorting_summary(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() - # mytest.test_plot_timeseries() + # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - mytest.test_quality_metrics() - mytest.test_template_metrics() + # mytest.test_unit_locations() + # mytest.test_quality_metrics() + # mytest.test_template_metrics() + mytest.test_amplitudes() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py deleted file mode 100644 index 93e0358460..0000000000 --- a/src/spikeinterface/widgets/timeseries.py +++ /dev/null @@ -1,251 +0,0 @@ -import numpy as np - -from ..core import BaseRecording, order_channels_by_depth -from .base import BaseWidget -from .utils import get_some_colors - - -class TimeseriesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor, dict, or list - The recording extractor object. If dict (or list) then it is a multi-layer display to compare, for example, - different processing steps - segment_index: None or int - The segment index (required for multi-segment recordings), default None - channel_ids: list - The channel ids to display, default None - order_channel_by_depth: bool - Reorder channel by depth, default False - time_range: list - List with start time and end time, default None - mode: str - Three possible modes, default 'auto': - * 'line': classical for low channel count - * 'map': for high channel count use color heat map - * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) - return_scaled: bool - If True and the recording has scaled traces, it plots the scaled traces, default False - cmap: str - matplotlib colormap used in mode 'map', default 'RdBu' - show_channel_ids: bool - Set yticks with channel ids, default False - color_groups: bool - If True groups are plotted with different colors, default False - color: str - The color used to draw the traces, default None - clim: None, tuple or dict - When mode is 'map', this argument controls color limits. - If dict, keys should be the same as recording keys - Default None - with_colorbar: bool - When mode is 'map', a colorbar is added, by default True - tile_size: int - For sortingview backend, the size of each tile in the rendered image, default 1500 - seconds_per_row: float - For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 - add_legend : bool - If True adds legend to figures, default True - - Returns - ------- - W: TimeseriesWidget - The output widget - """ - - possible_backends = {} - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - return_scaled=False, - cmap="RdBu_r", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - tile_size=1500, - seconds_per_row=0.2, - with_colorbar=True, - add_legend=True, - backend=None, - **backend_kwargs, - ): - if isinstance(recording, BaseRecording): - recordings = {"rec": recording} - rec0 = recording - elif isinstance(recording, dict): - recordings = recording - k0 = list(recordings.keys())[0] - rec0 = recordings[k0] - elif isinstance(recording, list): - recordings = {f"rec{i}": rec for i, rec in enumerate(recording)} - rec0 = recordings[0] - else: - raise ValueError("plot_timeseries recording must be recording or dict or list") - - layer_keys = list(recordings.keys()) - - if segment_index is None: - if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - - if channel_ids is None: - channel_ids = rec0.channel_ids - - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None - - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None - - fs = rec0.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - mode = mode - cmap = cmap - - times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled - ) - - # stat for auto scaling done on the first layer - traces0 = list_traces[0] - mean_channel_std = np.mean(np.std(traces0, axis=0)) - max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 - - if rec0.get_channel_groups() is None: - color_groups = False - - # colors is a nested dict by layer and channels - # lets first create black for all channels and layer - colors = {} - for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} - - if color_groups: - channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) - groups = np.unique(channel_groups) - - group_colors = get_some_colors(groups, color_engine="auto") - - channel_colors = {} - for i, chan_id in enumerate(channel_ids): - group = channel_groups[i] - channel_colors[chan_id] = group_colors[group] - - # first layer is colored then black - colors[layer_keys[0]] = channel_colors - - elif color is not None: - # old behavior one color for all channel - # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} - elif color is None and len(recordings) > 1: - # several layer - layer_colors = get_some_colors(layer_keys) - for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} - else: - # color is None unique layer : all channels black - pass - - if clim is None: - clims = {layer_key: [-200, 200] for layer_key in layer_keys} - else: - if isinstance(clim, tuple): - clims = {layer_key: clim for layer_key in layer_keys} - elif isinstance(clim, dict): - assert all(layer_key in clim for layer_key in layer_keys), "" - clims = clim - else: - raise TypeError(f"'clim' can be None, tuple, or dict! Unsupported type {type(clim)}") - - plot_data = dict( - recordings=recordings, - segment_index=segment_index, - channel_ids=channel_ids, - channel_locations=channel_locations, - time_range=time_range, - frame_range=frame_range, - times=times, - layer_keys=layer_keys, - list_traces=list_traces, - mode=mode, - cmap=cmap, - clims=clims, - with_colorbar=with_colorbar, - mean_channel_std=mean_channel_std, - max_channel_amp=max_channel_amp, - vspacing=vspacing, - colors=colors, - show_channel_ids=show_channel_ids, - add_legend=add_legend, - order_channel_by_depth=order_channel_by_depth, - order=order, - tile_size=tile_size, - num_timepoints_per_row=int(seconds_per_row * fs), - return_scaled=return_scaled, - ) - - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): - # function also used in ipywidgets plotter - k0 = list(recordings.keys())[0] - rec0 = recordings[k0] - - fs = rec0.get_sampling_frequency() - - if return_scaled: - assert all( - rec.has_scaled() for rec in recordings.values() - ), "Some recording layers do not have scaled traces. Use `return_scaled=False`" - frame_range = (time_range * fs).astype("int64") - a_max = rec0.get_num_frames(segment_index=segment_index) - frame_range = np.clip(frame_range, 0, a_max) - time_range = frame_range / fs - times = np.arange(frame_range[0], frame_range[1]) / fs - - list_traces = [] - for rec_name, rec in recordings.items(): - traces = rec.get_traces( - segment_index=segment_index, - channel_ids=channel_ids, - start_frame=frame_range[0], - end_frame=frame_range[1], - return_scaled=return_scaled, - ) - - if order is not None: - traces = traces[:, order] - list_traces.append(traces) - - if order is not None: - channel_ids = np.array(channel_ids)[order] - - return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py new file mode 100644 index 0000000000..9a2ec4a215 --- /dev/null +++ b/src/spikeinterface/widgets/traces.py @@ -0,0 +1,566 @@ +import warnings + +import numpy as np + +from ..core import BaseRecording, order_channels_by_depth +from .base import BaseWidget, to_attr +from .utils import get_some_colors, array_to_image + + +class TracesWidget(BaseWidget): + """ + Plots recording timeseries. + + Parameters + ---------- + recording: RecordingExtractor, dict, or list + The recording extractor object. If dict (or list) then it is a multi-layer display to compare, for example, + different processing steps + segment_index: None or int + The segment index (required for multi-segment recordings), default None + channel_ids: list + The channel ids to display, default None + order_channel_by_depth: bool + Reorder channel by depth, default False + time_range: list + List with start time and end time, default None + mode: str + Three possible modes, default 'auto': + * 'line': classical for low channel count + * 'map': for high channel count use color heat map + * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) + return_scaled: bool + If True and the recording has scaled traces, it plots the scaled traces, default False + cmap: str + matplotlib colormap used in mode 'map', default 'RdBu' + show_channel_ids: bool + Set yticks with channel ids, default False + color_groups: bool + If True groups are plotted with different colors, default False + color: str + The color used to draw the traces, default None + clim: None, tuple or dict + When mode is 'map', this argument controls color limits. + If dict, keys should be the same as recording keys + Default None + with_colorbar: bool + When mode is 'map', a colorbar is added, by default True + tile_size: int + For sortingview backend, the size of each tile in the rendered image, default 1500 + seconds_per_row: float + For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 + add_legend : bool + If True adds legend to figures, default True + + Returns + ------- + W: TracesWidget + The output widget + """ + + def __init__( + self, + recording, + segment_index=None, + channel_ids=None, + order_channel_by_depth=False, + time_range=None, + mode="auto", + return_scaled=False, + cmap="RdBu_r", + show_channel_ids=False, + color_groups=False, + color=None, + clim=None, + tile_size=1500, + seconds_per_row=0.2, + with_colorbar=True, + add_legend=True, + backend=None, + **backend_kwargs, + ): + if isinstance(recording, BaseRecording): + recordings = {"rec": recording} + rec0 = recording + elif isinstance(recording, dict): + recordings = recording + k0 = list(recordings.keys())[0] + rec0 = recordings[k0] + elif isinstance(recording, list): + recordings = {f"rec{i}": rec for i, rec in enumerate(recording)} + rec0 = recordings[0] + else: + raise ValueError("plot_traces recording must be recording or dict or list") + + layer_keys = list(recordings.keys()) + + if segment_index is None: + if rec0.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if channel_ids is None: + channel_ids = rec0.channel_ids + + if "location" in rec0.get_property_keys(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None + + if order_channel_by_depth: + if channel_locations is not None: + order, _ = order_channels_by_depth(rec0, channel_ids) + else: + order = None + + fs = rec0.get_sampling_frequency() + if time_range is None: + time_range = (0, 1.0) + time_range = np.array(time_range) + + assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" + if mode == "auto": + if len(channel_ids) <= 64: + mode = "line" + else: + mode = "map" + mode = mode + cmap = cmap + + times, list_traces, frame_range, channel_ids = _get_trace_list( + recordings, channel_ids, time_range, segment_index, order, return_scaled + ) + + # stat for auto scaling done on the first layer + traces0 = list_traces[0] + mean_channel_std = np.mean(np.std(traces0, axis=0)) + max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) + vspacing = max_channel_amp * 1.5 + + if rec0.get_channel_groups() is None: + color_groups = False + + # colors is a nested dict by layer and channels + # lets first create black for all channels and layer + colors = {} + for k in layer_keys: + colors[k] = {chan_id: "k" for chan_id in channel_ids} + + if color_groups: + channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) + groups = np.unique(channel_groups) + + group_colors = get_some_colors(groups, color_engine="auto") + + channel_colors = {} + for i, chan_id in enumerate(channel_ids): + group = channel_groups[i] + channel_colors[chan_id] = group_colors[group] + + # first layer is colored then black + colors[layer_keys[0]] = channel_colors + + elif color is not None: + # old behavior one color for all channel + # if multi layer then black for all + colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + elif color is None and len(recordings) > 1: + # several layer + layer_colors = get_some_colors(layer_keys) + for k in layer_keys: + colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + else: + # color is None unique layer : all channels black + pass + + if clim is None: + clims = {layer_key: [-200, 200] for layer_key in layer_keys} + else: + if isinstance(clim, tuple): + clims = {layer_key: clim for layer_key in layer_keys} + elif isinstance(clim, dict): + assert all(layer_key in clim for layer_key in layer_keys), "" + clims = clim + else: + raise TypeError(f"'clim' can be None, tuple, or dict! Unsupported type {type(clim)}") + + plot_data = dict( + recordings=recordings, + segment_index=segment_index, + channel_ids=channel_ids, + channel_locations=channel_locations, + time_range=time_range, + frame_range=frame_range, + times=times, + layer_keys=layer_keys, + list_traces=list_traces, + mode=mode, + cmap=cmap, + clims=clims, + with_colorbar=with_colorbar, + mean_channel_std=mean_channel_std, + max_channel_amp=max_channel_amp, + vspacing=vspacing, + colors=colors, + show_channel_ids=show_channel_ids, + add_legend=add_legend, + order_channel_by_depth=order_channel_by_depth, + order=order, + tile_size=tile_size, + num_timepoints_per_row=int(seconds_per_row * fs), + return_scaled=return_scaled, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + n = len(dp.channel_ids) + if dp.channel_locations is not None: + y_locs = dp.channel_locations[:, 1] + else: + y_locs = np.arange(n) + min_y = np.min(y_locs) + max_y = np.max(y_locs) + + if dp.mode == "line": + offset = dp.vspacing * (n - 1) + + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + for i, chan_id in enumerate(dp.channel_ids): + offset = dp.vspacing * i + color = dp.colors[layer_key][chan_id] + ax.plot(dp.times, offset + traces[:, i], color=color) + ax.get_lines()[-1].set_label(layer_key) + + if dp.show_channel_ids: + ax.set_yticks(np.arange(n) * dp.vspacing) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + ax.set_xlim(*dp.time_range) + ax.set_ylim(-dp.vspacing, dp.vspacing * n) + ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) + ax.set_xlabel("time (s)") + if dp.add_legend: + ax.legend(loc="upper right") + + elif dp.mode == "map": + assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording' + assert len(dp.clims) == 1 + clim = list(dp.clims.values())[0] + extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) + im = ax.imshow( + dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap + ) + + im.set_clim(*clim) + + if dp.with_colorbar: + self.figure.colorbar(im, ax=ax) + + if dp.show_channel_ids: + ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import ( + check_ipywidget_backend, + make_timeseries_controller, + make_channel_controller, + make_scale_controller, + ) + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + self.next_data_plot["add_legend"] = False + + recordings = data_plot["recordings"] + + # first layer + rec0 = recordings[data_plot["layer_keys"][0]] + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + ratios = [0.1, 0.8, 0.2] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) + plt.show() + + t_start = 0.0 + t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( + t_start, + t_stop, + data_plot["layer_keys"], + rec0.get_num_segments(), + data_plot["time_range"], + data_plot["mode"], + False, + width_cm, + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) + + scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + + self.controller = ts_controller + self.controller.update(ch_controller) + self.controller.update(scale_controller) + + self.recordings = data_plot["recordings"] + self.return_scaled = data_plot["return_scaled"] + self.list_traces = None + self.actual_segment_index = self.controller["segment_index"].value + + self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] + self.t_stops = [ + self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() + for seg_index in range(self.rec0.get_num_segments()) + ] + + for w in self.controller.values(): + if isinstance(w, widgets.Button): + w.on_click(self._update_ipywidget) + else: + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + footer=ts_widget, + left_sidebar=scale_widget, + right_sidebar=ch_widget, + pane_heights=[0, 6, 1], + pane_widths=ratios, + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + import ipywidgets.widgets as widgets + + # if changing the layer_key, no need to retrieve and process traces + retrieve_traces = True + scale_up = False + scale_down = False + if change is not None: + for cname, c in self.controller.items(): + if isinstance(change, dict): + if change["owner"] is c and cname == "layer_key": + retrieve_traces = False + elif isinstance(change, widgets.Button): + if change is c and cname == "plus": + scale_up = True + if change is c and cname == "minus": + scale_down = True + + t_start = self.controller["t_start"].value + window = self.controller["window"].value + layer_key = self.controller["layer_key"].value + segment_index = self.controller["segment_index"].value + mode = self.controller["mode"].value + chan_start, chan_stop = self.controller["channel_inds"].value + + if mode == "line": + self.controller["all_layers"].layout.visibility = "visible" + all_layers = self.controller["all_layers"].value + elif mode == "map": + self.controller["all_layers"].layout.visibility = "hidden" + all_layers = False + + if all_layers: + self.controller["layer_key"].layout.visibility = "hidden" + else: + self.controller["layer_key"].layout.visibility = "visible" + + if chan_start == chan_stop: + chan_stop += 1 + channel_indices = np.arange(chan_start, chan_stop) + + t_stop = self.t_stops[segment_index] + if self.actual_segment_index != segment_index: + # change time_slider limits + self.controller["t_start"].max = t_stop + self.actual_segment_index = segment_index + + # protect limits + if t_start >= t_stop - window: + t_start = t_stop - window + + time_range = np.array([t_start, t_start + window]) + data_plot = self.next_data_plot + + if retrieve_traces: + all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + if self.data_plot["order"] is not None: + all_channel_ids = all_channel_ids[self.data_plot["order"]] + channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + else: + times = data_plot["times"] + list_traces = data_plot["list_traces"] + frame_range = data_plot["frame_range"] + channel_ids = data_plot["channel_ids"] + + if all_layers: + layer_keys = self.data_plot["layer_keys"] + recordings = self.recordings + list_traces_plot = self.list_traces + else: + layer_keys = [layer_key] + recordings = {layer_key: self.recordings[layer_key]} + list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] + + if scale_up: + if mode == "line": + data_plot["vspacing"] *= 0.8 + elif mode == "map": + data_plot["clims"] = { + layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() + } + if scale_down: + if mode == "line": + data_plot["vspacing"] *= 1.2 + elif mode == "map": + data_plot["clims"] = { + layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() + } + + self.next_data_plot["vspacing"] = data_plot["vspacing"] + self.next_data_plot["clims"] = data_plot["clims"] + + if mode == "line": + clims = None + elif mode == "map": + clims = {layer_key: self.data_plot["clims"][layer_key]} + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = mode + data_plot["frame_range"] = frame_range + data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = recordings + data_plot["layer_keys"] = layer_keys + data_plot["list_traces"] = list_traces_plot + data_plot["times"] = times + data_plot["clims"] = clims + data_plot["channel_ids"] = channel_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import handle_display_and_url + + try: + import pyvips + except ImportError: + raise ImportError("To use the timeseries in sorting view you need the pyvips package.") + + dp = to_attr(data_plot) + + assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' + + if not dp.order_channel_by_depth: + warnings.warn( + "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" + ) + + tiled_layers = [] + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + img = array_to_image( + traces, + clim=dp.clims[layer_key], + num_timepoints_per_row=dp.num_timepoints_per_row, + colormap=dp.cmap, + scalebar=True, + sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), + ) + + tiled_layers.append(vv.TiledImageLayer(layer_key, img)) + + self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) + + # timeseries currently doesn't display on the jupyter backend + backend_kwargs["display"] = False + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) + + +def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): + # function also used in ipywidgets plotter + k0 = list(recordings.keys())[0] + rec0 = recordings[k0] + + fs = rec0.get_sampling_frequency() + + if return_scaled: + assert all( + rec.has_scaled() for rec in recordings.values() + ), "Some recording layers do not have scaled traces. Use `return_scaled=False`" + frame_range = (time_range * fs).astype("int64") + a_max = rec0.get_num_frames(segment_index=segment_index) + frame_range = np.clip(frame_range, 0, a_max) + time_range = frame_range / fs + times = np.arange(frame_range[0], frame_range[1]) / fs + + list_traces = [] + for rec_name, rec in recordings.items(): + traces = rec.get_traces( + segment_index=segment_index, + channel_ids=channel_ids, + start_frame=frame_range[0], + end_frame=frame_range[1], + return_scaled=return_scaled, + ) + + if order is not None: + traces = traces[:, order] + list_traces.append(traces) + + if order is not None: + channel_ids = np.array(channel_ids)[order] + + return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 5ceee0c133..1cc7c909a1 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -24,8 +24,6 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes, default 'neg' """ - possible_backends = {} - def __init__( self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs ): @@ -45,7 +43,7 @@ def __init__( unit_amplitudes = get_template_extremum_amplitude(we, peak_sign=peak_sign) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) - num_spikes = np.array(list(we.sorting.get_total_num_spikes().values())) + num_spikes = np.array(list(we.sorting.count_num_spikes_per_unit().values())) plot_data = dict( unit_depths=unit_depths, @@ -56,3 +54,20 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + size = dp.num_spikes / max(dp.num_spikes) * 120 + ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) + + ax.set_aspect(3) + ax.set_xlabel("amplitude") + ax.set_ylabel("depth [um]") + ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 2c58fdfe45..42267e711f 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -1,7 +1,9 @@ import numpy as np from typing import Union -from .base import BaseWidget +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -31,8 +33,6 @@ class UnitLocationsWidget(BaseWidget): If True, the axis is set to off, default False (matplotlib backend) """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -62,7 +62,7 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - plot_data = dict( + data_plot = dict( all_unit_ids=sorting.unit_ids, unit_locations=unit_locations, sorting=sorting, @@ -78,4 +78,185 @@ def __init__( hide_axis=hide_axis, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + unit_locations = dp.unit_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + width = height = 10 + ellipse_kwargs = dict(width=width, height=height, lw=2) + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + patches = [ + Ellipse( + (unit_locations[unit]), + color=unit_colors[unit], + zorder=5 if unit in dp.unit_ids else 3, + alpha=0.9 if unit in dp.unit_ids else 0.5, + **ellipse_kwargs, + ) + for i, unit in enumerate(unit_ids) + ] + for p in patches: + self.ax.add_patch(p) + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + + if dp.plot_legend: + if hasattr(self, "legend") and self.legend is not None: + # if self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + if dp.hide_axis: + self.ax.axis("off") + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + + unit_items = [] + for unit_id in unit_ids: + unit_items.append( + vv.UnitLocationsItem( + unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) + ) + ) + + v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], + ) + else: + self.view = v_unit_locations + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8e1ffe2637..964b5813e6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,7 +1,6 @@ import numpy as np -from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -31,7 +30,7 @@ class UnitSummaryWidget(BaseWidget): If WaveformExtractor is already sparse, the argument is ignored """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -48,17 +47,63 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(we.sorting) + plot_data = dict( + we=we, + unit_id=unit_id, + unit_colors=unit_colors, + sparsity=sparsity, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + unit_id = dp.unit_id + we = dp.we + unit_colors = dp.unit_colors + sparsity = dp.sparsity + + # force the figure without axes + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (18, 7) + backend_kwargs["num_axes"] = 0 + backend_kwargs["ax"] = None + backend_kwargs["axes"] = None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + # and use custum grid spec + fig = self.figure + nrows = 2 + ncols = 3 + if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + ncols += 1 + if we.is_extension("spike_amplitudes"): + nrows += 1 + gs = fig.add_gridspec(nrows, ncols) + if we.is_extension("unit_locations"): - plot_data_unit_locations = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False - ).plot_data - unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - else: - plot_data_unit_locations = None - unit_location = None + ax1 = fig.add_subplot(gs[:2, 0]) + # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) + w = UnitLocationsWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 + ) - plot_data_waveforms = UnitWaveformsWidget( + unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax1.set_xlim(x - 80, x + 80) + ax1.set_ylim(y - 250, y + 250) + ax1.set_xticks([]) + ax1.set_xlabel(None) + ax1.set_ylabel(None) + + ax2 = fig.add_subplot(gs[:2, 1]) + w = UnitWaveformsWidget( we, unit_ids=[unit_id], unit_colors=unit_colors, @@ -66,37 +111,49 @@ def __init__( same_axis=True, plot_legend=False, sparsity=sparsity, - ).plot_data + backend="matplotlib", + ax=ax2, + ) - plot_data_waveform_density = UnitWaveformDensityMapWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False - ).plot_data + ax2.set_title(None) + + ax3 = fig.add_subplot(gs[:2, 2]) + UnitWaveformDensityMapWidget( + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax3, + ) + ax3.set_ylabel(None) if we.is_extension("correlograms"): - plot_data_acc = AutoCorrelogramsWidget( + ax4 = fig.add_subplot(gs[:2, 3]) + AutoCorrelogramsWidget( we, unit_ids=[unit_id], unit_colors=unit_colors, - ).plot_data - else: - plot_data_acc = None + backend="matplotlib", + ax=ax4, + ) - # use other widget to plot data - if we.is_extension("spike_amplitudes"): - plot_data_amplitudes = AmplitudesWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True - ).plot_data - else: - plot_data_amplitudes = None + ax4.set_title(None) + ax4.set_yticks([]) - plot_data = dict( - unit_id=unit_id, - unit_location=unit_location, - plot_data_unit_locations=plot_data_unit_locations, - plot_data_waveforms=plot_data_waveforms, - plot_data_waveform_density=plot_data_waveform_density, - plot_data_acc=plot_data_acc, - plot_data_amplitudes=plot_data_amplitudes, - ) + if we.is_extension("spike_amplitudes"): + ax5 = fig.add_subplot(gs[2, :3]) + ax6 = fig.add_subplot(gs[2, 3]) + axes = np.array([ax5, ax6]) + AmplitudesWidget( + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + plot_histograms=True, + backend="matplotlib", + axes=axes, + ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 41c4ece09c..cf58e91aa0 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,12 +1,56 @@ from .unit_waveforms import UnitWaveformsWidget +from .base import to_attr class UnitTemplatesWidget(UnitWaveformsWidget): - possible_backends = {} + # doc is copied from UnitWaveformsWidget def __init__(self, *args, **kargs): kargs["plot_waveforms"] = False UnitWaveformsWidget.__init__(self, *args, **kargs) + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + + # ensure serializable for sortingview + unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + templates_dict = {} + for u_i, unit in enumerate(unit_ids): + templates_dict[unit] = {} + templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + + aw_items = [ + vv.AverageWaveformItem( + unit_id=u, + channel_ids=list(unit_id_to_channel_ids[u]), + waveform=t["mean"].astype("float32"), + waveform_std_dev=t["std"].astype("float32"), + ) + for u, t in templates_dict.items() + ] + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], + ) + else: + self.view = v_average_waveforms + + self.url = handle_display_and_url(self, self.view, **backend_kwargs) + UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index ba707a8221..e64765b44b 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity @@ -59,8 +59,6 @@ class UnitWaveformsWidget(BaseWidget): Display legend, default True """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -165,6 +163,230 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs.get("axes", None) is not None: + assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" + elif backend_kwargs.get("ax", None) is not None: + assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" + else: + if dp.same_axis: + backend_kwargs["num_axes"] = 1 + backend_kwargs["ncols"] = None + else: + backend_kwargs["num_axes"] = len(dp.unit_ids) + backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[i] + color = dp.unit_colors[unit_id] + + chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() + + # plot waveforms + if dp.plot_waveforms: + wfs = dp.wfs_by_ids[unit_id] + if dp.unit_selected_waveforms is not None: + wfs = wfs[dp.unit_selected_waveforms[unit_id]] + elif dp.max_spikes_per_unit is not None: + if len(wfs) > dp.max_spikes_per_unit: + random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] + wfs = wfs[random_idxs] + wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) + + if not dp.plot_templates: + ax.get_lines()[-1].set_label(f"{unit_id}") + + # plot template + if dp.plot_templates: + template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot( + xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id + ) + + template_label = dp.unit_ids[i] + if dp.set_title: + ax.set_title(f"template {template_label}") + + # plot channels + if dp.plot_channels: + # TODO enhance this + ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") + + if dp.same_axis and dp.plot_legend: + if hasattr(self, "legend") and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + self.we = we = data_plot["waveform_extractor"] + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.1, 0.7, 0.2] + + with plt.ioff(): + output1 = widgets.Output() + with output1: + self.fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + output2 = widgets.Output() + with output2: + self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + same_axis_button = widgets.Checkbox( + value=False, + description="same axis", + disabled=False, + ) + + plot_templates_button = widgets.Checkbox( + value=True, + description="plot templates", + disabled=False, + ) + + hide_axis_button = widgets.Checkbox( + value=True, + description="hide axis", + disabled=False, + ) + + footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) + + self.controller = { + "same_axis": same_axis_button, + "plot_templates": plot_templates_button, + "hide_axis": hide_axis_button, + } + self.controller.update(unit_controller) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.fig_wf.canvas, + left_sidebar=unit_widget, + right_sidebar=self.fig_probe.canvas, + pane_widths=ratios, + footer=footer, + ) + + # a first update + self._update_ipywidget(None) + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change): + self.fig_wf.clear() + self.ax_probe.clear() + + unit_ids = self.controller["unit_ids"].value + same_axis = self.controller["same_axis"].value + plot_templates = self.controller["plot_templates"].value + hide_axis = self.controller["hide_axis"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) + data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") + data_plot["same_axis"] = same_axis + data_plot["plot_templates"] = plot_templates + if data_plot["plot_waveforms"]: + data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + backend_kwargs = {} + + if same_axis: + backend_kwargs["ax"] = self.fig_wf.add_subplot() + data_plot["set_title"] = False + else: + backend_kwargs["figure"] = self.fig_wf + + self.plot_matplotlib(data_plot, **backend_kwargs) + if same_axis: + self.ax.axis("equal") + if hide_axis: + self.ax.axis("off") + else: + if hide_axis: + for i in range(len(unit_ids)): + ax = self.axes.flatten()[i] + ax.axis("off") + + # update probe plot + channel_locations = self.we.get_channel_locations() + self.ax_probe.plot( + channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 + ) + self.ax_probe.axis("off") + self.ax_probe.axis("equal") + + for unit in unit_ids: + channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] + self.ax_probe.plot( + channel_locations[channel_inds, 0], + channel_locations[channel_inds, 1], + ls="", + marker="o", + markersize=3, + color=self.next_data_plot["unit_colors"][unit], + ) + self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) + fig_probe = self.ax_probe.get_figure() + + self.fig_wf.canvas.draw() + self.fig_wf.canvas.flush_events() + fig_probe.canvas.draw() + fig_probe.canvas.flush_events() + def get_waveforms_scales(we, templates, channel_locations, x_offset_units=False): """ diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9f3e5e86b5..e8a6868e92 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity, get_template_extremum_channel @@ -33,8 +33,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): all channel per units, default False """ - possible_backends = {} - def __init__( self, waveform_extractor, @@ -156,3 +154,72 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + else: + if dp.same_axis: + num_axes = 1 + else: + num_axes = len(dp.unit_ids) + backend_kwargs["ncols"] = 1 + backend_kwargs["num_axes"] = num_axes + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.same_axis: + ax = self.ax + hist2d = dp.all_hist2d + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + else: + for unit_index, unit_id in enumerate(dp.unit_ids): + hist2d = dp.all_hist2d[unit_id] + ax = self.axes.flatten()[unit_index] + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[unit_index] + color = dp.unit_colors[unit_id] + ax.plot(dp.templates_flat[unit_id], color=color, lw=1) + + # final cosmetics + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + if unit_index != 0: + continue + else: + ax = self.axes.flatten()[unit_index] + chan_inds = dp.channel_inds[unit_id] + for i, chan_ind in enumerate(chan_inds): + if i != 0: + ax.axvline(i * dp.template_width, color="w", lw=3) + channel_id = dp.channel_ids[chan_ind] + x = i * dp.template_width + dp.template_width // 2 + y = (dp.bin_max + dp.bin_min) / 2.0 + ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") + + ax.set_xticks([]) + ax.set_ylabel(f"unit_id {unit_id}") diff --git a/src/spikeinterface/widgets/ipywidgets/utils.py b/src/spikeinterface/widgets/utils_ipywidgets.py similarity index 94% rename from src/spikeinterface/widgets/ipywidgets/utils.py rename to src/spikeinterface/widgets/utils_ipywidgets.py index f4b86c3fc2..a7c571d1f0 100644 --- a/src/spikeinterface/widgets/ipywidgets/utils.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -2,6 +2,13 @@ import numpy as np +def check_ipywidget_backend(): + import matplotlib + + mpl_backend = matplotlib.get_backend() + assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" + + def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): time_slider = widgets.FloatSlider( orientation="horizontal", diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py new file mode 100644 index 0000000000..a9128d7b66 --- /dev/null +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -0,0 +1,67 @@ +import matplotlib.pyplot as plt +import numpy as np + + +def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): + """ + figure/ax/axes : only one of then can be not None + """ + if figure is not None: + assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" + if num_axes is None: + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + axes = [] + nrows = int(np.ceil(num_axes / ncols)) + axes = np.full((nrows, ncols), fill_value=None, dtype=object) + for i in range(num_axes): + ax = figure.add_subplot(nrows, ncols, i + 1) + r = i // ncols + c = i % ncols + axes[r, c] = ax + elif ax is not None: + assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" + figure = ax.get_figure() + axes = np.array([[ax]]) + elif axes is not None: + assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" + axes = np.asarray(axes) + figure = axes.flatten()[0].get_figure() + else: + # 'figure/ax/axes are all None + if num_axes is None: + # one fig with one ax + figure, ax = plt.subplots(figsize=figsize) + axes = np.array([[ax]]) + else: + if num_axes == 0: + # one figure without plots (diffred subplot creation with + figure = plt.figure(figsize=figsize) + ax = None + axes = None + elif num_axes == 1: + figure = plt.figure(figsize=figsize) + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + if num_axes < ncols: + ncols = num_axes + nrows = int(np.ceil(num_axes / ncols)) + figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) + ax = None + # remove extra axes + if ncols * nrows > num_axes: + for i, extra_ax in enumerate(axes.flatten()): + if i >= num_axes: + extra_ax.remove() + r = i // ncols + c = i % ncols + axes[r, c] = None + + if figtitle is not None: + figure.suptitle(figtitle) + + return figure, axes, ax diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py new file mode 100644 index 0000000000..50bbab99df --- /dev/null +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -0,0 +1,85 @@ +import numpy as np + +from ..core.core_tools import check_json + + +def make_serializable(*args): + dict_to_serialize = {int(i): a for i, a in enumerate(args)} + serializable_dict = check_json(dict_to_serialize) + returns = () + for i in range(len(args)): + returns += (serializable_dict[str(i)],) + if len(returns) == 1: + returns = returns[0] + return returns + + +def is_notebook() -> bool: + try: + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": + return True # Jupyter notebook or qtconsole + elif shell == "TerminalInteractiveShell": + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False + + +def handle_display_and_url(widget, view, **backend_kwargs): + url = None + # TODO: put this back when figurl-jupyter is working again + # if is_notebook() and backend_kwargs["display"]: + # display(view.jupyter(height=backend_kwargs["height"])) + if backend_kwargs["generate_url"]: + figlabel = backend_kwargs.get("figlabel") + if figlabel is None: + # figlabel = widget.default_label + figlabel = "" + url = view.url(label=figlabel) + print(url) + + return url + + +def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): + import sortingview.views as vv + + if unit_properties is None: + ut_columns = [] + ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] + else: + ut_columns = [] + ut_rows = [] + values = {} + valid_unit_properties = [] + for prop_name in unit_properties: + property_values = sorting.get_property(prop_name) + # make dtype available + val0 = np.array(property_values[0]) + if val0.dtype.kind in ("i", "u"): + dtype = "int" + elif val0.dtype.kind in ("U", "S"): + dtype = "str" + elif val0.dtype.kind == "f": + dtype = "float" + elif val0.dtype.kind == "b": + dtype = "bool" + else: + print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") + continue + ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) + valid_unit_properties.append(prop_name) + + for ui, unit in enumerate(sorting.unit_ids): + for prop_name in valid_unit_properties: + property_values = sorting.get_property(prop_name) + val0 = property_values[0] + if np.isnan(property_values[ui]): + continue + values[prop_name] = property_values[ui] + ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + + v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index a6e0896e99..f3c640ff16 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,80 +1,46 @@ -from .base import define_widget_function_from_class +import warnings -# basics -from .timeseries import TimeseriesWidget +from .base import backend_kwargs_desc -# waveform -from .unit_waveforms import UnitWaveformsWidget -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms_density_map import UnitWaveformDensityMapWidget - -# isi/ccg/acg +from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget - -# peak activity - -# drift/motion - -# spikes-traces -from .spikes_on_traces import SpikesOnTracesWidget - -# PC related - -# units on probe -from .unit_locations import UnitLocationsWidget -from .spike_locations import SpikeLocationsWidget - -# unit presence - - -# comparison related - -# correlogram comparison - -# amplitudes -from .amplitudes import AmplitudesWidget -from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget - -# metrics +from .motion import MotionWidget from .quality_metrics import QualityMetricsWidget +from .sorting_summary import SortingSummaryWidget +from .spike_locations import SpikeLocationsWidget +from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget - - -# motion/drift -from .motion import MotionWidget - -# similarity from .template_similarity import TemplateSimilarityWidget - - +from .traces import TracesWidget from .unit_depths import UnitDepthsWidget - -# summary +from .unit_locations import UnitLocationsWidget from .unit_summary import UnitSummaryWidget -from .sorting_summary import SortingSummaryWidget +from .unit_templates import UnitTemplatesWidget +from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +from .unit_waveforms import UnitWaveformsWidget widget_list = [ - AmplitudesWidget, AllAmplitudesDistributionsWidget, + AmplitudesWidget, AutoCorrelogramsWidget, CrossCorrelogramsWidget, + MotionWidget, QualityMetricsWidget, + SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, TemplateMetricsWidget, - MotionWidget, TemplateSimilarityWidget, - TimeseriesWidget, + TracesWidget, + UnitDepthsWidget, UnitLocationsWidget, + UnitSummaryWidget, UnitTemplatesWidget, - UnitWaveformsWidget, UnitWaveformDensityMapWidget, - UnitDepthsWidget, - # summary - UnitSummaryWidget, - SortingSummaryWidget, + UnitWaveformsWidget, ] @@ -89,37 +55,41 @@ **backend_kwargs: kwargs {backend_kwargs} """ - backend_str = f" {list(wcls.possible_backends.keys())}" + # backend_str = f" {list(wcls.possible_backends.keys())}" + backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" - for backend, backend_plotter in wcls.possible_backends.items(): - backend_kwargs_desc = backend_plotter.backend_kwargs_desc - if len(backend_kwargs_desc) > 0: + # for backend, backend_plotter in wcls.possible_backends.items(): + for backend in wcls.get_possible_backends(): + # backend_kwargs_desc = backend_plotter.backend_kwargs_desc + kwargs_desc = backend_kwargs_desc[backend] + if len(kwargs_desc) > 0: backend_kwargs_str += f"\n {backend}:\n\n" - for bk, bk_dsc in backend_kwargs_desc.items(): + for bk, bk_dsc in kwargs_desc.items(): backend_kwargs_str += f" * {bk}: {bk_dsc}\n" wcls.__doc__ = wcls_doc.format(backends=backend_str, backend_kwargs=backend_kwargs_str) # make function for all widgets -plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") -plot_all_amplitudes_distributions = define_widget_function_from_class( - AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" -) -plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") -plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") -plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") -plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") -plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") -plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") -plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") -plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") -plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") -plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") -plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") -plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") -plot_unit_waveforms_density_map = define_widget_function_from_class( - UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" -) -plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") -plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") -plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") +plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget +plot_amplitudes = AmplitudesWidget +plot_autocorrelograms = AutoCorrelogramsWidget +plot_crosscorrelograms = CrossCorrelogramsWidget +plot_motion = MotionWidget +plot_quality_metrics = QualityMetricsWidget +plot_sorting_summary = SortingSummaryWidget +plot_spike_locations = SpikeLocationsWidget +plot_spikes_on_traces = SpikesOnTracesWidget +plot_template_metrics = TemplateMetricsWidget +plot_template_similarity = TemplateSimilarityWidget +plot_traces = TracesWidget +plot_unit_depths = UnitDepthsWidget +plot_unit_locations = UnitLocationsWidget +plot_unit_summary = UnitSummaryWidget +plot_unit_templates = UnitTemplatesWidget +plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_unit_waveforms = UnitWaveformsWidget + + +def plot_timeseries(*args, **kwargs): + warnings.warn("plot_timeseries() is now plot_traces()") + return plot_traces(*args, **kwargs)