diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml new file mode 100644 index 0000000000..b3bf08954d --- /dev/null +++ b/.github/workflows/installation-tips-test.yml @@ -0,0 +1,33 @@ +name: Creates Conda Install for Installation Tips + +on: + workflow_dispatch: + schedule: + - cron: "0 12 * * 0" # Weekly at noon UTC on Sundays + +jobs: + installation-tips-testing: + name: Build Conda Env on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + label: linux_dandi + - os: macos-latest + label: mac + - os: windows-latest + label: windows + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Test Conda Environment Creation + uses: conda-incubator/setup-miniconda@v2.2.0 + with: + environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml diff --git a/doc/api.rst b/doc/api.rst index 7a72ead33f..97c956c2f6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,6 +19,8 @@ spikeinterface.core .. autofunction:: extract_waveforms .. autofunction:: load_waveforms .. autofunction:: compute_sparsity + .. autoclass:: ChannelSparsity + :members: .. autoclass:: BinaryRecordingExtractor .. autoclass:: ZarrRecordingExtractor .. autoclass:: BinaryFolderRecording @@ -48,10 +50,6 @@ spikeinterface.core .. autofunction:: get_template_extremum_channel .. autofunction:: get_template_extremum_channel_peak_shift .. autofunction:: get_template_extremum_amplitude - -.. - .. autofunction:: read_binary - .. autofunction:: read_zarr .. autofunction:: append_recordings .. autofunction:: concatenate_recordings .. autofunction:: split_recording @@ -59,6 +57,8 @@ spikeinterface.core .. autofunction:: append_sortings .. autofunction:: split_sorting .. autofunction:: select_segment_sorting + .. autofunction:: read_binary + .. autofunction:: read_zarr Low-level ~~~~~~~~~ @@ -67,7 +67,6 @@ Low-level :noindex: .. autoclass:: BaseWaveformExtractorExtension - .. autoclass:: ChannelSparsity .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors @@ -83,6 +82,7 @@ NEO-based .. autofunction:: read_alphaomega_event .. autofunction:: read_axona .. autofunction:: read_biocam + .. autofunction:: read_binary .. autofunction:: read_blackrock .. autofunction:: read_ced .. autofunction:: read_intan @@ -98,10 +98,13 @@ NEO-based .. autofunction:: read_openephys_event .. autofunction:: read_plexon .. autofunction:: read_plexon_sorting + .. autofunction:: read_plexon2 + .. autofunction:: read_plexon2_sorting .. autofunction:: read_spike2 .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + .. autofunction:: read_zarr Non-NEO-based @@ -214,8 +217,10 @@ spikeinterface.sorters .. autofunction:: print_sorter_versions .. autofunction:: get_sorter_description .. autofunction:: run_sorter + .. autofunction:: run_sorter_jobs .. autofunction:: run_sorters .. autofunction:: run_sorter_by_property + .. autofunction:: read_sorter_folder Low level ~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index 847de9ff42..15cb65d46a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,6 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting" ] numpydoc_show_class_members = False diff --git a/doc/development/development.rst b/doc/development/development.rst index cd613a27e6..7656da11ab 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -1,5 +1,5 @@ Development -========== +=========== How to contribute ----------------- @@ -14,7 +14,7 @@ There are various ways to contribute to SpikeInterface as a user or developer. S * Writing unit tests to expand code coverage and use case scenarios. * Reporting bugs and issues. -We use a forking workflow _ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: +We use a forking workflow ``_ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: * Fork the SpikeInterface repository. * Create a new branch (e.g., :code:`git switch -c my-contribution`). @@ -22,7 +22,7 @@ We use a forking workflow _ . +While we appreciate all the contributions please be mindful of the cost of reviewing pull requests ``_ . How to run tests locally @@ -201,7 +201,7 @@ Implement a new extractor SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the supported formats list (***ref***). Most of the extractord rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new `neo.rawio `_ class. +Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:`neo.rawio.BaseRawIO` class (see `example `_). Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index c921b13719..37646c2146 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -4,11 +4,11 @@ Analyse Neuropixels datasets This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing. -.. code:: ipython3 +.. code:: ipython %matplotlib inline -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si @@ -16,7 +16,7 @@ including custom pre- and post-processing. import matplotlib.pyplot as plt from pathlib import Path -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/') @@ -29,7 +29,7 @@ Read the data The ``SpikeGLX`` folder can contain several “streams” (AP, LF and NIDQ). We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) stream_names @@ -43,7 +43,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we do not load the sync channel, so the probe is automatically loaded raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) @@ -58,7 +58,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we automaticaly have the probe loaded! raw_rec.get_probe().to_dataframe() @@ -201,7 +201,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(raw_rec, ax=ax, with_channel_ids=True) @@ -229,7 +229,7 @@ Let’s do something similar to the IBL destriping chain (See - instead of interpolating bad channels, we remove then. - instead of highpass_spatial_filter() we use common_reference() -.. code:: ipython3 +.. code:: ipython rec1 = si.highpass_filter(raw_rec, freq_min=400.) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) @@ -271,7 +271,7 @@ preprocessing chain wihtout to save the entire file to disk. Everything is lazy, so you can change the previsous cell (parameters, step order, …) and visualize it immediatly. -.. code:: ipython3 +.. code:: ipython # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) @@ -287,7 +287,7 @@ is lazy, so you can change the previsous cell (parameters, step order, .. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png -.. code:: ipython3 +.. code:: ipython # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) @@ -326,7 +326,7 @@ Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -344,7 +344,7 @@ parallelization mechanism of SpikeInterface. write_binary_recording: 0%| | 0/1139 [00:00 0.9) -.. code:: ipython3 +.. code:: ipython keep_units = metrics.query(our_query) keep_unit_ids = keep_units.index.values @@ -1071,11 +1071,11 @@ In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again). -.. code:: ipython3 +.. code:: ipython we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean') -.. code:: ipython3 +.. code:: ipython we_clean @@ -1091,12 +1091,12 @@ them again). Then we export figures to a report folder -.. code:: ipython3 +.. code:: ipython # export spike sorting report to a folder si.export_report(we_clean, base_folder / 'report', format='png') -.. code:: ipython3 +.. code:: ipython we_clean = si.load_waveforms(base_folder / 'waveforms_clean') we_clean diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index a235eb4272..a923393916 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -11,7 +11,7 @@ dataset, and we will then perform some pre-processing, run a spike sorting algorithm, post-process the spike sorting output, perform curation (manual and automatic), and compare spike sorting results. -.. code:: ipython3 +.. code:: ipython import matplotlib.pyplot as plt from pprint import pprint @@ -19,7 +19,7 @@ curation (manual and automatic), and compare spike sorting results. The spikeinterface module by itself import only the spikeinterface.core submodule which is not useful for end user -.. code:: ipython3 +.. code:: ipython import spikeinterface @@ -35,7 +35,7 @@ We need to import one by one different submodules separately - ``comparison`` : comparison of spike sorting output - ``widgets`` : visualization -.. code:: ipython3 +.. code:: ipython import spikeinterface as si # import core only import spikeinterface.extractors as se @@ -56,14 +56,14 @@ This is useful for notebooks, but it is a heavier import because internally many more dependencies are imported (scipy/sklearn/networkx/matplotlib/h5py…) -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si Before getting started, we can set some global arguments for parallel processing. For this example, let’s use 4 jobs and time chunks of 1s: -.. code:: ipython3 +.. code:: ipython global_job_kwargs = dict(n_jobs=4, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs) @@ -75,7 +75,7 @@ Then we can open it. Note that `MEArec `__ simulated file contains both “recording” and a “sorting” object. -.. code:: ipython3 +.. code:: ipython local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') recording, sorting_true = se.read_mearec(local_path) @@ -102,7 +102,7 @@ ground-truth information of the spiking activity of each unit. Let’s use the ``spikeinterface.widgets`` module to visualize the traces and the raster plots. -.. code:: ipython3 +.. code:: ipython w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) @@ -118,7 +118,7 @@ and the raster plots. This is how you retrieve info from a ``BaseRecording``\ … -.. code:: ipython3 +.. code:: ipython channel_ids = recording.get_channel_ids() fs = recording.get_sampling_frequency() @@ -143,7 +143,7 @@ This is how you retrieve info from a ``BaseRecording``\ … …and a ``BaseSorting`` -.. code:: ipython3 +.. code:: ipython num_seg = recording.get_num_segments() unit_ids = sorting_true.get_unit_ids() @@ -173,7 +173,7 @@ any probe in the probeinterface collections can be downloaded and set to a ``Recording`` object. In this case, the MEArec dataset already handles a ``Probe`` and we don’t need to set it *manually*. -.. code:: ipython3 +.. code:: ipython probe = recording.get_probe() print(probe) @@ -200,7 +200,7 @@ All these preprocessing steps are “lazy”. The computation is done on demand when we call ``recording.get_traces(...)`` or when we save the object to disk. -.. code:: ipython3 +.. code:: ipython recording_cmr = recording recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) @@ -224,7 +224,7 @@ Now you are ready to spike sort using the ``spikeinterface.sorters`` module! Let’s first check which sorters are implemented and which are installed -.. code:: ipython3 +.. code:: ipython print('Available sorters', ss.available_sorters()) print('Installed sorters', ss.installed_sorters()) @@ -241,7 +241,7 @@ machine. We can see we have HerdingSpikes and Tridesclous installed. Spike sorters come with a set of parameters that users can change. The available parameters are dictionaries and can be accessed with: -.. code:: ipython3 +.. code:: ipython print("Tridesclous params:") pprint(ss.get_default_sorter_params('tridesclous')) @@ -279,7 +279,7 @@ available parameters are dictionaries and can be accessed with: Let’s run ``tridesclous`` and change one of the parameter, say, the ``detect_threshold``: -.. code:: ipython3 +.. code:: ipython sorting_TDC = ss.run_sorter(sorter_name="tridesclous", recording=recording_preprocessed, detect_threshold=4) print(sorting_TDC) @@ -292,7 +292,7 @@ Let’s run ``tridesclous`` and change one of the parameter, say, the Alternatively we can pass full dictionary containing the parameters: -.. code:: ipython3 +.. code:: ipython other_params = ss.get_default_sorter_params('tridesclous') other_params['detect_threshold'] = 6 @@ -310,7 +310,7 @@ Alternatively we can pass full dictionary containing the parameters: Let’s run ``spykingcircus2`` as well, with default parameters: -.. code:: ipython3 +.. code:: ipython sorting_SC2 = ss.run_sorter(sorter_name="spykingcircus2", recording=recording_preprocessed) print(sorting_SC2) @@ -341,7 +341,7 @@ If a sorter is not installed locally, we can also avoid to install it and run it anyways, using a container (Docker or Singularity). For example, let’s run ``Kilosort2`` using Docker: -.. code:: ipython3 +.. code:: ipython sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, docker_image=True, verbose=True) @@ -370,7 +370,7 @@ extracts, their waveforms, and stores them to disk. These waveforms are helpful to compute the average waveform, or “template”, for each unit and then to compute, for example, quality metrics. -.. code:: ipython3 +.. code:: ipython we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, 'waveforms_folder', overwrite=True) print(we_TDC) @@ -399,7 +399,7 @@ compute spike amplitudes, PCA projections, unit locations, and more. Let’s compute some postprocessing information that will be needed later for computing quality metrics, exporting, and visualization: -.. code:: ipython3 +.. code:: ipython amplitudes = spost.compute_spike_amplitudes(we_TDC) unit_locations = spost.compute_unit_locations(we_TDC) @@ -411,7 +411,7 @@ for computing quality metrics, exporting, and visualization: All of this postprocessing functions are saved in the waveforms folder as extensions: -.. code:: ipython3 +.. code:: ipython print(we_TDC.get_available_extension_names()) @@ -424,7 +424,7 @@ as extensions: Importantly, waveform extractors (and all extensions) can be reloaded at later times: -.. code:: ipython3 +.. code:: ipython we_loaded = si.load_waveforms('waveforms_folder') print(we_loaded.get_available_extension_names()) @@ -439,7 +439,7 @@ Once we have computed all these postprocessing information, we can compute quality metrics (different quality metrics require different extensions - e.g., drift metrics resuire ``spike_locations``): -.. code:: ipython3 +.. code:: ipython qm_params = sqm.get_default_qm_params() pprint(qm_params) @@ -485,14 +485,14 @@ extensions - e.g., drift metrics resuire ``spike_locations``): Since the recording is very short, let’s change some parameters to accomodate the duration: -.. code:: ipython3 +.. code:: ipython qm_params["presence_ratio"]["bin_duration_s"] = 1 qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5 qm_params["drift"]["interval_s"] = 2 qm_params["drift"]["min_spikes_per_interval"] = 2 -.. code:: ipython3 +.. code:: ipython qm = sqm.compute_quality_metrics(we_TDC, qm_params=qm_params) display(qm) @@ -522,7 +522,7 @@ We can export a sorting summary and quality metrics plot using the ``sortingview`` backend. This will generate shareble links for web-based visualization. -.. code:: ipython3 +.. code:: ipython w1 = sw.plot_quality_metrics(we_TDC, display=False, backend="sortingview") @@ -530,7 +530,7 @@ visualization. https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://901a11ba31ae9ab512a99bdf36a3874173249d87&label=SpikeInterface%20-%20Quality%20Metrics -.. code:: ipython3 +.. code:: ipython w2 = sw.plot_sorting_summary(we_TDC, display=False, curation=True, backend="sortingview") @@ -543,7 +543,7 @@ curation. In the example above, we manually merged two units (0, 4) and added accept labels (2, 6, 7). After applying our curation, we can click on the “Save as snapshot (sha://)” and copy the URI: -.. code:: ipython3 +.. code:: ipython uri = "sha1://68cb54a9aaed2303fb82dedbc302c853e818f1b6" @@ -562,7 +562,7 @@ Alternatively, we can export the data locally to Phy. `Phy `_ is a GUI for manual curation of the spike sorting output. To export to phy you can run: -.. code:: ipython3 +.. code:: ipython sexp.export_to_phy(we_TDC, 'phy_folder_for_TDC', verbose=True) @@ -581,7 +581,7 @@ After curating with Phy, the curated sorting can be reloaded to SpikeInterface. In this case, we exclude the units that have been labeled as “noise”: -.. code:: ipython3 +.. code:: ipython sorting_curated_phy = se.read_phy('phy_folder_for_TDC', exclude_cluster_groups=["noise"]) @@ -589,7 +589,7 @@ Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select sorted units with a SNR above a certain threshold: -.. code:: ipython3 +.. code:: ipython keep_mask = (qm['snr'] > 10) & (qm['isi_violations_ratio'] < 0.01) print("Mask:", keep_mask.values) @@ -615,7 +615,7 @@ outputs. We can either: 3. compare the output of multiple sorters (Tridesclous, SpykingCircus2, Kilosort2) -.. code:: ipython3 +.. code:: ipython comp_gt = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC) comp_pair = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_SC2) @@ -625,7 +625,7 @@ outputs. We can either: When comparing with a ground-truth sorting (1,), you can get the sorting performance and plot a confusion matrix -.. code:: ipython3 +.. code:: ipython print(comp_gt.get_performance()) w_conf = sw.plot_confusion_matrix(comp_gt) @@ -659,7 +659,7 @@ performance and plot a confusion matrix When comparing two sorters (2.), we can see the matching of units between sorters. Units which are not matched has -1 as unit id: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_12 @@ -683,7 +683,7 @@ between sorters. Units which are not matched has -1 as unit id: or the reverse: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_21 @@ -709,7 +709,7 @@ When comparing multiple sorters (3.), you can extract a ``BaseSorting`` object with units in agreement between sorters. You can also plot a graph showing how the units are matched between the sorters. -.. code:: ipython3 +.. code:: ipython sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 7ff98a666b..5c4476187b 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -1,4 +1,4 @@ -.. code:: ipython3 +.. code:: ipython %matplotlib inline %load_ext autoreload @@ -42,7 +42,7 @@ Neuropixels 1 and a Neuropixels 2 probe. Here we will use *dataset1* with *neuropixel1*. This dataset is the *“hello world”* for drift correction in the spike sorting community! -.. code:: ipython3 +.. code:: ipython from pathlib import Path import matplotlib.pyplot as plt @@ -52,12 +52,12 @@ Here we will use *dataset1* with *neuropixel1*. This dataset is the import spikeinterface.full as si -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick') dataset_folder = base_folder / 'dataset1/NP1' -.. code:: ipython3 +.. code:: ipython # read the file raw_rec = si.read_spikeglx(dataset_folder) @@ -77,7 +77,7 @@ We preprocess the recording with bandpass filter and a common median reference. Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! -.. code:: ipython3 +.. code:: ipython def preprocess_chain(rec): rec = si.bandpass_filter(rec, freq_min=300., freq_max=6000.) @@ -85,7 +85,7 @@ motion estimation to get a better estimate of peak locations! return rec rec = preprocess_chain(raw_rec) -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -101,7 +101,7 @@ parameters for each step. Here we also save the motion correction results into a folder to be able to load them later. -.. code:: ipython3 +.. code:: ipython # internally, we can explore a preset like this # every parameter can be overwritten at runtime @@ -143,13 +143,13 @@ to load them later. -.. code:: ipython3 +.. code:: ipython # lets try theses 3 presets some_presets = ('rigid_fast', 'kilosort_like', 'nonrigid_accurate') # some_presets = ('nonrigid_accurate', ) -.. code:: ipython3 +.. code:: ipython # compute motion with 3 presets for preset in some_presets: @@ -195,7 +195,7 @@ A few comments on the figures: (2000-3000um) experience some drift, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. -.. code:: ipython3 +.. code:: ipython for preset in some_presets: # load @@ -243,7 +243,7 @@ locations (:py:func:`correct_motion_on_peaks()`) Case 1 is used before running a spike sorter and the case 2 is used here to display the results. -.. code:: ipython3 +.. code:: ipython from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks @@ -303,7 +303,7 @@ run times Presets and related methods have differents accuracies but also computation speeds. It is good to have this in mind! -.. code:: ipython3 +.. code:: ipython run_times = [] for preset in some_presets: diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 3fda05848c..10a3185c5c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -117,7 +117,7 @@ Kilosort2.5 git clone https://github.com/MouseLand/Kilosort # provide installation path by setting the KILOSORT2_5_PATH environment variable - # or using Kilosort2_5Sorter.set_kilosort2_path() + # or using Kilosort2_5Sorter.set_kilosort2_5_path() * See also for Matlab/CUDA: https://www.mathworks.com/help/parallel-computing/gpu-support-by-release.html diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index a6752e2c9d..5aed24ca41 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -129,13 +129,15 @@ For raw recording formats, we currently support: * **MCS RAW** :py:func:`~spikeinterface.extractors.read_mcsraw()` * **MEArec** :py:func:`~spikeinterface.extractors.read_mearec()` * **Mountainsort MDA** :py:func:`~spikeinterface.extractors.read_mda_recording()` +* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Neurodata Without Borders** :py:func:`~spikeinterface.extractors.read_nwb_recording()` * **Neuroscope** :py:func:`~spikeinterface.coextractorsre.read_neuroscope_recording()` +* **Neuroexplorer** :py:func:`~spikeinterface.extractors.read_neuroexplorer()` * **NIX** :py:func:`~spikeinterface.extractors.read_nix()` -* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Open Ephys Legacy** :py:func:`~spikeinterface.extractors.read_openephys()` * **Open Ephys Binary** :py:func:`~spikeinterface.extractors.read_openephys()` -* **Plexon** :py:func:`~spikeinterface.corextractorse.read_plexon()` +* **Plexon** :py:func:`~spikeinterface.extractors.read_plexon()` +* **Plexon 2** :py:func:`~spikeinterface.extractors.read_plexon2()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_recording()` * **SpikeGLX** :py:func:`~spikeinterface.extractors.read_spikeglx()` * **SpikeGLX IBL compressed** :py:func:`~spikeinterface.extractors.read_cbin_ibl()` @@ -165,6 +167,7 @@ For sorted data formats, we currently support: * **Neuralynx spikes** :py:func:`~spikeinterface.extractors.read_neuralynx_sorting()` * **NPZ (created by SpikeInterface)** :py:func:`~spikeinterface.core.read_npz_sorting()` * **Plexon spikes** :py:func:`~spikeinterface.extractors.read_plexon_sorting()` +* **Plexon 2 spikes** :py:func:`~spikeinterface.extractors.read_plexon2_sorting()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_sorting()` * **Spyking Circus** :py:func:`~spikeinterface.extractors.read_spykingcircus()` * **Trideclous** :py:func:`~spikeinterface.extractors.read_tridesclous()` diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 62c0d6b8d4..afedc4f982 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -9,7 +9,7 @@ Overview Mechanical drift, often observed in recordings, is currently a major issue for spike sorting. This is especially striking with the new generation of high-density devices used for in-vivo electrophyisology such as the neuropixel electrodes. -The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_) +The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_ [SteinmetzDataset]_) Long story short, the main idea is the same as the one used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index ee3234af6c..8c7c0a2cc3 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -38,6 +38,7 @@ For more details about each metric and it's availability and use within SpikeInt qualitymetrics/snr qualitymetrics/noise_cutoff qualitymetrics/silhouette_score + qualitymetrics/synchrony This code snippet shows how to compute quality metrics (with or without principal components) in SpikeInterface: diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index 275805c6a7..b924cdbf73 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -1,3 +1,5 @@ +.. _silhouette_score : + Silhouette score (:code:`silhouette`, :code:`silhouette_full`) ============================================================== @@ -7,7 +9,7 @@ Calculation Gives the ratio between the cohesiveness of a cluster and its separation from other clusters. Values for silhouette score range from -1 to 1. -For the full method as proposed by [Rouseeuw]_, the pairwise distances between each point +For the full method as proposed by [Rousseeuw]_, the pairwise distances between each point and every other point :math:`a(i)` in a cluster :math:`C_i` are calculated and then iterating through every other cluster's distances between the points in :math:`C_i` and the points in :math:`C_j` are calculated. The cluster with the minimal mean distance is taken to be :math:`b(i)`. The @@ -48,6 +50,13 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.simplified_silhouette_score + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_score + Literature ---------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index b41e194466..2f566bf8a7 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -1,3 +1,5 @@ +.. _synchrony: + Synchrony Metrics (:code:`synchrony`) ===================================== @@ -39,11 +41,11 @@ The SpikeInterface implementation is a partial port of the low-level complexity References ---------- -.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics +.. automodule:: spikeinterface.qualitymetrics.misc_metrics .. autofunction:: compute_synchrony_metrics Literature ---------- -Based on concepts described in Gruen_ +Based on concepts described in [Gruen]_ diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 26f2365202..f3c8e7b733 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -130,7 +130,7 @@ Parameters from all sorters can be retrieved with these functions: .. _containerizedsorters: Running sorters in Docker/Singularity Containers ------------------------------------------------ +------------------------------------------------ One of the biggest bottlenecks for users is installing spike sorting software. To alleviate this, we build and maintain containerized versions of several popular spike sorters on the @@ -239,7 +239,7 @@ There are three options: 1. **released PyPi version**: if you installed :code:`spikeinterface` with :code:`pip install spikeinterface`, the latest released version will be installed in the container. -2. **development :code:`main` version**: if you installed :code:`spikeinterface` from source from the cloned repo +2. **development** :code:`main` **version**: if you installed :code:`spikeinterface` from source from the cloned repo (with :code:`pip install .`) or with :code:`pip install git+https://github.com/SpikeInterface/spikeinterface.git`, the current development version from the :code:`main` branch will be installed in the container. @@ -285,27 +285,26 @@ Running several sorters in parallel The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs sequentially or in parallel. This can be done with the -:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying +:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`). .. code-block:: python - recordings = {'rec1' : recording, 'rec2': another_recording} - sorter_list = ['herdingspikes', 'tridesclous'] - sorter_params = { - 'herdingspikes': {'clustering_bandwidth' : 8}, - 'tridesclous': {'detect_threshold' : 5.}, - } - sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters', - mode_if_folder_exists='overwrite', sorter_params=sorter_params) + # here we run 2 sorters on 2 different recordings = 4 jobs + recording = ... + another_recording = ... + + job_list = [ + {'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.}, + {'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.}, + {'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, + ] + + # run in loop + sortings = run_sorter_jobs(job_list, engine='loop') - # the output is a dict with (rec_name, sorter_name) as keys - for (rec_name, sorter_name), sorting in sorting_output.items(): - print(rec_name, sorter_name, ':', sorting.get_unit_ids()) -After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g. -:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting` -as a value. :py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation: @@ -315,13 +314,11 @@ as a value. .. code-block:: python - run_sorters(sorter_list, recordings, engine='loop') + run_sorter_jobs(job_list, engine='loop') - run_sorters(sorter_list, recordings, engine='joblib', - engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorters(sorter_list, recordings, engine='slurm', - engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) Spike sorting by group @@ -458,7 +455,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **Kilosort** :code:`run_sorter('kilosort')` * **Kilosort2** :code:`run_sorter('kilosort2')` * **Kilosort2.5** :code:`run_sorter('kilosort2_5')` -* **Kilosort3** :code:`run_sorter('Kilosort3')` +* **Kilosort3** :code:`run_sorter('kilosort3')` * **PyKilosort** :code:`run_sorter('pykilosort')` * **Klusta** :code:`run_sorter('klusta')` * **Mountainsort4** :code:`run_sorter('mountainsort4')` @@ -474,7 +471,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: -* **Spyking circus2** :code:`run_sorter('spykingcircus2')` +* **Spyking Circus2** :code:`run_sorter('spykingcircus2')` * **Tridesclous2** :code:`run_sorter('tridesclous2')` In 2023, we expect to add many more sorters to this list. diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index aa62ea5b33..422eaea890 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -223,7 +223,7 @@ Here is a short example that depends on the output of "Motion interpolation": **Notes**: * :code:`spatial_interpolation_method` "kriging" or "iwd" do not play a big role. - * :code:`border_mode` is a very important parameter. It controls how to deal with the border because motion causes units on the + * :code:`border_mode` is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the :code:`border_mode='remove_channels'` because this removes channels on the border that will be impacted by drift. Of course the larger the motion is the more channels are removed. @@ -278,7 +278,7 @@ At the moment, there are five methods implemented: * 'naive': a very naive implemenation used as a reference for benchmarks * 'tridesclous': the algorithm for template matching implemented in Tridesclous * 'circus': the algorithm for template matching implemented in SpyKING-Circus - * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal macthing + * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal matching pursuit) * 'wobble' : an algorithm loosely based on YASS that scales template amplitudes and shifts them in time to match detected spikes diff --git a/installation_tips/check_your_install.py b/installation_tips/check_your_install.py index 20809ec6c0..2b13a941cd 100644 --- a/installation_tips/check_your_install.py +++ b/installation_tips/check_your_install.py @@ -103,8 +103,8 @@ def _clean(): try: func() done = '...OK' - except: - done = '...Fail' + except Exception as err: + done = f'...Fail, Error: {err}' print(label, done) if platform.system() == "Windows": diff --git a/installation_tips/full_spikeinterface_environment_linux_dandi.yml b/installation_tips/full_spikeinterface_environment_linux_dandi.yml index d402f6805f..2ed176b16c 100755 --- a/installation_tips/full_spikeinterface_environment_linux_dandi.yml +++ b/installation_tips/full_spikeinterface_environment_linux_dandi.yml @@ -3,13 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - mamba - # numpy 1.22 break numba which break tridesclous - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - joblib - tqdm - matplotlib - h5py @@ -31,12 +29,11 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 - spyking-circus>=1.1.0 # - phy==2.0b5 diff --git a/installation_tips/full_spikeinterface_environment_mac.yml b/installation_tips/full_spikeinterface_environment_mac.yml index 8b872981aa..7ce4a149cc 100755 --- a/installation_tips/full_spikeinterface_environment_mac.yml +++ b/installation_tips/full_spikeinterface_environment_mac.yml @@ -3,12 +3,10 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -30,13 +28,12 @@ dependencies: - pip: # - PyQt5 - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 - - mountainsort4>=1.0.0 - - mountainsort5>=0.3.0 + # - mountainsort4>=1.0.0 isosplit5 fails on pip install for mac + # - mountainsort5>=0.3.0 diff --git a/installation_tips/full_spikeinterface_environment_windows.yml b/installation_tips/full_spikeinterface_environment_windows.yml index 8c793edcb1..38c26e6a78 100755 --- a/installation_tips/full_spikeinterface_environment_windows.yml +++ b/installation_tips/full_spikeinterface_environment_windows.yml @@ -3,12 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -26,11 +25,10 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 diff --git a/pyproject.toml b/pyproject.toml index 474cdc483f..51efe1f585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ docs = [ "sphinx_rtd_theme==1.0.0", "sphinx-gallery", "numpydoc", + "ipython", # for notebooks in the gallery "MEArec", # Use as an example diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py index 79227c865f..26d2c1ad6f 100644 --- a/src/spikeinterface/comparison/studytools.py +++ b/src/spikeinterface/comparison/studytools.py @@ -22,12 +22,45 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.extractors import NpzSortingExtractor from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output +from spikeinterface.sorters.basesorter import is_log_ok + from .comparisontools import _perf_keys from .paircomparisons import compare_sorter_to_ground_truth +# This is deprecated and will be removed +def iter_working_folder(working_folder): + working_folder = Path(working_folder) + for rec_folder in working_folder.iterdir(): + if not rec_folder.is_dir(): + continue + for output_folder in rec_folder.iterdir(): + if (output_folder / "spikeinterface_job.json").is_file(): + with open(output_folder / "spikeinterface_job.json", "r") as f: + job_dict = json.load(f) + rec_name = job_dict["rec_name"] + sorter_name = job_dict["sorter_name"] + yield rec_name, sorter_name, output_folder + else: + rec_name = rec_folder.name + sorter_name = output_folder.name + if not output_folder.is_dir(): + continue + if not is_log_ok(output_folder): + continue + yield rec_name, sorter_name, output_folder + + +# This is deprecated and will be removed +def iter_sorting_output(working_folder): + """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" + for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): + SorterClass = sorter_dict[sorter_name] + sorting = SorterClass.get_result_from_folder(output_folder) + yield rec_name, sorter_name, sorting + + def setup_comparison_study(study_folder, gt_dict, **job_kwargs): """ Based on a dict of (recording, sorting) create the study folder. diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..08f187895b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..f35bc2b266 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..e6d08d38f7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index d78a2e4e57..07837bcef7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,6 +2,7 @@ import warnings import numpy as np from typing import Union, Optional, List, Literal +import warnings from .numpyextractors import NumpyRecording, NumpySorting @@ -31,7 +32,7 @@ def generate_recording( set_probe: Optional[bool] = True, ndim: Optional[int] = 2, seed: Optional[int] = None, - mode: Literal["lazy", "legacy"] = "legacy", + mode: Literal["lazy", "legacy"] = "lazy", ) -> BaseRecording: """ Generate a recording object. @@ -51,10 +52,10 @@ def generate_recording( The number of dimensions of the probe, by default 2. Set to 3 to make 3 dimensional probes. seed : Optional[int] A seed for the np.ramdom.default_rng function - mode: str ["lazy", "legacy"] Default "legacy". + mode: str ["lazy", "legacy"] Default "lazy". "legacy": generate a NumpyRecording with white noise. - This mode is kept for backward compatibility and will be deprecated in next release. - "lazy": return a NoiseGeneratorRecording + This mode is kept for backward compatibility and will be deprecated version 0.100.0. + "lazy": return a NoiseGeneratorRecording instance. Returns ------- @@ -64,6 +65,10 @@ def generate_recording( seed = _ensure_seed(seed) if mode == "legacy": + warnings.warn( + "generate_recording() : mode='legacy' will be deprecated in version 0.100.0. Use mode='lazy' instead.", + DeprecationWarning, + ) recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed) elif mode == "lazy": recording = NoiseGeneratorRecording( @@ -161,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -214,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue @@ -538,7 +543,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol class NoiseGeneratorRecording(BaseRecording): """ - A lazy recording that generates random samples if and only if `get_traces` is called. + A lazy recording that generates white noise samples if and only if `get_traces` is called. This done by tiling small noise chunk. @@ -555,7 +560,7 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : List[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_level: float, default 5: + noise_level: float, default 1: Std of the white noise dtype : Optional[Union[np.dtype, str]], default='float32' The dtype of the recording. Note that only np.float32 and np.float64 are supported. @@ -581,7 +586,7 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_level: float = 5.0, + noise_level: float = 1.0, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", @@ -647,7 +652,7 @@ def __init__( if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) self.noise_block = ( - rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level + rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) * noise_level ) elif self.strategy == "on_the_fly": pass @@ -664,35 +669,35 @@ def get_traces( start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) - start_frame_mod = start_frame % self.noise_block_size - end_frame_mod = end_frame % self.noise_block_size + start_frame_within_block = start_frame % self.noise_block_size + end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame traces = np.empty(shape=(num_samples, self.num_channels), dtype=self.dtype) - start_block_index = start_frame // self.noise_block_size - end_block_index = end_frame // self.noise_block_size + first_block_index = start_frame // self.noise_block_size + last_block_index = end_frame // self.noise_block_size pos = 0 - for block_index in range(start_block_index, end_block_index + 1): + for block_index in range(first_block_index, last_block_index + 1): if self.strategy == "tile_pregenerated": noise_block = self.noise_block elif self.strategy == "on_the_fly": rng = np.random.default_rng(seed=(self.seed, block_index)) - noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) + noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) noise_block *= self.noise_level - if block_index == start_block_index: - if start_block_index != end_block_index: - end_first_block = self.noise_block_size - start_frame_mod - traces[:end_first_block] = noise_block[start_frame_mod:] + if block_index == first_block_index: + if first_block_index != last_block_index: + end_first_block = self.noise_block_size - start_frame_within_block + traces[:end_first_block] = noise_block[start_frame_within_block:] pos += end_first_block else: # special case when unique block - traces[:] = noise_block[start_frame_mod : start_frame_mod + traces.shape[0]] - elif block_index == end_block_index: - if end_frame_mod > 0: - traces[pos:] = noise_block[:end_frame_mod] + traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples] + elif block_index == last_block_index: + if end_frame_within_block > 0: + traces[pos:] = noise_block[:end_frame_within_block] else: traces[pos : pos + self.noise_block_size] = noise_block pos += self.noise_block_size @@ -710,7 +715,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 1024, + num_channels: int = 384, seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: @@ -719,7 +724,7 @@ def generate_recording_by_size( This is a convenience wrapper around the NoiseGeneratorRecording class where only the size in GiB (NOT GB!) is specified. - It is generated with 1024 channels and a sampling frequency of 1 Hz. The duration is manipulted to + It is generated with 384 channels and a sampling frequency of 1 Hz. The duration is manipulted to produced the desired size. Seee GeneratorRecording for more details. @@ -727,7 +732,7 @@ def generate_recording_by_size( Parameters ---------- full_traces_size_GiB : float - The size in gibibyte (GiB) of the recording. + The size in gigabytes (GiB) of the recording. num_channels: int Number of channels. seed : int, optional @@ -740,7 +745,7 @@ def generate_recording_by_size( dtype = np.dtype("float32") sampling_frequency = 30_000.0 # Hz - num_channels = 1024 + num_channels = 384 GiB_to_bytes = 1024**3 full_traces_size_bytes = int(full_traces_size_GiB * GiB_to_bytes) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9ea5ad59e7..651804c995 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar raise NotImplementedError -# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -138,7 +136,97 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): - pass + """ + This class is useful to inject a sorting object in the node pipepline mechanism. + It allows to compute some post-processing steps with the same machinery used for sorting components. + This is used by: + * compute_spike_locations() + * compute_amplitude_scalings() + * compute_spike_amplitudes() + * compute_principal_components() + + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. + If False, the max channel is computed for each spike given a radius around the template max channel. + extremum_channel_inds: dict of int + The extremum channel index dict given from template. + radius_um: float (default 50.) + The radius to find the real max channel. + Used only when channel_from_template=False + peak_sign: str (default "neg") + Peak sign to find the max channel. + Used only when channel_from_template=False + """ + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): + PipelineNode.__init__(self, recording, return_output=False) + + self.channel_from_template = channel_from_template + + assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" + + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) + + if not channel_from_template: + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.peak_sign = peak_sign + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + if not self.channel_from_template: + # handle channel spike per spike + for i, peak in enumerate(local_peaks): + chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + # TODO: "amplitude" ??? + + return (local_peaks,) + + +def sorting_to_peaks(sorting, extremum_channel_inds): + spikes = sorting.to_spike_vector() + peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks["sample_index"] = spikes["sample_index"] + extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) + peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] + peaks["amplitude"] = 0.0 + peaks["segment_index"] = spikes["segment_index"] + return peaks class WaveformsNode(PipelineNode): @@ -423,7 +511,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): + elif isinstance(node, PeakSource): node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) else: # TODO later when in master: change the signature of all nodes (or maybe not!) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..d5663156c7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index e5901d7ee0..ff9cd99389 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -302,7 +302,7 @@ def get_chunk_with_margin( return traces_chunk, left_margin, right_margin -def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): +def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. @@ -316,6 +316,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. Returns ------- @@ -341,6 +344,8 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" locations_to_sort += (locations[:, dim],) order_f = np.lexsort(locations_to_sort) + if flip: + order_f = order_f[::-1] order_r = np.argsort(order_f, kind="stable") return order_f, order_r diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 85f41924c1..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -12,9 +12,10 @@ from spikeinterface.core.node_pipeline import ( run_node_pipeline, PeakRetriever, + SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype, + sorting_to_peaks, ) @@ -78,99 +79,107 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - # print(extremum_channel_inds) - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - # print(ext_channel_inds) - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) - peaks["sample_index"] = spikes["sample_index"] - peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0.0 - peaks["segment_index"] = 0 - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 + peaks = sorting_to_peaks(sorting, extremum_channel_inds) + peak_retriever = PeakRetriever(recording, peaks) - dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True + # channel index is from template + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) - - nodes = [ - peak_retriever, - dense_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - - output = run_node_pipeline( + # channel index is per spike + spike_retriever_S = SpikeRetriever( recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) + + # test with 3 differents first nodes + for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + # one step only : squeeze output + nodes = [ + peak_source, + AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), + ] + step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + + # 3 nodes two have outputs + ms_before = 0.5 + ms_after = 1.0 + peak_retriever = PeakRetriever(recording, peaks) + dense_waveforms = ExtractDenseWaveforms( + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False) + amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True) + waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True) + denoised_waveforms_rms = WaveformsRootMeanSquare( + recording, parents=[peak_source, waveform_denoiser], return_output=True + ) + + nodes = [ + peak_source, + dense_waveforms, + waveform_denoiser, + amplitue_extraction, + waveforms_rms, + denoised_waveforms_rms, + ] + + # gather memory mode + output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") + amplitudes, waveforms_rms, denoised_waveforms_rms = output + assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) + + num_peaks = peaks.shape[0] + num_channels = recording.get_num_channels() + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + # gather npy mode + folder = cache_folder / f"pipeline_folder_{loop}" + if folder.is_dir(): + shutil.rmtree(folder) + output = run_node_pipeline( + recording, + nodes, + job_kwargs, + gather_mode="npy", + folder=folder, + names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + ) + amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output + + amplitudes_file = folder / "amplitudes.npy" + assert amplitudes_file.is_file() + amplitudes3 = np.load(amplitudes_file) + assert np.array_equal(amplitudes, amplitudes2) + assert np.array_equal(amplitudes2, amplitudes3) + + waveforms_rms_file = folder / "waveforms_rms.npy" + assert waveforms_rms_file.is_file() + waveforms_rms3 = np.load(waveforms_rms_file) + assert np.array_equal(waveforms_rms, waveforms_rms2) + assert np.array_equal(waveforms_rms2, waveforms_rms3) + + denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" + assert denoised_waveforms_rms_file.is_file() + denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) + assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) + assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) + + # Test pickle mechanism + for node in nodes: + import pickle + + pickled_node = pickle.dumps(node) + unpickled_node = pickle.loads(pickled_node) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 6e92d155fe..1d99b192ee 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -138,11 +138,13 @@ def test_order_channels_by_depth(): order_1d, order_r1d = order_channels_by_depth(rec, dimensions="y") order_2d, order_r2d = order_channels_by_depth(rec, dimensions=("x", "y")) locations_rev = locations_copy[order_1d][order_r1d] + order_2d_fliped, order_r2d_fliped = order_channels_by_depth(rec, dimensions=("x", "y"), flip=True) assert np.array_equal(locations[:, 1], locations_copy[order_1d][:, 1]) assert np.array_equal(locations_copy[order_1d][:, 1], locations_copy[order_2d][:, 1]) assert np.array_equal(locations, locations_copy[order_2d]) assert np.array_equal(locations_copy, locations_copy[order_2d][order_r2d]) + assert np.array_equal(order_2d[::-1], order_2d_fliped) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c3a1c378f7..082ed8c0a6 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..6881ab3ec5 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -399,6 +403,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -516,6 +523,10 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: if self.format == "binary": return (self.folder / extension_name).is_dir() and ( @@ -1740,13 +1751,33 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr - zarr_root = zarr.open(self.folder, mode="r+") + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1863,6 +1894,9 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return if self.format == "binary": import pandas as pd @@ -1900,7 +1934,9 @@ def _save(self, **kwargs): self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif isinstance(ext_data, pd.DataFrame): ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) self.extension_group[ext_data_name].attrs["dataframe"] = True else: diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..5295cc76d8 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): ---------- parent_sorting: Recording The sorting object - units_to_merge: list of lists + units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids: None or list @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting): Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications + Returns ------- sorting: Sorting @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting): def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): self._parent_sorting = parent_sorting - if not isinstance(units_to_merge[0], list): + if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge units_to_merge = [units_to_merge] @@ -59,7 +60,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +69,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..c92861a8bf 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -178,7 +178,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index ebff40fae0..235dd705dc 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -11,6 +11,8 @@ NumpySorting, NpySnippetsExtractor, ZarrRecordingExtractor, + read_binary, + read_zarr, ) # sorting/recording/event from neo diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..22b40a51c5 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) @@ -99,8 +99,7 @@ def _run(self, **job_kwargs): # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 991d79506e..ce1c3bd5a0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size @@ -694,11 +693,10 @@ def compute_principal_components( If True and pc scores are already in the waveform extractor folders, pc scores are loaded and not recomputed. n_components: int Number of components fo PCA - default 5 - mode: str + mode: str, default: 'by_channel_local' - 'by_channel_local': a local PCA is fitted for each channel (projection by channel) - 'by_channel_global': a global PCA is fitted for all channels (projection by channel) - 'concatenated': channels are concatenated and a global PCA is fitted - default 'by_channel_local' sparsity: ChannelSparsity or None The sparsity to apply to waveforms. If waveform_extractor is already sparse, the default sparsity will be used - default None @@ -735,6 +733,7 @@ def compute_principal_components( >>> # run for all spikes in the SortingExtractor >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ + if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..38cb714d59 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..f7272ddefe 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,9 +2,10 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity from spikeinterface.extractors import toy_example if hasattr(pytest, "global_test_folder"): @@ -76,6 +77,16 @@ def setUp(self): overwrite=True, ) self.we2 = we2 + + # make we read-only + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, @@ -97,6 +108,12 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +194,11 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..d2739f69dd 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 0b8d8a730b..55e34ba5dd 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -18,13 +18,18 @@ class DepthOrderRecording(ChannelSliceRecording): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. """ name = "depth_order" installed = True - def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): - order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions) + def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): + order_f, order_r = order_channels_by_depth( + parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip + ) reordered_channel_ids = parent_recording.channel_ids[order_f] ChannelSliceRecording.__init__( self, @@ -35,6 +40,7 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording=parent_recording, channel_ids=channel_ids, dimensions=dimensions, + flip=flip, ) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 0f4800c6e8..cc4e8601e2 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -18,7 +18,7 @@ def detect_bad_channels( nyquist_threshold=0.8, direction="y", chunk_duration_s=0.3, - num_random_chunks=10, + num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, neighborhood_r2_threshold=0.9, @@ -81,9 +81,10 @@ def detect_bad_channels( highpass_filter_cutoff : float If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 chunk_duration_s : float - Duration of each chunk, by default 0.3 + Duration of each chunk, by default 0.5 num_random_chunks : int - Number of random chunks, by default 10 + Number of random chunks, by default 100 + Having many chunks is important for reproducibility. welch_window_ms : float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms neighborhood_r2_threshold : float, default 0.95 @@ -174,20 +175,18 @@ def detect_bad_channels( channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) assert dim < channel_locations.shape[1], f"Direction {direction} is wrong" - locs_depth = channel_locations[:, dim] - if np.array_equal(np.sort(locs_depth), locs_depth): + order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) + if np.all(np.diff(order_f) == 1): + # already ordered order_f = None order_r = None - else: - # sort by x, y to avoid ambiguity - order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) # Create empty channel labels and fill with bad-channel detection estimate for each chunk chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8) for i, random_chunk in enumerate(random_data): - random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk - chunk_channel_labels[:, i] = detect_bad_channels_ibl( + random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk + chunk_labels = detect_bad_channels_ibl( raw=random_chunk_sorted, fs=recording.sampling_frequency, psd_hf_threshold=psd_hf_threshold, @@ -198,11 +197,10 @@ def detect_bad_channels( nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, ) + chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False) - if order_r is not None: - mode_channel_labels = mode_channel_labels[order_r] (bad_inds,) = np.where(mode_channel_labels != 0) bad_channel_ids = recording.channel_ids[bad_inds] diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index ff2a5b60c2..e2ef6e6794 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -186,11 +186,12 @@ def correct_motion( Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: - * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks' - * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion' - * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion' + + * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` + * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion` + * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion` Possible presets: {} diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..8dd5f857f6 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) @@ -848,16 +848,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index b7b267251d..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] @@ -736,6 +736,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the simplified silhouette score for each cluster. The value ranges from -1 (bad clustering) to 1 (good clustering). The simplified silhoutte score utilizes the centroids for distance calculations rather than pairwise calculations. + Parameters ---------- all_pcs : 2d array @@ -744,12 +745,14 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Simplified Silhouette Score for this unit + References - ------------ + ---------- Based on simplified silhouette score suggested by [Hruschka]_ """ @@ -782,6 +785,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the silhouette score which is a marker of cluster quality ranging from -1 (bad clustering) to 1 (good clustering). Distances are all calculated as pairwise comparisons of all data points. + Parameters ---------- all_pcs : 2d array @@ -790,12 +794,14 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Silhouette Score for this unit + References - ------------ + ---------- Based on [Rousseeuw]_ """ diff --git a/src/spikeinterface/sorters/__init__.py b/src/spikeinterface/sorters/__init__.py index a0d437559d..ba663327e8 100644 --- a/src/spikeinterface/sorters/__init__.py +++ b/src/spikeinterface/sorters/__init__.py @@ -1,11 +1,4 @@ from .basesorter import BaseSorter from .sorterlist import * from .runsorter import * - -from .launcher import ( - run_sorters, - run_sorter_by_property, - collect_sorting_outputs, - iter_working_folder, - iter_sorting_output, -) +from .launcher import run_sorter_jobs, run_sorters, run_sorter_by_property diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index ff559cc78d..c7581ba1e1 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -411,3 +411,14 @@ def get_job_kwargs(params, verbose): if not verbose: job_kwargs["progress_bar"] = False return job_kwargs + + +def is_log_ok(output_folder): + # log is OK when run_time is not None + if (output_folder / "spikeinterface_log.json").is_file(): + with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + ok = run_time is not None + return ok + return False diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 52098f45cd..f32a468a22 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -4,61 +4,193 @@ from pathlib import Path import shutil import numpy as np -import json import tempfile import os import stat import subprocess import sys +import warnings -from spikeinterface.core import load_extractor, aggregate_units -from spikeinterface.core.core_tools import check_json +from spikeinterface.core import aggregate_units from .sorterlist import sorter_dict -from .runsorter import run_sorter, run_sorter - - -def _run_one(arg_list): - # the multiprocessing python module force to have one unique tuple argument - ( - sorter_name, - recording, - output_folder, - verbose, - sorter_params, - docker_image, - singularity_image, - with_output, - ) = arg_list - - if isinstance(recording, dict): - recording = load_extractor(recording) +from .runsorter import run_sorter +from .basesorter import is_log_ok + +_default_engine_kwargs = dict( + loop=dict(), + joblib=dict(n_jobs=-1, backend="loky"), + processpoolexecutor=dict(max_workers=2, mp_context=None), + dask=dict(client=None), + slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"), +) + + +_implemented_engine = list(_default_engine_kwargs.keys()) + + +def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False): + """ + Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs. + + For **engine="loop"** this is equivalent to: + + ..code:: + + for job in job_list: + run_sorter(**job) + + The following engines block the I/O: + * "loop" + * "joblib" + * "multiprocessing" + * "dask" + + The following engines are *asynchronous*: + * "slurm" + + Where *blocking* means that this function is blocking until the results are returned. + This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking), + but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know + when jobs are finish. + In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results. + + + Parameters + ---------- + job_list: list of dict + A list a dict that are propagated to run_sorter(...) + engine: str "loop", "joblib", "dask", "slurm" + The engine to run the list. + * "loop": a simple loop. This engine is + engine_kwargs: dict + + return_output: bool, dfault False + Return a sorting or None. + + Returns + ------- + sortings: None or list of sorting + With engine="loop" or "joblib" you can optional get directly the list of sorting result if return_output=True. + """ + + assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" + + engine_kwargs_ = dict() + engine_kwargs_.update(_default_engine_kwargs[engine]) + engine_kwargs_.update(engine_kwargs) + engine_kwargs = engine_kwargs_ + + if return_output: + assert engine in ( + "loop", + "joblib", + "processpoolexecutor", + ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." + out = [] else: - recording = recording - - # because this is checks in run_sorters before this call - remove_existing_folder = False - # result is retrieve later - delete_output_folder = False - # because we won't want the loop/worker to break - raise_error = False - - run_sorter( - sorter_name, - recording, - output_folder=output_folder, - remove_existing_folder=remove_existing_folder, - delete_output_folder=delete_output_folder, - verbose=verbose, - raise_error=raise_error, - docker_image=docker_image, - singularity_image=singularity_image, - with_output=with_output, - **sorter_params, - ) + out = None + + if engine == "loop": + # simple loop in main process + for kwargs in job_list: + sorting = run_sorter(**kwargs) + if return_output: + out.append(sorting) + + elif engine == "joblib": + from joblib import Parallel, delayed + + n_jobs = engine_kwargs["n_jobs"] + backend = engine_kwargs["backend"] + sortings = Parallel(n_jobs=n_jobs, backend=backend)(delayed(run_sorter)(**kwargs) for kwargs in job_list) + if return_output: + out.extend(sortings) + + elif engine == "processpoolexecutor": + from concurrent.futures import ProcessPoolExecutor + + max_workers = engine_kwargs["max_workers"] + mp_context = engine_kwargs["mp_context"] + with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp_context) as executor: + futures = [] + for kwargs in job_list: + res = executor.submit(run_sorter, **kwargs) + futures.append(res) + for futur in futures: + sorting = futur.result() + if return_output: + out.append(sorting) -_implemented_engine = ("loop", "joblib", "dask", "slurm") + elif engine == "dask": + client = engine_kwargs["client"] + assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" + + tasks = [] + for kwargs in job_list: + task = client.submit(run_sorter, **kwargs) + tasks.append(task) + + for task in tasks: + task.result() + + elif engine == "slurm": + # generate python script for slurm + tmp_script_folder = engine_kwargs["tmp_script_folder"] + if tmp_script_folder is None: + tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") + tmp_script_folder = Path(tmp_script_folder) + cpus_per_task = engine_kwargs["cpus_per_task"] + mem = engine_kwargs["mem"] + + tmp_script_folder.mkdir(exist_ok=True, parents=True) + + for i, kwargs in enumerate(job_list): + script_name = tmp_script_folder / f"si_script_{i}.py" + with open(script_name, "w") as f: + kwargs_txt = "" + for k, v in kwargs.items(): + kwargs_txt += " " + if k == "recording": + # put None temporally + kwargs_txt += "recording=None" + else: + if isinstance(v, str): + kwargs_txt += f'{k}="{v}"' + elif isinstance(v, Path): + kwargs_txt += f'{k}="{str(v.absolute())}"' + else: + kwargs_txt += f"{k}={v}" + kwargs_txt += ",\n" + + # recording_dict = task_args[1] + recording_dict = kwargs["recording"].to_dict() + slurm_script = _slurm_script.format( + python=sys.executable, recording_dict=recording_dict, kwargs_txt=kwargs_txt + ) + f.write(slurm_script) + os.fchmod(f.fileno(), mode=stat.S_IRWXU) + + subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + + return out + + +_slurm_script = """#! {python} +from numpy import array +from spikeinterface import load_extractor +from spikeinterface.sorters import run_sorter + +rec_dict = {recording_dict} + +kwargs = dict( +{kwargs_txt} +) +kwargs['recording'] = load_extractor(rec_dict) + +run_sorter(**kwargs) +""" def run_sorter_by_property( @@ -66,7 +198,7 @@ def run_sorter_by_property( recording, grouping_property, working_folder, - mode_if_folder_exists="raise", + mode_if_folder_exists=None, engine="loop", engine_kwargs={}, verbose=False, @@ -93,11 +225,10 @@ def run_sorter_by_property( Property to split by before sorting working_folder: str The working directory. - mode_if_folder_exists: {'raise', 'overwrite', 'keep'} - The mode when the subfolder of recording/sorter already exists. - * 'raise' : raise error if subfolder exists - * 'overwrite' : delete and force recompute - * 'keep' : do not compute again if f=subfolder exists and log is OK + mode_if_folder_exists: None + Must be None. This is deprecated. + If not None then a warning is raise. + Will be removed in next release. engine: {'loop', 'joblib', 'dask'} Which engine to use to run sorter. engine_kwargs: dict @@ -127,46 +258,49 @@ def run_sorter_by_property( engine_kwargs={"n_jobs": 4}) """ + if mode_if_folder_exists is not None: + warnings.warn( + "run_sorter_by_property(): mode_if_folder_exists is not used anymore", + DeprecationWarning, + stacklevel=2, + ) + + working_folder = Path(working_folder).absolute() assert grouping_property in recording.get_property_keys(), ( f"The 'grouping_property' {grouping_property} is not " f"a recording property!" ) recording_dict = recording.split_by(grouping_property) - sorting_output = run_sorters( - [sorter_name], - recording_dict, - working_folder, - mode_if_folder_exists=mode_if_folder_exists, - engine=engine, - engine_kwargs=engine_kwargs, - verbose=verbose, - with_output=True, - docker_images={sorter_name: docker_image}, - singularity_images={sorter_name: singularity_image}, - sorter_params={sorter_name: sorter_params}, - ) - grouping_property_values = None - sorting_list = [] - for output_name, sorting in sorting_output.items(): - prop_name, sorter_name = output_name - sorting_list.append(sorting) - if grouping_property_values is None: - grouping_property_values = np.array( - [prop_name] * len(sorting.get_unit_ids()), dtype=np.dtype(type(prop_name)) - ) - else: - grouping_property_values = np.concatenate( - (grouping_property_values, [prop_name] * len(sorting.get_unit_ids())) - ) + job_list = [] + for k, rec in recording_dict.items(): + job = dict( + sorter_name=sorter_name, + recording=rec, + output_folder=working_folder / str(k), + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **sorter_params, + ) + job_list.append(job) + + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=True) + + unit_groups = [] + for sorting, group in zip(sorting_list, recording_dict.keys()): + num_units = sorting.get_unit_ids().size + unit_groups.extend([group] * num_units) + unit_groups = np.array(unit_groups) aggregate_sorting = aggregate_units(sorting_list) - aggregate_sorting.set_property(key=grouping_property, values=grouping_property_values) + aggregate_sorting.set_property(key=grouping_property, values=unit_groups) aggregate_sorting.register_recording(recording) return aggregate_sorting +# This is deprecated and will be removed def run_sorters( sorter_list, recording_dict_or_list, @@ -180,7 +314,9 @@ def run_sorters( docker_images={}, singularity_images={}, ): - """Run several sorter on several recordings. + """ + This function is deprecated and will be removed in version 0.100 + Please use run_sorter_jobs() instead. Parameters ---------- @@ -221,6 +357,13 @@ def run_sorters( results : dict The output is nested dict[(rec_name, sorter_name)] of SortingExtractor. """ + + warnings.warn( + "run_sorters() is deprecated please use run_sorter_jobs() instead. This will be removed in 0.100", + DeprecationWarning, + stacklevel=2, + ) + working_folder = Path(working_folder) mode_if_folder_exists in ("raise", "keep", "overwrite") @@ -247,8 +390,7 @@ def run_sorters( dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" - need_dump = engine != "loop" - task_args_list = [] + job_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / str(rec_name) / sorter_name @@ -268,181 +410,21 @@ def run_sorters( params = sorter_params.get(sorter_name, {}) docker_image = docker_images.get(sorter_name, None) singularity_image = singularity_images.get(sorter_name, None) - _check_container_images(docker_image, singularity_image, sorter_name) - - if need_dump: - if not recording.check_if_dumpable(): - raise Exception("recording not dumpable call recording.save() before") - recording_arg = recording.to_dict(recursive=True) - else: - recording_arg = recording - - task_args = ( - sorter_name, - recording_arg, - output_folder, - verbose, - params, - docker_image, - singularity_image, - with_output, - ) - task_args_list.append(task_args) - if engine == "loop": - # simple loop in main process - for task_args in task_args_list: - _run_one(task_args) - - elif engine == "joblib": - from joblib import Parallel, delayed - - n_jobs = engine_kwargs.get("n_jobs", -1) - backend = engine_kwargs.get("backend", "loky") - Parallel(n_jobs=n_jobs, backend=backend)(delayed(_run_one)(task_args) for task_args in task_args_list) - - elif engine == "dask": - client = engine_kwargs.get("client", None) - assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" - - tasks = [] - for task_args in task_args_list: - task = client.submit(_run_one, task_args) - tasks.append(task) - - for task in tasks: - task.result() - - elif engine == "slurm": - # generate python script for slurm - tmp_script_folder = engine_kwargs.get("tmp_script_folder", None) - if tmp_script_folder is None: - tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") - tmp_script_folder = Path(tmp_script_folder) - cpus_per_task = engine_kwargs.get("cpus_per_task", 1) - mem = engine_kwargs.get("mem", "1G") - - for i, task_args in enumerate(task_args_list): - script_name = tmp_script_folder / f"si_script_{i}.py" - with open(script_name, "w") as f: - arg_list_txt = "(\n" - for j, arg in enumerate(task_args): - arg_list_txt += "\t" - if j != 1: - if isinstance(arg, str): - arg_list_txt += f'"{arg}"' - elif isinstance(arg, Path): - arg_list_txt += f'"{str(arg.absolute())}"' - else: - arg_list_txt += f"{arg}" - else: - arg_list_txt += "recording" - arg_list_txt += ",\r" - arg_list_txt += ")" - - recording_dict = task_args[1] - slurm_script = _slurm_script.format( - python=sys.executable, recording_dict=recording_dict, arg_list_txt=arg_list_txt - ) - f.write(slurm_script) - os.fchmod(f.fileno(), mode=stat.S_IRWXU) - - print(slurm_script) - - subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=output_folder, + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **params, + ) + job_list.append(job) - non_blocking_engine = ("loop", "joblib") - if engine in non_blocking_engine: - # dump spikeinterface_job.json - # only for non blocking engine - for rec_name, recording in recording_dict.items(): - for sorter_name in sorter_list: - output_folder = working_folder / str(rec_name) / sorter_name - with open(output_folder / "spikeinterface_job.json", "w") as f: - dump_dict = {"rec_name": rec_name, "sorter_name": sorter_name, "engine": engine} - if engine != "dask": - dump_dict.update({"engine_kwargs": engine_kwargs}) - json.dump(check_json(dump_dict), f) + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=with_output) if with_output: - if engine not in non_blocking_engine: - print( - f'Warning!! With engine="{engine}" you cannot have directly output results\n' - "Use : run_sorters(..., with_output=False)\n" - "And then: results = collect_sorting_outputs(output_folders)" - ) - return - - results = collect_sorting_outputs(working_folder) + keys = [(rec_name, sorter_name) for rec_name in recording_dict for sorter_name in sorter_list] + results = dict(zip(keys, sorting_list)) return results - - -_slurm_script = """#! {python} -from numpy import array -from spikeinterface.sorters.launcher import _run_one - -recording = {recording_dict} - -arg_list = {arg_list_txt} - -_run_one(arg_list) -""" - - -def is_log_ok(output_folder): - # log is OK when run_time is not None - if (output_folder / "spikeinterface_log.json").is_file(): - with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - ok = run_time is not None - return ok - return False - - -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def collect_sorting_outputs(working_folder): - """Collect results in a working_folder. - - The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor. - """ - results = {} - for rec_name, sorter_name, sorting in iter_sorting_output(working_folder): - results[(rec_name, sorter_name)] = sorting - return results - - -def _check_container_images(docker_image, singularity_image, sorter_name): - if docker_image is not None: - assert singularity_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" - if singularity_image is not None: - assert docker_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index cd8bc0fa5d..14c938f8ba 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,4 +1,5 @@ import os +import sys import shutil import time @@ -6,8 +7,10 @@ from pathlib import Path from spikeinterface.core import load_extractor -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import run_sorters, run_sorter_by_property, collect_sorting_outputs + +# from spikeinterface.extractors import toy_example +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters import run_sorter_jobs, run_sorters, run_sorter_by_property if hasattr(pytest, "global_test_folder"): @@ -15,10 +18,17 @@ else: cache_folder = Path("cache_folder") / "sorters" +base_output = cache_folder / "sorter_output" + +# no need to have many +num_recordings = 2 +sorters = ["tridesclous2"] + def setup_module(): - rec, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) - for i in range(4): + base_seed = 42 + for i in range(num_recordings): + rec, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) rec_folder = cache_folder / f"toy_rec_{i}" if rec_folder.is_dir(): shutil.rmtree(rec_folder) @@ -31,19 +41,106 @@ def setup_module(): rec.save(folder=rec_folder) -def test_run_sorters_with_list(): - working_folder = cache_folder / "test_run_sorters_list" +def get_job_list(): + jobs = [] + for i in range(num_recordings): + for sorter_name in sorters: + recording = load_extractor(cache_folder / f"toy_rec_{i}") + kwargs = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=base_output / f"{sorter_name}_rec{i}", + verbose=True, + raise_error=False, + ) + jobs.append(kwargs) + + return jobs + + +@pytest.fixture(scope="module") +def job_list(): + return get_job_list() + + +def test_run_sorter_jobs_loop(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs(job_list, engine="loop", return_output=True) + print(sortings) + + +def test_run_sorter_jobs_joblib(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="joblib", engine_kwargs=dict(n_jobs=2, backend="loky"), return_output=True + ) + print(sortings) + + +def test_run_sorter_jobs_processpoolexecutor(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="processpoolexecutor", engine_kwargs=dict(max_workers=2), return_output=True + ) + print(sortings) + + +@pytest.mark.skipif(True, reason="This is tested locally") +def test_run_sorter_jobs_dask(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + # create a dask Client for a slurm queue + from dask.distributed import Client + + test_mode = "local" + # test_mode = "client_slurm" + + if test_mode == "local": + client = Client() + elif test_mode == "client_slurm": + from dask_jobqueue import SLURMCluster + + cluster = SLURMCluster( + processes=1, + cores=1, + memory="12GB", + python=sys.executable, + walltime="12:00:00", + ) + cluster.scale(2) + client = Client(cluster) + + # dask + t0 = time.perf_counter() + run_sorter_jobs(job_list, engine="dask", engine_kwargs=dict(client=client)) + t1 = time.perf_counter() + print(t1 - t0) + + +@pytest.mark.skip("Slurm launcher need a machine with slurm") +def test_run_sorter_jobs_slurm(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + working_folder = cache_folder / "test_run_sorters_slurm" if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable - rec0 = load_extractor(cache_folder / "toy_rec_0") - rec1 = load_extractor(cache_folder / "toy_rec_1") - - recording_list = [rec0, rec1] - sorter_list = ["tridesclous"] + tmp_script_folder = working_folder / "slurm_scripts" - run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + run_sorter_jobs( + job_list, + engine="slurm", + engine_kwargs=dict( + tmp_script_folder=tmp_script_folder, + cpus_per_task=32, + mem="32G", + ), + ) def test_run_sorter_by_property(): @@ -59,7 +156,7 @@ def test_run_sorter_by_property(): rec0_by = rec0.split_by("group") group_names0 = list(rec0_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1, engine="loop", verbose=False) assert "group" in sorting0.get_property_keys() assert all([g in group_names0 for g in sorting0.get_property("group")]) @@ -68,12 +165,31 @@ def test_run_sorter_by_property(): rec1_by = rec1.split_by("group") group_names1 = list(rec1_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2, engine="loop", verbose=False) assert "group" in sorting1.get_property_keys() assert all([g in group_names1 for g in sorting1.get_property("group")]) +# run_sorters is deprecated +# This will test will be removed in next release +def test_run_sorters_with_list(): + working_folder = cache_folder / "test_run_sorters_list" + if working_folder.is_dir(): + shutil.rmtree(working_folder) + + # make dumpable + rec0 = load_extractor(cache_folder / "toy_rec_0") + rec1 = load_extractor(cache_folder / "toy_rec_1") + + recording_list = [rec0, rec1] + sorter_list = ["tridesclous2"] + + run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + + +# run_sorters is deprecated +# This will test will be removed in next release def test_run_sorters_with_dict(): working_folder = cache_folder / "test_run_sorters_dict" if working_folder.is_dir(): @@ -84,9 +200,9 @@ def test_run_sorters_with_dict(): recording_dict = {"toy_tetrode": rec0, "toy_octotrode": rec1} - sorter_list = ["tridesclous", "tridesclous2"] + sorter_list = ["tridesclous2"] - sorter_params = {"tridesclous": dict(detect_threshold=5.6), "tridesclous2": dict()} + sorter_params = {"tridesclous2": dict()} # simple loop t0 = time.perf_counter() @@ -116,143 +232,19 @@ def test_run_sorters_with_dict(): ) -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_joblib(): - working_folder = cache_folder / "test_run_sorters_joblib" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # joblib - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder / "with_joblib", - engine="joblib", - engine_kwargs={"n_jobs": 4}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_dask(): - working_folder = cache_folder / "test_run_sorters_dask" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # create a dask Client for a slurm queue - from dask.distributed import Client - from dask_jobqueue import SLURMCluster - - python = "/home/samuel.garcia/.virtualenvs/py36/bin/python3.6" - cluster = SLURMCluster( - processes=1, - cores=1, - memory="12GB", - python=python, - walltime="12:00:00", - ) - cluster.scale(5) - client = Client(cluster) - - # dask - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="dask", - engine_kwargs={"client": client}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_slurm(): - working_folder = cache_folder / "test_run_sorters_slurm" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - # create recording - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "spykingcircus2", - "tridesclous2", - ] - - tmp_script_folder = working_folder / "slurm_scripts" - tmp_script_folder.mkdir(parents=True) - - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="slurm", - engine_kwargs={ - "tmp_script_folder": tmp_script_folder, - "cpus_per_task": 32, - "mem": "32G", - }, - with_output=False, - mode_if_folder_exists="keep", - verbose=True, - ) - - -def test_collect_sorting_outputs(): - working_folder = cache_folder / "test_run_sorters_dict" - results = collect_sorting_outputs(working_folder) - print(results) - - -def test_sorter_installation(): - # This import is to get error on github when import fails - import tridesclous - - # import circus - - if __name__ == "__main__": - setup_module() - # pass - # test_run_sorters_with_list() - - # test_run_sorter_by_property() + # setup_module() + job_list = get_job_list() - test_run_sorters_with_dict() + # test_run_sorter_jobs_loop(job_list) + # test_run_sorter_jobs_joblib(job_list) + # test_run_sorter_jobs_processpoolexecutor(job_list) + # test_run_sorter_jobs_multiprocessing(job_list) + # test_run_sorter_jobs_dask(job_list) + test_run_sorter_jobs_slurm(job_list) - # test_run_sorters_joblib() - - # test_run_sorters_dask() - - # test_run_sorters_slurm() + # test_run_sorter_by_property() - # test_collect_sorting_outputs() + # this deprecated + # test_run_sorters_with_list() + # test_run_sorters_with_dict() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..772c99bc0a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..23fdbf1979 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index f3719b934b..bc52ea2c70 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -65,10 +65,9 @@ def detect_peaks( This avoid reading the recording multiple times. gather_mode: str How to gather the results: - * "memory": results are returned as in-memory numpy arrays - * "npy": results are stored to .npy files in `folder` + folder: str or Path If gather_mode is "npy", the folder where the files are created. names: list @@ -81,9 +80,11 @@ def detect_peaks( ------- peaks: array Detected peaks. + Notes ----- This peak detection ported from tridesclous into spikeinterface. + """ assert method in detect_peak_methods diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index e7bcff0832..ae036d1ba1 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -30,11 +30,10 @@ class SpikesOnTracesWidget(BaseWidget): sparsity : ChannelSparsity or None Optional ChannelSparsity to apply. If WaveformExtractor is already sparse, the argument is ignored, default None - unit_colors : dict or None + unit_colors : dict or None If given, a dictionary with unit ids as keys and colors as values, default None If None, then the get_unit_colors() is internally used. (matplotlib backend) - mode : str - Three possible modes, default 'auto': + mode : str in ('line', 'map', 'auto') default: 'auto' * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9a2ec4a215..e025f779c1 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -26,6 +26,7 @@ class TracesWidget(BaseWidget): List with start time and end time, default None mode: str Three possible modes, default 'auto': + * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) @@ -51,11 +52,6 @@ class TracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 add_legend : bool If True adds legend to figures, default True - - Returns - ------- - W: TracesWidget - The output widget """ def __init__( diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f3c640ff16..9c89b3981e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -54,6 +54,12 @@ {backends} **backend_kwargs: kwargs {backend_kwargs} + + + Returns + ------- + w : BaseWidget + The output widget object. """ # backend_str = f" {list(wcls.possible_backends.keys())}" backend_str = f" {wcls.get_possible_backends()}"