diff --git a/.github/actions/install-wine/action.yml b/.github/actions/install-wine/action.yml new file mode 100644 index 0000000000..3ae08ecd34 --- /dev/null +++ b/.github/actions/install-wine/action.yml @@ -0,0 +1,21 @@ +name: Install packages +description: This action installs the package and its dependencies for testing + +inputs: + python-version: + description: 'Python version to set up' + required: false + os: + description: 'Operating system to set up' + required: false + +runs: + using: "composite" + steps: + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index ac5130bade..8f88e84039 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -75,6 +75,10 @@ jobs: echo "Extractors changed" echo "EXTRACTORS_CHANGED=true" >> $GITHUB_OUTPUT fi + if [[ $file == *"plexon2"* ]]; then + echo "Plexon2 changed" + echo "PLEXON2_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/preprocessing/"* ]]; then echo "Preprocessing changed" echo "PREPROCESSING_CHANGED=true" >> $GITHUB_OUTPUT @@ -122,6 +126,9 @@ jobs: done - name: Set execute permissions on run_tests.sh run: chmod +x .github/run_tests.sh + - name: Install Wine (Plexon2) + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + uses: ./.github/actions/install-wine - name: Test core run: ./.github/run_tests.sh core - name: Test extractors diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml new file mode 100644 index 0000000000..0e522e6baa --- /dev/null +++ b/.github/workflows/installation-tips-test.yml @@ -0,0 +1,33 @@ +name: Creates Conda Install for Installation Tips + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly at noon UTC on Sundays + +jobs: + installation-tips-testing: + name: Build Conda Env on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + label: linux_dandi + - os: macos-latest + label: mac + - os: windows-latest + label: windows + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Test Conda Environment Creation + uses: conda-incubator/setup-miniconda@v2.2.0 + with: + environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ced1ee6a2f..07601cd208 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black files: ^src/ diff --git a/doc/api.rst b/doc/api.rst index 2e9fc1567a..43f79386e6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -91,17 +91,21 @@ NEO-based .. autofunction:: read_mcsraw .. autofunction:: read_neuralynx .. autofunction:: read_neuralynx_sorting + .. autofunction:: read_neuroexplorer .. autofunction:: read_neuroscope .. autofunction:: read_nix .. autofunction:: read_openephys .. autofunction:: read_openephys_event .. autofunction:: read_plexon .. autofunction:: read_plexon_sorting + .. autofunction:: read_plexon2 + .. autofunction:: read_plexon2_sorting .. autofunction:: read_spike2 .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + Non-NEO-based ~~~~~~~~~~~~~ .. automodule:: spikeinterface.extractors diff --git a/doc/conf.py b/doc/conf.py index 847de9ff42..15cb65d46a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,6 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting" ] numpydoc_show_class_members = False diff --git a/doc/development/development.rst b/doc/development/development.rst index cd613a27e6..f1371639c3 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -1,5 +1,5 @@ Development -========== +=========== How to contribute ----------------- diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index c921b13719..37646c2146 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -4,11 +4,11 @@ Analyse Neuropixels datasets This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing. -.. code:: ipython3 +.. code:: ipython %matplotlib inline -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si @@ -16,7 +16,7 @@ including custom pre- and post-processing. import matplotlib.pyplot as plt from pathlib import Path -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/') @@ -29,7 +29,7 @@ Read the data The ``SpikeGLX`` folder can contain several “streams” (AP, LF and NIDQ). We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) stream_names @@ -43,7 +43,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we do not load the sync channel, so the probe is automatically loaded raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) @@ -58,7 +58,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we automaticaly have the probe loaded! raw_rec.get_probe().to_dataframe() @@ -201,7 +201,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(raw_rec, ax=ax, with_channel_ids=True) @@ -229,7 +229,7 @@ Let’s do something similar to the IBL destriping chain (See - instead of interpolating bad channels, we remove then. - instead of highpass_spatial_filter() we use common_reference() -.. code:: ipython3 +.. code:: ipython rec1 = si.highpass_filter(raw_rec, freq_min=400.) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) @@ -271,7 +271,7 @@ preprocessing chain wihtout to save the entire file to disk. Everything is lazy, so you can change the previsous cell (parameters, step order, …) and visualize it immediatly. -.. code:: ipython3 +.. code:: ipython # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) @@ -287,7 +287,7 @@ is lazy, so you can change the previsous cell (parameters, step order, .. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png -.. code:: ipython3 +.. code:: ipython # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) @@ -326,7 +326,7 @@ Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -344,7 +344,7 @@ parallelization mechanism of SpikeInterface. write_binary_recording: 0%| | 0/1139 [00:00 0.9) -.. code:: ipython3 +.. code:: ipython keep_units = metrics.query(our_query) keep_unit_ids = keep_units.index.values @@ -1071,11 +1071,11 @@ In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again). -.. code:: ipython3 +.. code:: ipython we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean') -.. code:: ipython3 +.. code:: ipython we_clean @@ -1091,12 +1091,12 @@ them again). Then we export figures to a report folder -.. code:: ipython3 +.. code:: ipython # export spike sorting report to a folder si.export_report(we_clean, base_folder / 'report', format='png') -.. code:: ipython3 +.. code:: ipython we_clean = si.load_waveforms(base_folder / 'waveforms_clean') we_clean diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index a235eb4272..a923393916 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -11,7 +11,7 @@ dataset, and we will then perform some pre-processing, run a spike sorting algorithm, post-process the spike sorting output, perform curation (manual and automatic), and compare spike sorting results. -.. code:: ipython3 +.. code:: ipython import matplotlib.pyplot as plt from pprint import pprint @@ -19,7 +19,7 @@ curation (manual and automatic), and compare spike sorting results. The spikeinterface module by itself import only the spikeinterface.core submodule which is not useful for end user -.. code:: ipython3 +.. code:: ipython import spikeinterface @@ -35,7 +35,7 @@ We need to import one by one different submodules separately - ``comparison`` : comparison of spike sorting output - ``widgets`` : visualization -.. code:: ipython3 +.. code:: ipython import spikeinterface as si # import core only import spikeinterface.extractors as se @@ -56,14 +56,14 @@ This is useful for notebooks, but it is a heavier import because internally many more dependencies are imported (scipy/sklearn/networkx/matplotlib/h5py…) -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si Before getting started, we can set some global arguments for parallel processing. For this example, let’s use 4 jobs and time chunks of 1s: -.. code:: ipython3 +.. code:: ipython global_job_kwargs = dict(n_jobs=4, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs) @@ -75,7 +75,7 @@ Then we can open it. Note that `MEArec `__ simulated file contains both “recording” and a “sorting” object. -.. code:: ipython3 +.. code:: ipython local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') recording, sorting_true = se.read_mearec(local_path) @@ -102,7 +102,7 @@ ground-truth information of the spiking activity of each unit. Let’s use the ``spikeinterface.widgets`` module to visualize the traces and the raster plots. -.. code:: ipython3 +.. code:: ipython w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) @@ -118,7 +118,7 @@ and the raster plots. This is how you retrieve info from a ``BaseRecording``\ … -.. code:: ipython3 +.. code:: ipython channel_ids = recording.get_channel_ids() fs = recording.get_sampling_frequency() @@ -143,7 +143,7 @@ This is how you retrieve info from a ``BaseRecording``\ … …and a ``BaseSorting`` -.. code:: ipython3 +.. code:: ipython num_seg = recording.get_num_segments() unit_ids = sorting_true.get_unit_ids() @@ -173,7 +173,7 @@ any probe in the probeinterface collections can be downloaded and set to a ``Recording`` object. In this case, the MEArec dataset already handles a ``Probe`` and we don’t need to set it *manually*. -.. code:: ipython3 +.. code:: ipython probe = recording.get_probe() print(probe) @@ -200,7 +200,7 @@ All these preprocessing steps are “lazy”. The computation is done on demand when we call ``recording.get_traces(...)`` or when we save the object to disk. -.. code:: ipython3 +.. code:: ipython recording_cmr = recording recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) @@ -224,7 +224,7 @@ Now you are ready to spike sort using the ``spikeinterface.sorters`` module! Let’s first check which sorters are implemented and which are installed -.. code:: ipython3 +.. code:: ipython print('Available sorters', ss.available_sorters()) print('Installed sorters', ss.installed_sorters()) @@ -241,7 +241,7 @@ machine. We can see we have HerdingSpikes and Tridesclous installed. Spike sorters come with a set of parameters that users can change. The available parameters are dictionaries and can be accessed with: -.. code:: ipython3 +.. code:: ipython print("Tridesclous params:") pprint(ss.get_default_sorter_params('tridesclous')) @@ -279,7 +279,7 @@ available parameters are dictionaries and can be accessed with: Let’s run ``tridesclous`` and change one of the parameter, say, the ``detect_threshold``: -.. code:: ipython3 +.. code:: ipython sorting_TDC = ss.run_sorter(sorter_name="tridesclous", recording=recording_preprocessed, detect_threshold=4) print(sorting_TDC) @@ -292,7 +292,7 @@ Let’s run ``tridesclous`` and change one of the parameter, say, the Alternatively we can pass full dictionary containing the parameters: -.. code:: ipython3 +.. code:: ipython other_params = ss.get_default_sorter_params('tridesclous') other_params['detect_threshold'] = 6 @@ -310,7 +310,7 @@ Alternatively we can pass full dictionary containing the parameters: Let’s run ``spykingcircus2`` as well, with default parameters: -.. code:: ipython3 +.. code:: ipython sorting_SC2 = ss.run_sorter(sorter_name="spykingcircus2", recording=recording_preprocessed) print(sorting_SC2) @@ -341,7 +341,7 @@ If a sorter is not installed locally, we can also avoid to install it and run it anyways, using a container (Docker or Singularity). For example, let’s run ``Kilosort2`` using Docker: -.. code:: ipython3 +.. code:: ipython sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, docker_image=True, verbose=True) @@ -370,7 +370,7 @@ extracts, their waveforms, and stores them to disk. These waveforms are helpful to compute the average waveform, or “template”, for each unit and then to compute, for example, quality metrics. -.. code:: ipython3 +.. code:: ipython we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, 'waveforms_folder', overwrite=True) print(we_TDC) @@ -399,7 +399,7 @@ compute spike amplitudes, PCA projections, unit locations, and more. Let’s compute some postprocessing information that will be needed later for computing quality metrics, exporting, and visualization: -.. code:: ipython3 +.. code:: ipython amplitudes = spost.compute_spike_amplitudes(we_TDC) unit_locations = spost.compute_unit_locations(we_TDC) @@ -411,7 +411,7 @@ for computing quality metrics, exporting, and visualization: All of this postprocessing functions are saved in the waveforms folder as extensions: -.. code:: ipython3 +.. code:: ipython print(we_TDC.get_available_extension_names()) @@ -424,7 +424,7 @@ as extensions: Importantly, waveform extractors (and all extensions) can be reloaded at later times: -.. code:: ipython3 +.. code:: ipython we_loaded = si.load_waveforms('waveforms_folder') print(we_loaded.get_available_extension_names()) @@ -439,7 +439,7 @@ Once we have computed all these postprocessing information, we can compute quality metrics (different quality metrics require different extensions - e.g., drift metrics resuire ``spike_locations``): -.. code:: ipython3 +.. code:: ipython qm_params = sqm.get_default_qm_params() pprint(qm_params) @@ -485,14 +485,14 @@ extensions - e.g., drift metrics resuire ``spike_locations``): Since the recording is very short, let’s change some parameters to accomodate the duration: -.. code:: ipython3 +.. code:: ipython qm_params["presence_ratio"]["bin_duration_s"] = 1 qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5 qm_params["drift"]["interval_s"] = 2 qm_params["drift"]["min_spikes_per_interval"] = 2 -.. code:: ipython3 +.. code:: ipython qm = sqm.compute_quality_metrics(we_TDC, qm_params=qm_params) display(qm) @@ -522,7 +522,7 @@ We can export a sorting summary and quality metrics plot using the ``sortingview`` backend. This will generate shareble links for web-based visualization. -.. code:: ipython3 +.. code:: ipython w1 = sw.plot_quality_metrics(we_TDC, display=False, backend="sortingview") @@ -530,7 +530,7 @@ visualization. https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://901a11ba31ae9ab512a99bdf36a3874173249d87&label=SpikeInterface%20-%20Quality%20Metrics -.. code:: ipython3 +.. code:: ipython w2 = sw.plot_sorting_summary(we_TDC, display=False, curation=True, backend="sortingview") @@ -543,7 +543,7 @@ curation. In the example above, we manually merged two units (0, 4) and added accept labels (2, 6, 7). After applying our curation, we can click on the “Save as snapshot (sha://)” and copy the URI: -.. code:: ipython3 +.. code:: ipython uri = "sha1://68cb54a9aaed2303fb82dedbc302c853e818f1b6" @@ -562,7 +562,7 @@ Alternatively, we can export the data locally to Phy. `Phy `_ is a GUI for manual curation of the spike sorting output. To export to phy you can run: -.. code:: ipython3 +.. code:: ipython sexp.export_to_phy(we_TDC, 'phy_folder_for_TDC', verbose=True) @@ -581,7 +581,7 @@ After curating with Phy, the curated sorting can be reloaded to SpikeInterface. In this case, we exclude the units that have been labeled as “noise”: -.. code:: ipython3 +.. code:: ipython sorting_curated_phy = se.read_phy('phy_folder_for_TDC', exclude_cluster_groups=["noise"]) @@ -589,7 +589,7 @@ Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select sorted units with a SNR above a certain threshold: -.. code:: ipython3 +.. code:: ipython keep_mask = (qm['snr'] > 10) & (qm['isi_violations_ratio'] < 0.01) print("Mask:", keep_mask.values) @@ -615,7 +615,7 @@ outputs. We can either: 3. compare the output of multiple sorters (Tridesclous, SpykingCircus2, Kilosort2) -.. code:: ipython3 +.. code:: ipython comp_gt = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC) comp_pair = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_SC2) @@ -625,7 +625,7 @@ outputs. We can either: When comparing with a ground-truth sorting (1,), you can get the sorting performance and plot a confusion matrix -.. code:: ipython3 +.. code:: ipython print(comp_gt.get_performance()) w_conf = sw.plot_confusion_matrix(comp_gt) @@ -659,7 +659,7 @@ performance and plot a confusion matrix When comparing two sorters (2.), we can see the matching of units between sorters. Units which are not matched has -1 as unit id: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_12 @@ -683,7 +683,7 @@ between sorters. Units which are not matched has -1 as unit id: or the reverse: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_21 @@ -709,7 +709,7 @@ When comparing multiple sorters (3.), you can extract a ``BaseSorting`` object with units in agreement between sorters. You can also plot a graph showing how the units are matched between the sorters. -.. code:: ipython3 +.. code:: ipython sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 7ff98a666b..5c4476187b 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -1,4 +1,4 @@ -.. code:: ipython3 +.. code:: ipython %matplotlib inline %load_ext autoreload @@ -42,7 +42,7 @@ Neuropixels 1 and a Neuropixels 2 probe. Here we will use *dataset1* with *neuropixel1*. This dataset is the *“hello world”* for drift correction in the spike sorting community! -.. code:: ipython3 +.. code:: ipython from pathlib import Path import matplotlib.pyplot as plt @@ -52,12 +52,12 @@ Here we will use *dataset1* with *neuropixel1*. This dataset is the import spikeinterface.full as si -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick') dataset_folder = base_folder / 'dataset1/NP1' -.. code:: ipython3 +.. code:: ipython # read the file raw_rec = si.read_spikeglx(dataset_folder) @@ -77,7 +77,7 @@ We preprocess the recording with bandpass filter and a common median reference. Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! -.. code:: ipython3 +.. code:: ipython def preprocess_chain(rec): rec = si.bandpass_filter(rec, freq_min=300., freq_max=6000.) @@ -85,7 +85,7 @@ motion estimation to get a better estimate of peak locations! return rec rec = preprocess_chain(raw_rec) -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -101,7 +101,7 @@ parameters for each step. Here we also save the motion correction results into a folder to be able to load them later. -.. code:: ipython3 +.. code:: ipython # internally, we can explore a preset like this # every parameter can be overwritten at runtime @@ -143,13 +143,13 @@ to load them later. -.. code:: ipython3 +.. code:: ipython # lets try theses 3 presets some_presets = ('rigid_fast', 'kilosort_like', 'nonrigid_accurate') # some_presets = ('nonrigid_accurate', ) -.. code:: ipython3 +.. code:: ipython # compute motion with 3 presets for preset in some_presets: @@ -195,7 +195,7 @@ A few comments on the figures: (2000-3000um) experience some drift, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. -.. code:: ipython3 +.. code:: ipython for preset in some_presets: # load @@ -243,7 +243,7 @@ locations (:py:func:`correct_motion_on_peaks()`) Case 1 is used before running a spike sorter and the case 2 is used here to display the results. -.. code:: ipython3 +.. code:: ipython from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks @@ -303,7 +303,7 @@ run times Presets and related methods have differents accuracies but also computation speeds. It is good to have this in mind! -.. code:: ipython3 +.. code:: ipython run_times = [] for preset in some_presets: diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index a6752e2c9d..5aed24ca41 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -129,13 +129,15 @@ For raw recording formats, we currently support: * **MCS RAW** :py:func:`~spikeinterface.extractors.read_mcsraw()` * **MEArec** :py:func:`~spikeinterface.extractors.read_mearec()` * **Mountainsort MDA** :py:func:`~spikeinterface.extractors.read_mda_recording()` +* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Neurodata Without Borders** :py:func:`~spikeinterface.extractors.read_nwb_recording()` * **Neuroscope** :py:func:`~spikeinterface.coextractorsre.read_neuroscope_recording()` +* **Neuroexplorer** :py:func:`~spikeinterface.extractors.read_neuroexplorer()` * **NIX** :py:func:`~spikeinterface.extractors.read_nix()` -* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Open Ephys Legacy** :py:func:`~spikeinterface.extractors.read_openephys()` * **Open Ephys Binary** :py:func:`~spikeinterface.extractors.read_openephys()` -* **Plexon** :py:func:`~spikeinterface.corextractorse.read_plexon()` +* **Plexon** :py:func:`~spikeinterface.extractors.read_plexon()` +* **Plexon 2** :py:func:`~spikeinterface.extractors.read_plexon2()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_recording()` * **SpikeGLX** :py:func:`~spikeinterface.extractors.read_spikeglx()` * **SpikeGLX IBL compressed** :py:func:`~spikeinterface.extractors.read_cbin_ibl()` @@ -165,6 +167,7 @@ For sorted data formats, we currently support: * **Neuralynx spikes** :py:func:`~spikeinterface.extractors.read_neuralynx_sorting()` * **NPZ (created by SpikeInterface)** :py:func:`~spikeinterface.core.read_npz_sorting()` * **Plexon spikes** :py:func:`~spikeinterface.extractors.read_plexon_sorting()` +* **Plexon 2 spikes** :py:func:`~spikeinterface.extractors.read_plexon2_sorting()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_sorting()` * **Spyking Circus** :py:func:`~spikeinterface.extractors.read_spykingcircus()` * **Trideclous** :py:func:`~spikeinterface.extractors.read_tridesclous()` diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 62c0d6b8d4..afedc4f982 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -9,7 +9,7 @@ Overview Mechanical drift, often observed in recordings, is currently a major issue for spike sorting. This is especially striking with the new generation of high-density devices used for in-vivo electrophyisology such as the neuropixel electrodes. -The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_) +The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_ [SteinmetzDataset]_) Long story short, the main idea is the same as the one used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index ee3234af6c..8c7c0a2cc3 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -38,6 +38,7 @@ For more details about each metric and it's availability and use within SpikeInt qualitymetrics/snr qualitymetrics/noise_cutoff qualitymetrics/silhouette_score + qualitymetrics/synchrony This code snippet shows how to compute quality metrics (with or without principal components) in SpikeInterface: diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index 275805c6a7..b924cdbf73 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -1,3 +1,5 @@ +.. _silhouette_score : + Silhouette score (:code:`silhouette`, :code:`silhouette_full`) ============================================================== @@ -7,7 +9,7 @@ Calculation Gives the ratio between the cohesiveness of a cluster and its separation from other clusters. Values for silhouette score range from -1 to 1. -For the full method as proposed by [Rouseeuw]_, the pairwise distances between each point +For the full method as proposed by [Rousseeuw]_, the pairwise distances between each point and every other point :math:`a(i)` in a cluster :math:`C_i` are calculated and then iterating through every other cluster's distances between the points in :math:`C_i` and the points in :math:`C_j` are calculated. The cluster with the minimal mean distance is taken to be :math:`b(i)`. The @@ -48,6 +50,13 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.simplified_silhouette_score + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_score + Literature ---------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index b41e194466..2f566bf8a7 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -1,3 +1,5 @@ +.. _synchrony: + Synchrony Metrics (:code:`synchrony`) ===================================== @@ -39,11 +41,11 @@ The SpikeInterface implementation is a partial port of the low-level complexity References ---------- -.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics +.. automodule:: spikeinterface.qualitymetrics.misc_metrics .. autofunction:: compute_synchrony_metrics Literature ---------- -Based on concepts described in Gruen_ +Based on concepts described in [Gruen]_ diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 26f2365202..34ab3d1151 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -130,7 +130,7 @@ Parameters from all sorters can be retrieved with these functions: .. _containerizedsorters: Running sorters in Docker/Singularity Containers ------------------------------------------------ +------------------------------------------------ One of the biggest bottlenecks for users is installing spike sorting software. To alleviate this, we build and maintain containerized versions of several popular spike sorters on the diff --git a/installation_tips/check_your_install.py b/installation_tips/check_your_install.py index 20809ec6c0..2b13a941cd 100644 --- a/installation_tips/check_your_install.py +++ b/installation_tips/check_your_install.py @@ -103,8 +103,8 @@ def _clean(): try: func() done = '...OK' - except: - done = '...Fail' + except Exception as err: + done = f'...Fail, Error: {err}' print(label, done) if platform.system() == "Windows": diff --git a/installation_tips/full_spikeinterface_environment_linux_dandi.yml b/installation_tips/full_spikeinterface_environment_linux_dandi.yml index d402f6805f..2ed176b16c 100755 --- a/installation_tips/full_spikeinterface_environment_linux_dandi.yml +++ b/installation_tips/full_spikeinterface_environment_linux_dandi.yml @@ -3,13 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - mamba - # numpy 1.22 break numba which break tridesclous - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - joblib - tqdm - matplotlib - h5py @@ -31,12 +29,11 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 - spyking-circus>=1.1.0 # - phy==2.0b5 diff --git a/installation_tips/full_spikeinterface_environment_mac.yml b/installation_tips/full_spikeinterface_environment_mac.yml index 8b872981aa..7ce4a149cc 100755 --- a/installation_tips/full_spikeinterface_environment_mac.yml +++ b/installation_tips/full_spikeinterface_environment_mac.yml @@ -3,12 +3,10 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -30,13 +28,12 @@ dependencies: - pip: # - PyQt5 - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 - - mountainsort4>=1.0.0 - - mountainsort5>=0.3.0 + # - mountainsort4>=1.0.0 isosplit5 fails on pip install for mac + # - mountainsort5>=0.3.0 diff --git a/installation_tips/full_spikeinterface_environment_windows.yml b/installation_tips/full_spikeinterface_environment_windows.yml index 8c793edcb1..38c26e6a78 100755 --- a/installation_tips/full_spikeinterface_environment_windows.yml +++ b/installation_tips/full_spikeinterface_environment_windows.yml @@ -3,12 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -26,11 +25,10 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 diff --git a/pyproject.toml b/pyproject.toml index e17d6f6506..51efe1f585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files + "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ @@ -147,6 +148,7 @@ docs = [ "sphinx_rtd_theme==1.0.0", "sphinx-gallery", "numpydoc", + "ipython", # for notebooks in the gallery "MEArec", # Use as an example diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbf77682ee..401c498f03 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1,7 +1,8 @@ import math - +import warnings import numpy as np from typing import Union, Optional, List, Literal +import warnings from .numpyextractors import NumpyRecording, NumpySorting @@ -31,7 +32,7 @@ def generate_recording( set_probe: Optional[bool] = True, ndim: Optional[int] = 2, seed: Optional[int] = None, - mode: Literal["lazy", "legacy"] = "legacy", + mode: Literal["lazy", "legacy"] = "lazy", ) -> BaseRecording: """ Generate a recording object. @@ -51,10 +52,10 @@ def generate_recording( The number of dimensions of the probe, by default 2. Set to 3 to make 3 dimensional probes. seed : Optional[int] A seed for the np.ramdom.default_rng function - mode: str ["lazy", "legacy"] Default "legacy". + mode: str ["lazy", "legacy"] Default "lazy". "legacy": generate a NumpyRecording with white noise. - This mode is kept for backward compatibility and will be deprecated in next release. - "lazy": return a NoiseGeneratorRecording + This mode is kept for backward compatibility and will be deprecated version 0.100.0. + "lazy": return a NoiseGeneratorRecording instance. Returns ------- @@ -64,6 +65,10 @@ def generate_recording( seed = _ensure_seed(seed) if mode == "legacy": + warnings.warn( + "generate_recording() : mode='legacy' will be deprecated in version 0.100.0. Use mode='lazy' instead.", + DeprecationWarning, + ) recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed) elif mode == "lazy": recording = NoiseGeneratorRecording( @@ -538,7 +543,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol class NoiseGeneratorRecording(BaseRecording): """ - A lazy recording that generates random samples if and only if `get_traces` is called. + A lazy recording that generates white noise samples if and only if `get_traces` is called. This done by tiling small noise chunk. @@ -555,7 +560,7 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : List[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_level: float, default 5: + noise_level: float, default 1: Std of the white noise dtype : Optional[Union[np.dtype, str]], default='float32' The dtype of the recording. Note that only np.float32 and np.float64 are supported. @@ -581,7 +586,7 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_level: float = 5.0, + noise_level: float = 1.0, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", @@ -647,7 +652,7 @@ def __init__( if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) self.noise_block = ( - rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level + rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) * noise_level ) elif self.strategy == "on_the_fly": pass @@ -664,35 +669,35 @@ def get_traces( start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) - start_frame_mod = start_frame % self.noise_block_size - end_frame_mod = end_frame % self.noise_block_size + start_frame_within_block = start_frame % self.noise_block_size + end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame traces = np.empty(shape=(num_samples, self.num_channels), dtype=self.dtype) - start_block_index = start_frame // self.noise_block_size - end_block_index = end_frame // self.noise_block_size + first_block_index = start_frame // self.noise_block_size + last_block_index = end_frame // self.noise_block_size pos = 0 - for block_index in range(start_block_index, end_block_index + 1): + for block_index in range(first_block_index, last_block_index + 1): if self.strategy == "tile_pregenerated": noise_block = self.noise_block elif self.strategy == "on_the_fly": rng = np.random.default_rng(seed=(self.seed, block_index)) - noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) + noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) noise_block *= self.noise_level - if block_index == start_block_index: - if start_block_index != end_block_index: - end_first_block = self.noise_block_size - start_frame_mod - traces[:end_first_block] = noise_block[start_frame_mod:] + if block_index == first_block_index: + if first_block_index != last_block_index: + end_first_block = self.noise_block_size - start_frame_within_block + traces[:end_first_block] = noise_block[start_frame_within_block:] pos += end_first_block else: # special case when unique block - traces[:] = noise_block[start_frame_mod : start_frame_mod + traces.shape[0]] - elif block_index == end_block_index: - if end_frame_mod > 0: - traces[pos:] = noise_block[:end_frame_mod] + traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples] + elif block_index == last_block_index: + if end_frame_within_block > 0: + traces[pos:] = noise_block[:end_frame_within_block] else: traces[pos : pos + self.noise_block_size] = noise_block pos += self.noise_block_size @@ -710,7 +715,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 1024, + num_channels: int = 384, seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: @@ -719,7 +724,7 @@ def generate_recording_by_size( This is a convenience wrapper around the NoiseGeneratorRecording class where only the size in GiB (NOT GB!) is specified. - It is generated with 1024 channels and a sampling frequency of 1 Hz. The duration is manipulted to + It is generated with 384 channels and a sampling frequency of 1 Hz. The duration is manipulted to produced the desired size. Seee GeneratorRecording for more details. @@ -727,7 +732,7 @@ def generate_recording_by_size( Parameters ---------- full_traces_size_GiB : float - The size in gibibyte (GiB) of the recording. + The size in gigabytes (GiB) of the recording. num_channels: int Number of channels. seed : int, optional @@ -740,7 +745,7 @@ def generate_recording_by_size( dtype = np.dtype("float32") sampling_frequency = 30_000.0 # Hz - num_channels = 1024 + num_channels = 384 GiB_to_bytes = 1024**3 full_traces_size_bytes = int(full_traces_size_GiB * GiB_to_bytes) @@ -1037,13 +1042,14 @@ def __init__( parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, upsample_vector: Union[List[int], None] = None, - check_borbers: bool = True, + check_borders: bool = False, ) -> None: templates = np.asarray(templates) - if check_borbers: + # TODO: this should be external to this class. It is not the responsability of this class to check the templates + if check_borders: self._check_templates(templates) - # lets test this only once so force check_borbers=false for kwargs - check_borbers = False + # lets test this only once so force check_borders=False for kwargs + check_borders = False self.templates = templates channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) @@ -1131,7 +1137,7 @@ def __init__( "nbefore": nbefore, "amplitude_factor": amplitude_factor, "upsample_vector": upsample_vector, - "check_borbers": check_borbers, + "check_borders": check_borders, } if parent_recording is None: self._kwargs["num_samples"] = num_samples @@ -1144,8 +1150,8 @@ def _check_templates(templates: np.ndarray): threshold = 0.01 * max_value if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: - raise Exception( - "Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger." + warnings.warn( + "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." ) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9ea5ad59e7..b11f40a441 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar raise NotImplementedError -# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) @@ -138,7 +138,99 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): - pass + """ + This class is useful to inject a sorting object in the node pipepline mechanism. + It allows to compute some post-processing steps with the same machinery used for sorting components. + This is used by: + * compute_spike_locations() + * compute_amplitude_scalings() + * compute_spike_amplitudes() + * compute_principal_components() + + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. + If False, the max channel is computed for each spike given a radius around the template max channel. + extremum_channel_inds: dict of int + The extremum channel index dict given from template. + radius_um: float (default 50.) + The radius to find the real max channel. + Used only when channel_from_template=False + peak_sign: str (default "neg") + Peak sign to find the max channel. + Used only when channel_from_template=False + """ + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): + PipelineNode.__init__(self, recording, return_output=False) + + self.channel_from_template = channel_from_template + + assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" + + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) + + if not channel_from_template: + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.peak_sign = peak_sign + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0 = np.searchsorted(self.peaks["segment_index"], segment_index) + i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + if not self.channel_from_template: + # handle channel spike per spike + for i, peak in enumerate(local_peaks): + chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + # TODO: "amplitude" ??? + + return (local_peaks,) + + +def sorting_to_peaks(sorting, extremum_channel_inds): + spikes = sorting.to_spike_vector() + peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks["sample_index"] = spikes["sample_index"] + extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) + peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] + peaks["amplitude"] = 0.0 + peaks["segment_index"] = spikes["segment_index"] + return peaks class WaveformsNode(PipelineNode): @@ -423,7 +515,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): + elif isinstance(node, PeakSource): node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) else: # TODO later when in master: change the signature of all nodes (or maybe not!) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 85f41924c1..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -12,9 +12,10 @@ from spikeinterface.core.node_pipeline import ( run_node_pipeline, PeakRetriever, + SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype, + sorting_to_peaks, ) @@ -78,99 +79,107 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - # print(extremum_channel_inds) - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - # print(ext_channel_inds) - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) - peaks["sample_index"] = spikes["sample_index"] - peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0.0 - peaks["segment_index"] = 0 - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 + peaks = sorting_to_peaks(sorting, extremum_channel_inds) + peak_retriever = PeakRetriever(recording, peaks) - dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True + # channel index is from template + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) - - nodes = [ - peak_retriever, - dense_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - - output = run_node_pipeline( + # channel index is per spike + spike_retriever_S = SpikeRetriever( recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) + + # test with 3 differents first nodes + for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + # one step only : squeeze output + nodes = [ + peak_source, + AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), + ] + step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + + # 3 nodes two have outputs + ms_before = 0.5 + ms_after = 1.0 + peak_retriever = PeakRetriever(recording, peaks) + dense_waveforms = ExtractDenseWaveforms( + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False) + amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True) + waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True) + denoised_waveforms_rms = WaveformsRootMeanSquare( + recording, parents=[peak_source, waveform_denoiser], return_output=True + ) + + nodes = [ + peak_source, + dense_waveforms, + waveform_denoiser, + amplitue_extraction, + waveforms_rms, + denoised_waveforms_rms, + ] + + # gather memory mode + output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") + amplitudes, waveforms_rms, denoised_waveforms_rms = output + assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) + + num_peaks = peaks.shape[0] + num_channels = recording.get_num_channels() + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + # gather npy mode + folder = cache_folder / f"pipeline_folder_{loop}" + if folder.is_dir(): + shutil.rmtree(folder) + output = run_node_pipeline( + recording, + nodes, + job_kwargs, + gather_mode="npy", + folder=folder, + names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + ) + amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output + + amplitudes_file = folder / "amplitudes.npy" + assert amplitudes_file.is_file() + amplitudes3 = np.load(amplitudes_file) + assert np.array_equal(amplitudes, amplitudes2) + assert np.array_equal(amplitudes2, amplitudes3) + + waveforms_rms_file = folder / "waveforms_rms.npy" + assert waveforms_rms_file.is_file() + waveforms_rms3 = np.load(waveforms_rms_file) + assert np.array_equal(waveforms_rms, waveforms_rms2) + assert np.array_equal(waveforms_rms2, waveforms_rms3) + + denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" + assert denoised_waveforms_rms_file.is_file() + denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) + assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) + assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) + + # Test pickle mechanism + for node in nodes: + import pickle + + pickled_node = pickle.dumps(node) + unpickled_node = pickle.loads(pickled_node) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 8da47b1940..e9cf1bfb5f 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -8,7 +8,9 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, -) # allocate_waveforms_buffers, distribute_waveforms_to_buffers + extract_waveforms_to_single_buffer, + split_waveforms_by_units, +) if hasattr(pytest, "global_test_folder"): @@ -52,96 +54,95 @@ def test_waveform_tools(): unit_ids = sorting.unit_ids some_job_kwargs = [ - {}, {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}, {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] + some_modes = [ + {"mode": "memmap"}, + {"mode": "shared_memory"}, + ] + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) + + some_sparsity = [ + dict(sparsity_mask=None), + dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + ] # memmap mode - list_wfs = [] + list_wfs_dense = [] + list_wfs_sparse = [] for j, job_kwargs in enumerate(some_job_kwargs): - wf_folder = cache_folder / f"test_waveform_tools_{j}" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir(parents=True) - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - _check_all_wf_equal(list_wfs) - - # memory - if platform.system() != "Windows": - # shared memory on windows is buggy... - list_wfs = [] - for job_kwargs in some_job_kwargs: - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=dtype, - sparsity_mask=None, - copy=True, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - # to avoid warning we need to first destroy arrays then sharedmemm object - # del wfs_arrays - # del wfs_arrays_info - _check_all_wf_equal(list_wfs) - - # with sparsity - wf_folder = cache_folder / "test_waveform_tools_sparse" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") - job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs) - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=sparsity_mask, - copy=False, - **job_kwargs, - ) + for k, mode_kwargs in enumerate(some_modes): + for l, sparsity_kwargs in enumerate(some_sparsity): + # print() + # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + + if mode_kwargs["mode"] == "memmap": + wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + wf_folder.mkdir(parents=True) + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder"] = wf_folder + + wfs_arrays = extract_waveforms_to_buffers( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + for unit_ind, unit_id in enumerate(unit_ids): + wf = wfs_arrays[unit_id] + assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) + + if sparsity_kwargs["sparsity_mask"] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path"] = wf_file_path + + all_waveforms = extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + wfs_arrays = split_waveforms_by_units( + unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"] + ) + if sparsity_kwargs["sparsity_mask"] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + _check_all_wf_equal(list_wfs_dense) + _check_all_wf_equal(list_wfs_sparse) if __name__ == "__main__": diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a10c209f47..da8e3d64b6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -36,7 +36,7 @@ def extract_waveforms_to_buffers( Same as calling allocate_waveforms_buffers() and then distribute_waveforms_to_buffers(). - Important note: for the "shared_memory" mode wfs_arrays_info contains reference to + Important note: for the "shared_memory" mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. To avoid this a copy to non shared memmory can be perform at the end. @@ -66,17 +66,17 @@ def extract_waveforms_to_buffers( If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then wfs_arrays_info is also return. Please keep in mind that wfs_arrays_info - need to be referenced as long as wfs_arrays will be used otherwise it will be very hard to debug. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually {} Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep) - wfs_arrays_info: dict of info + arrays_info: dict of info Optionally return in case of shared_memory if copy=False. Dictionary to "construct" array in workers process (memmap file or sharemem info) """ @@ -89,7 +89,7 @@ def extract_waveforms_to_buffers( dtype = "float32" dtype = np.dtype(dtype) - wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers( + waveforms_by_units, arrays_info = allocate_waveforms_buffers( recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask ) @@ -97,7 +97,7 @@ def extract_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -107,19 +107,19 @@ def extract_waveforms_to_buffers( ) if mode == "memmap": - return wfs_arrays + return waveforms_by_units elif mode == "shared_memory": if copy: - wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + waveforms_by_units = {unit_id: arr.copy() for unit_id, arr in waveforms_by_units.items()} # release all sharedmem buffer for unit_id in unit_ids: - shm = wfs_arrays_info[unit_id][0] + shm = arrays_info[unit_id][0] if shm is not None: # empty array have None shm.unlink() - return wfs_arrays + return waveforms_by_units else: - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info extract_waveforms_to_buffers.__doc__ = extract_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) @@ -131,7 +131,7 @@ def allocate_waveforms_buffers( """ Allocate memmap or shared memory buffers before snippet extraction. - Important note: for the shared memory mode wfs_arrays_info contains reference to + Important note: for the shared memory mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -158,9 +158,9 @@ def allocate_waveforms_buffers( Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep - wfs_arrays_info: dict of info + arrays_info: dict of info Dictionary to "construct" array in workers process (memmap file or sharemem) """ @@ -173,8 +173,8 @@ def allocate_waveforms_buffers( folder = Path(folder) # prepare buffers - wfs_arrays = {} - wfs_arrays_info = {} + waveforms_by_units = {} + arrays_info = {} for unit_ind, unit_id in enumerate(unit_ids): n_spikes = np.sum(spikes["unit_index"] == unit_ind) if sparsity_mask is None: @@ -186,34 +186,35 @@ def allocate_waveforms_buffers( if mode == "memmap": filename = str(folder / f"waveforms_{unit_id}.npy") arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = filename + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = filename elif mode == "shared_memory": - if n_spikes == 0: + if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: arr, shm = make_shared_array(shape, dtype) shm_name = shm.name - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info def distribute_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, mode="memmap", sparsity_mask=None, + job_name=None, **job_kwargs, ): """ @@ -221,7 +222,7 @@ def distribute_waveforms_to_buffers( Buffers must be pre-allocated with the `allocate_waveforms_buffers()` function. - Important note, for "shared_memory" mode wfs_arrays_info contain reference to + Important note, for "shared_memory" mode arrays_info contain reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -233,7 +234,7 @@ def distribute_waveforms_to_buffers( This vector can be spikes = Sorting.to_spike_vector() unit_ids: list ot numpy List of unit_ids - wfs_arrays_info: dict + arrays_info: dict Dictionary to "construct" array in workers process (memmap file or sharemem) nbefore: int N samples before spike @@ -257,14 +258,14 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _waveform_extractor_chunk - init_func = _init_worker_waveform_extractor + func = _worker_distribute_buffers + init_func = _init_worker_distribute_buffers init_args = ( recording, unit_ids, spikes, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -272,9 +273,9 @@ def distribute_waveforms_to_buffers( mode, sparsity_mask, ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name = f"extract waveforms {mode} multi buffer" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -282,8 +283,8 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_waveform_extractor( - recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask +def _init_worker_distribute_buffers( + recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker worker_ctx = {} @@ -296,23 +297,23 @@ def _init_worker_waveform_extractor( if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["wfs_arrays_info"] = wfs_arrays_info + worker_ctx["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - wfs_arrays = {} + waveforms_by_units = {} shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + for unit_id, (shm, shm_name, dtype, shape) in arrays_info.items(): if shm_name is None: arr = np.zeros(shape=shape, dtype=dtype) else: shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr + waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + worker_ctx["waveforms_by_units"] = waveforms_by_units worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes @@ -328,7 +329,7 @@ def _init_worker_waveform_extractor( # used by ChunkRecordingExecutor -def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -349,16 +350,9 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) # slice in absolut in spikes vector l0 = i0 + s0 @@ -382,10 +376,10 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) if worker_ctx["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] + filename = worker_ctx["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] + wfs = worker_ctx["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -397,6 +391,282 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] +def extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + file_path=None, + dtype=None, + sparsity_mask=None, + copy=False, + job_name=None, + **job_kwargs, +): + """ + Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. + + Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across + units. + + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + + Important note: for the "shared_memory" mode wf_array_info contains reference to + the shared memmory buffer, this variable must be referenced as long as arrays is used. + This variable must also unlink() when the array is de-referenced. + To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + numpy array. + + + Parameters + ---------- + recording: recording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = Sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + N samples before spike + nafter: int + N samples after spike + mode: str + Mode to use ('memmap' | 'shared_memory') + return_scaled: bool + Scale traces before exporting to buffer or not. + file_path: str or path + In case of memmap mode, file to save npy file. + dtype: numpy.dtype + dtype for waveforms buffer + sparsity_mask: None or array of bool + If not None shape must be must be (len(unit_ids), len(channel_ids)) + copy: bool + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + + {} + + Returns + ------- + all_waveforms: numpy array + Single array with shape (nump_spikes, num_samples, num_channels) + + wf_array_info: dict of info + Optionally return in case of shared_memory if copy=False. + Dictionary to "construct" array in workers process (memmap file or sharemem info) + """ + nsamples = nbefore + nafter + + dtype = np.dtype(dtype) + if mode == "shared_memory": + assert file_path is None + else: + file_path = Path(file_path) + + num_spikes = spikes.size + if sparsity_mask is None: + num_chans = recording.get_num_channels() + else: + num_chans = max(np.sum(sparsity_mask, axis=1)) + shape = (num_spikes, nsamples, num_chans) + + if mode == "memmap": + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) + elif mode == "shared_memory": + if num_spikes == 0 or num_chans == 0: + all_waveforms = np.zeros(shape, dtype=dtype) + shm = None + shm_name = None + else: + all_waveforms, shm = make_shared_array(shape, dtype) + shm_name = shm.name + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) + else: + raise ValueError("allocate_waveforms_buffers bad mode") + + job_kwargs = fix_job_kwargs(job_kwargs) + + if num_spikes > 0: + # and run + func = _worker_distribute_single_buffer + init_func = _init_worker_distribute_single_buffer + + init_args = ( + recording, + spikes, + wf_array_info, + nbefore, + nafter, + return_scaled, + mode, + sparsity_mask, + ) + if job_name is None: + job_name = f"extract waveforms {mode} mono buffer" + + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor.run() + + if mode == "memmap": + return all_waveforms + elif mode == "shared_memory": + if copy: + if shm is not None: + # release all sharedmem buffer + # empty array have None + shm.unlink() + return all_waveforms.copy() + else: + return all_waveforms, wf_array_info + + +def _init_worker_distribute_single_buffer( + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask +): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["wf_array_info"] = wf_array_info + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode + + if mode == "memmap": + filename = wf_array_info["filename"] + all_waveforms = np.load(str(filename), mmap_mode="r+") + worker_ctx["all_waveforms"] = all_waveforms + elif mode == "shared_memory": + from multiprocessing.shared_memory import SharedMemory + + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] + shm = SharedMemory(shm_name) + all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["all_waveforms"] = all_waveforms + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + segment_slices = worker_ctx["segment_slices"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + return_scaled = worker_ctx["return_scaled"] + sparsity_mask = worker_ctx["sparsity_mask"] + all_waveforms = worker_ctx["all_waveforms"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + if sparsity_mask is None: + all_waveforms[spike_index, :, :] = wf + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + all_waveforms[spike_index, :, : wf.shape[1]] = wf + + if worker_ctx["mode"] == "memmap": + all_waveforms.flush() + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None, folder=None): + """ + Split a single buffer waveforms into waveforms by units (multi buffers or multi files). + + Parameters + ---------- + unit_ids: list or numpy array + List of unit ids + spikes: numpy array + The spike vector + all_waveforms : numpy array + Single buffer containing all waveforms + sparsity_mask : None or numpy array + Optionally the boolean sparsity mask + folder : None or str or Path + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. + + Returns + ------- + waveforms_by_units: dict of array + A dict of arrays. + In case of folder not None, this contain the memmap of the files. + """ + if folder is not None: + folder = Path(folder) + waveforms_by_units = {} + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes["unit_index"] == unit_index + if sparsity_mask is not None: + chan_mask = sparsity_mask[unit_index, :] + num_chans = np.sum(chan_mask) + wfs = all_waveforms[mask, :, :][:, :, :num_chans] + else: + wfs = all_waveforms[mask, :, :] + + if folder is None: + waveforms_by_units[unit_id] = wfs + else: + np.save(folder / f"waveforms_{unit_id}.npy", wfs) + # this avoid keeping in memory all waveforms + waveforms_by_units[unit_id] = np.load(f"waveforms_{unit_id}.npy", mmap_mode="r") + + return waveforms_by_units + + def has_exceeding_spikes(recording, sorting): """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0d9da1960a..0b11b72b2a 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -16,6 +16,7 @@ read_neuroscope_sorting, read_neuroscope, ) +from .neuroexplorer import NeuroExplorerRecordingExtractor, read_neuroexplorer from .nix import NixRecordingExtractor, read_nix from .openephys import ( OpenEphysLegacyRecordingExtractor, @@ -25,6 +26,14 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting +from .plexon2 import ( + Plexon2SortingExtractor, + Plexon2RecordingExtractor, + Plexon2EventExtractor, + read_plexon2, + read_plexon2_sorting, + read_plexon2_event, +) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -49,12 +58,19 @@ OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, PlexonRecordingExtractor, + Plexon2RecordingExtractor, Spike2RecordingExtractor, SpikeGadgetsRecordingExtractor, SpikeGLXRecordingExtractor, TdtRecordingExtractor, + NeuroExplorerRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor] +neo_sorting_extractors_list = [ + BlackrockSortingExtractor, + MEArecSortingExtractor, + NeuralynxSortingExtractor, + Plexon2SortingExtractor, +] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor] +neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py new file mode 100644 index 0000000000..2c8603cb9c --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor + + +class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading NEX (NeuroExplorer data format) files. + + Based on :py:class:`neo.rawio.NeuroExplorerRawIO` + + Importantly, at the moment, this recorder only extracts one channel of the recording. + This is because the NeuroExplorerRawIO class does not support multi-channel recordings + as in the NeuroExplorer format they might have different sampling rates. + + Consider extracting all the channels and then concatenating them with the aggregate_channels function. + + >>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor + >>> from spikeinterface.core import aggregate_channels + >>> + >>> file_path="/the/path/to/your/nex/file.nex" + >>> + >>> streams = NeuroExplorerRecordingExtractor.get_streams(file_path=file_path) + >>> stream_names = streams[0] + >>> + >>> your_signal_stream_names = "Here goes the logic to filter from stream names the ones that you know have the same sampling rate and you want to aggregate" + >>> + >>> recording_list = [NeuroExplorerRecordingExtractor(file_path=file_path, stream_name=stream_name) for stream_name in your_signal_stream_names] + >>> recording = aggregate_channels(recording_list) + + + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + For this neo reader streams are defined by their sampling frequency. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "NeuroExplorerRawIO" + name = "neuroexplorer" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = {"filename": str(file_path)} + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) + self.extra_requirements.append("neo[edf]") + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +read_neuroexplorer = define_function_from_class(source_class=NeuroExplorerRecordingExtractor, name="read_neuroexplorer") diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py new file mode 100644 index 0000000000..8dbfc67e90 --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -0,0 +1,103 @@ +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor + + +class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading plexon pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2SortingExtractor(NeoBaseSortingExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + sampling_frequency: float, default: None + The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + neo_returns_frames = True + name = "plexon2" + + def __init__(self, file_path, sampling_frequency=None): + from neo.rawio import Plexon2RawIO + + neo_kwargs = self.map_to_neo_kwargs(file_path) + neo_reader = Plexon2RawIO(**neo_kwargs) + neo_reader.parse_header() + NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) + self._kwargs.update({"file_path": str(file_path), "sampling_frequency": sampling_frequency}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2EventExtractor(NeoBaseEventExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + folder_path: str + + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, folder_path, block_index=None): + neo_kwargs = self.map_to_neo_kwargs(folder_path) + NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + + @classmethod + def map_to_neo_kwargs(cls, folder_path): + neo_kwargs = {"filename": str(folder_path)} + return neo_kwargs + + +read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") +read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name="read_plexon2_event") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 900bdec06e..257c1d566a 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,5 +1,6 @@ import unittest -from platform import python_version +import platform +import subprocess from packaging import version import pytest @@ -18,6 +19,38 @@ local_folder = get_global_dataset_folder() / "ephy_testing_data" +def has_plexon2_dependencies(): + """ + Check if required Plexon2 dependencies are installed on different OS. + """ + + os_type = platform.system() + + if os_type == "Windows": + # On Windows, no need for additional dependencies + return True + + elif os_type == "Linux": + # Check for 'wine' using dpkg + try: + result_wine = subprocess.run( + ["dpkg", "-l", "wine"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + except subprocess.CalledProcessError: + return False + + # Check for 'zugbruecke' using pip + try: + import zugbruecke + + return True + except ImportError: + return False + else: + # Not sure about MacOS + raise ValueError(f"Unsupported OS: {os_type}") + + class MearecRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MEArecRecordingExtractor downloads = ["mearec"] @@ -109,6 +142,17 @@ class NeuroScopeRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +class NeuroExplorerRecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = NeuroExplorerRecordingExtractor + downloads = ["neuroexplorer"] + entities = [ + ("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel01"}), + ("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel02"}), + ("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel01"}), + ("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel02"}), + ] + + class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuroScopeSortingExtractor downloads = ["neuroscope"] @@ -218,7 +262,7 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(python_version()) >= version.parse("3.10"), + version.parse(platform.python_version()) >= version.parse("3.10"), reason="Sonpy only testing with Python < 3.10!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -290,6 +334,32 @@ def test_pickling(self): pass +# We run plexon2 tests only if we have dependencies (wine) +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] + + +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] + + +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] + + if __name__ == "__main__": # test = MearecSortingTest() # test = SpikeGLXRecordingTest() @@ -304,7 +374,7 @@ def test_pickling(self): # test = PlexonRecordingTest() # test = PlexonSortingTest() # test = NeuralynxRecordingTest() - test = BlackrockRecordingTest() + test = Plexon2RecordingTest() # test = MCSRawRecordingTest() # test = KiloSortSortingTest() # test = Spike2RecordingTest() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,9 +24,25 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + self.collisions = None + + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +62,11 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +91,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +119,8 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -104,10 +132,18 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - self._extension_data[f"amplitude_scalings"] = amp_scalings + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) + + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """ @@ -154,6 +190,8 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=True, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +203,27 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: True + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + delta_collision_ms: float, default: 2 + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +240,14 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +268,8 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,14 +281,24 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + # in this case we extend the margin to be able to get with collisions outside the chunk + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx): + # from sklearn.linear_model import LinearRegression from scipy.stats import linregress # recover variables of the worker @@ -250,16 +312,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -272,8 +332,26 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets - # get all waveforms - for spike in local_spikes: + # set colliding spikes apart (if needed) + if handle_collisions: + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions_local = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + ) + else: + collisions_local = {} + + # compute the scaling for each spike + scalings = np.zeros(len(local_spikes), dtype=float) + # collision_global transforms local spike index to global spike index + collisions_global = {} + for spike_index, spike in enumerate(local_spikes): + if spike_index in collisions_local.keys(): + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -291,10 +369,335 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + + # here we use linregress, which is equivalent to using sklearn LinearRegression with fit_intercept=True + # y = local_waveform.flatten() + # X = template.flatten()[:, np.newaxis] + # reg = LinearRegression(positive=True, fit_intercept=True).fit(X, y) + # scalings[spike_index] = reg.coef_[0] linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(collisions_local) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions_local.items(): + scaled_amps = fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_global.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) + else: + scalings = np.array([]) + collisions_global = {} + + return (scalings, collisions_global) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + - return (scalings,) +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): + """ + Finds the collisions between spikes. + + Parameters + ---------- + spikes: np.array + An array of spikes + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. + """ + # TODO: refactor to speed-up + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] + + # find the possible spikes per and post within delta_collision_samples + consecutive_window_pre = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, + ) + consecutive_window_post = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) + ) + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict + + +def fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. + + Parameters + ---------- + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the colliding spikes. + """ + from sklearn.linear_model import LinearRegression + + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in collision: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression(fit_intercept=True, positive=True).fit(X, y) + scalings = reg.coef_ + return scalings + + +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ax=None +# ): +# import matplotlib.pyplot as plt + +# if ax is None: +# fig, ax = plt.subplots() + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) + +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 991d79506e..233625e09e 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -694,11 +694,10 @@ def compute_principal_components( If True and pc scores are already in the waveform extractor folders, pc scores are loaded and not recomputed. n_components: int Number of components fo PCA - default 5 - mode: str + mode: str, default: 'by_channel_local' - 'by_channel_local': a local PCA is fitted for each channel (projection by channel) - 'by_channel_global': a global PCA is fitted for all channels (projection by channel) - 'concatenated': channels are concatenated and a global PCA is fitted - default 'by_channel_local' sparsity: ChannelSparsity or None The sparsity to apply to waveforms. If waveform_extractor is already sparse, the default sparsity will be used - default None @@ -735,6 +734,7 @@ def compute_principal_components( >>> # run for all spikes in the SortingExtractor >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ + if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index ff2a5b60c2..e2ef6e6794 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -186,11 +186,12 @@ def correct_motion( Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: - * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks' - * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion' - * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion' + + * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` + * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion` + * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion` Possible presets: {} diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 90b39aee8a..7d43982853 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -293,7 +293,7 @@ def __init__( means = means[None, :] stds = np.std(random_data, axis=0) stds = stds[None, :] - gain = 1 / stds + gain = 1.0 / stds offset = -means / stds if int_scale is not None: diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 223122e927..c2ffcc6843 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -46,7 +46,7 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) num_seg = recording.get_num_segments() if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and not np.isscalar(list_periods[0]): + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays list_periods = [list_periods] diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index b62a73a8cb..764acc9852 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -78,13 +78,18 @@ def test_zscore(): assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01) assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01) + +def test_zscore_int(): + seed = 1 + rec = generate_recording(seed=seed, mode="legacy") rec_int = scale(rec, dtype="int16", gain=100) with pytest.raises(AssertionError): - rec4 = zscore(rec_int, dtype=None) - rec4 = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) - tr = rec4.get_traces(segment_index=0) - trace_mean = np.mean(tr, axis=0) - trace_std = np.std(tr, axis=0) + zscore(rec_int, dtype=None) + + zscore_recording = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) + traces = zscore_recording.get_traces(segment_index=0) + trace_mean = np.mean(traces, axis=0) + trace_std = np.std(traces, axis=0) assert np.all(np.abs(trace_mean) < 1) assert np.all(np.abs(trace_std - 256) < 1) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index b7b267251d..59000211d4 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -736,6 +736,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the simplified silhouette score for each cluster. The value ranges from -1 (bad clustering) to 1 (good clustering). The simplified silhoutte score utilizes the centroids for distance calculations rather than pairwise calculations. + Parameters ---------- all_pcs : 2d array @@ -744,12 +745,14 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Simplified Silhouette Score for this unit + References - ------------ + ---------- Based on simplified silhouette score suggested by [Hruschka]_ """ @@ -782,6 +785,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the silhouette score which is a marker of cluster quality ranging from -1 (bad clustering) to 1 (good clustering). Distances are all calculated as pairwise comparisons of all data points. + Parameters ---------- all_pcs : 2d array @@ -790,12 +794,14 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Silhouette Score for this unit + References - ------------ + ---------- Based on [Rousseeuw]_ """ diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index f3719b934b..bc52ea2c70 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -65,10 +65,9 @@ def detect_peaks( This avoid reading the recording multiple times. gather_mode: str How to gather the results: - * "memory": results are returned as in-memory numpy arrays - * "npy": results are stored to .npy files in `folder` + folder: str or Path If gather_mode is "npy", the folder where the files are created. names: list @@ -81,9 +80,11 @@ def detect_peaks( ------- peaks: array Detected peaks. + Notes ----- This peak detection ported from tridesclous into spikeinterface. + """ assert method in detect_peak_methods diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index e7bcff0832..ae036d1ba1 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -30,11 +30,10 @@ class SpikesOnTracesWidget(BaseWidget): sparsity : ChannelSparsity or None Optional ChannelSparsity to apply. If WaveformExtractor is already sparse, the argument is ignored, default None - unit_colors : dict or None + unit_colors : dict or None If given, a dictionary with unit ids as keys and colors as values, default None If None, then the get_unit_colors() is internally used. (matplotlib backend) - mode : str - Three possible modes, default 'auto': + mode : str in ('line', 'map', 'auto') default: 'auto' * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9a2ec4a215..e025f779c1 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -26,6 +26,7 @@ class TracesWidget(BaseWidget): List with start time and end time, default None mode: str Three possible modes, default 'auto': + * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) @@ -51,11 +52,6 @@ class TracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 add_legend : bool If True adds legend to figures, default True - - Returns - ------- - W: TracesWidget - The output widget """ def __init__( diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f3c640ff16..9c89b3981e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -54,6 +54,12 @@ {backends} **backend_kwargs: kwargs {backend_kwargs} + + + Returns + ------- + w : BaseWidget + The output widget object. """ # backend_str = f" {list(wcls.possible_backends.keys())}" backend_str = f" {wcls.get_possible_backends()}"