diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index 0e522e6baa..b3bf08954d 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -30,4 +30,4 @@ jobs: - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml + environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml diff --git a/doc/api.rst b/doc/api.rst index 43f79386e6..97c956c2f6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -19,6 +19,8 @@ spikeinterface.core .. autofunction:: extract_waveforms .. autofunction:: load_waveforms .. autofunction:: compute_sparsity + .. autoclass:: ChannelSparsity + :members: .. autoclass:: BinaryRecordingExtractor .. autoclass:: ZarrRecordingExtractor .. autoclass:: BinaryFolderRecording @@ -48,10 +50,6 @@ spikeinterface.core .. autofunction:: get_template_extremum_channel .. autofunction:: get_template_extremum_channel_peak_shift .. autofunction:: get_template_extremum_amplitude - -.. - .. autofunction:: read_binary - .. autofunction:: read_zarr .. autofunction:: append_recordings .. autofunction:: concatenate_recordings .. autofunction:: split_recording @@ -59,6 +57,8 @@ spikeinterface.core .. autofunction:: append_sortings .. autofunction:: split_sorting .. autofunction:: select_segment_sorting + .. autofunction:: read_binary + .. autofunction:: read_zarr Low-level ~~~~~~~~~ @@ -67,7 +67,6 @@ Low-level :noindex: .. autoclass:: BaseWaveformExtractorExtension - .. autoclass:: ChannelSparsity .. autoclass:: ChunkRecordingExecutor spikeinterface.extractors @@ -83,6 +82,7 @@ NEO-based .. autofunction:: read_alphaomega_event .. autofunction:: read_axona .. autofunction:: read_biocam + .. autofunction:: read_binary .. autofunction:: read_blackrock .. autofunction:: read_ced .. autofunction:: read_intan @@ -104,6 +104,7 @@ NEO-based .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + .. autofunction:: read_zarr Non-NEO-based @@ -216,8 +217,10 @@ spikeinterface.sorters .. autofunction:: print_sorter_versions .. autofunction:: get_sorter_description .. autofunction:: run_sorter + .. autofunction:: run_sorter_jobs .. autofunction:: run_sorters .. autofunction:: run_sorter_by_property + .. autofunction:: read_sorter_folder Low level ~~~~~~~~~ diff --git a/doc/development/development.rst b/doc/development/development.rst index f1371639c3..7656da11ab 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -14,7 +14,7 @@ There are various ways to contribute to SpikeInterface as a user or developer. S * Writing unit tests to expand code coverage and use case scenarios. * Reporting bugs and issues. -We use a forking workflow _ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: +We use a forking workflow ``_ to manage contributions. Here's a summary of the steps involved, with more details available in the provided link: * Fork the SpikeInterface repository. * Create a new branch (e.g., :code:`git switch -c my-contribution`). @@ -22,7 +22,7 @@ We use a forking workflow _ . +While we appreciate all the contributions please be mindful of the cost of reviewing pull requests ``_ . How to run tests locally @@ -201,7 +201,7 @@ Implement a new extractor SpikeInterface already supports over 30 file formats, but the acquisition system you use might not be among the supported formats list (***ref***). Most of the extractord rely on the `NEO `_ package to read information from files. -Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new `neo.rawio `_ class. +Therefore, to implement a new extractor to handle the unsupported format, we recommend make a new :code:`neo.rawio.BaseRawIO` class (see `example `_). Once that is done, the new class can be easily wrapped into SpikeInterface as an extension of the :py:class:`~spikeinterface.extractors.neoextractors.neobaseextractors.NeoBaseRecordingExtractor` (for :py:class:`~spikeinterface.core.BaseRecording` objects) or diff --git a/doc/images/plot_traces_ephyviewer.png b/doc/images/plot_traces_ephyviewer.png new file mode 100644 index 0000000000..9d926725a4 Binary files /dev/null and b/doc/images/plot_traces_ephyviewer.png differ diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 3fda05848c..10a3185c5c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -117,7 +117,7 @@ Kilosort2.5 git clone https://github.com/MouseLand/Kilosort # provide installation path by setting the KILOSORT2_5_PATH environment variable - # or using Kilosort2_5Sorter.set_kilosort2_path() + # or using Kilosort2_5Sorter.set_kilosort2_5_path() * See also for Matlab/CUDA: https://www.mathworks.com/help/parallel-computing/gpu-support-by-release.html diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 34ab3d1151..f3c8e7b733 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -239,7 +239,7 @@ There are three options: 1. **released PyPi version**: if you installed :code:`spikeinterface` with :code:`pip install spikeinterface`, the latest released version will be installed in the container. -2. **development :code:`main` version**: if you installed :code:`spikeinterface` from source from the cloned repo +2. **development** :code:`main` **version**: if you installed :code:`spikeinterface` from source from the cloned repo (with :code:`pip install .`) or with :code:`pip install git+https://github.com/SpikeInterface/spikeinterface.git`, the current development version from the :code:`main` branch will be installed in the container. @@ -285,27 +285,26 @@ Running several sorters in parallel The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs sequentially or in parallel. This can be done with the -:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying +:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`). .. code-block:: python - recordings = {'rec1' : recording, 'rec2': another_recording} - sorter_list = ['herdingspikes', 'tridesclous'] - sorter_params = { - 'herdingspikes': {'clustering_bandwidth' : 8}, - 'tridesclous': {'detect_threshold' : 5.}, - } - sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters', - mode_if_folder_exists='overwrite', sorter_params=sorter_params) + # here we run 2 sorters on 2 different recordings = 4 jobs + recording = ... + another_recording = ... + + job_list = [ + {'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.}, + {'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.}, + {'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, + ] + + # run in loop + sortings = run_sorter_jobs(job_list, engine='loop') - # the output is a dict with (rec_name, sorter_name) as keys - for (rec_name, sorter_name), sorting in sorting_output.items(): - print(rec_name, sorter_name, ':', sorting.get_unit_ids()) -After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g. -:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting` -as a value. :py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation: @@ -315,13 +314,11 @@ as a value. .. code-block:: python - run_sorters(sorter_list, recordings, engine='loop') + run_sorter_jobs(job_list, engine='loop') - run_sorters(sorter_list, recordings, engine='joblib', - engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorters(sorter_list, recordings, engine='slurm', - engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) Spike sorting by group @@ -458,7 +455,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **Kilosort** :code:`run_sorter('kilosort')` * **Kilosort2** :code:`run_sorter('kilosort2')` * **Kilosort2.5** :code:`run_sorter('kilosort2_5')` -* **Kilosort3** :code:`run_sorter('Kilosort3')` +* **Kilosort3** :code:`run_sorter('kilosort3')` * **PyKilosort** :code:`run_sorter('pykilosort')` * **Klusta** :code:`run_sorter('klusta')` * **Mountainsort4** :code:`run_sorter('mountainsort4')` @@ -474,7 +471,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: Here a list of internal sorter based on `spikeinterface.sortingcomponents`; they are totally experimental for now: -* **Spyking circus2** :code:`run_sorter('spykingcircus2')` +* **Spyking Circus2** :code:`run_sorter('spykingcircus2')` * **Tridesclous2** :code:`run_sorter('tridesclous2')` In 2023, we expect to add many more sorters to this list. diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index aa62ea5b33..422eaea890 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -223,7 +223,7 @@ Here is a short example that depends on the output of "Motion interpolation": **Notes**: * :code:`spatial_interpolation_method` "kriging" or "iwd" do not play a big role. - * :code:`border_mode` is a very important parameter. It controls how to deal with the border because motion causes units on the + * :code:`border_mode` is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the :code:`border_mode='remove_channels'` because this removes channels on the border that will be impacted by drift. Of course the larger the motion is the more channels are removed. @@ -278,7 +278,7 @@ At the moment, there are five methods implemented: * 'naive': a very naive implemenation used as a reference for benchmarks * 'tridesclous': the algorithm for template matching implemented in Tridesclous * 'circus': the algorithm for template matching implemented in SpyKING-Circus - * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal macthing + * 'circus-omp': a updated algorithm similar to SpyKING-Circus but with OMP (orthogonal matching pursuit) * 'wobble' : an algorithm loosely based on YASS that scales template amplitudes and shifts them in time to match detected spikes diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 86c541dfd0..8565e94fce 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -14,6 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi * | :code:`sortingview`: web-based and interactive rendering using the `sortingview `_ | and `FIGURL `_ packages. +Version 0.100.0, also come with this new backend: +* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package + Installing backends ------------------- @@ -85,6 +88,28 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr `kachery-cloud `_ package ("Using your own storage bucket"). +ephyviewer +^^^^^^^^^^ + +This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install. + + +For a pip-based installation, run: + +.. code-block:: bash + + pip install PySide6 ephyviewer + + +Anaconda users will have a better experience with this: + +.. code-block:: bash + + conda install pyqt=5 + pip install ephyviewer + + + Usage ----- @@ -215,6 +240,21 @@ For example, here is how to combine the timeseries and sorting summary generated print(url) +ephyviewer +^^^^^^^^^^ + + +The :code:`ephyviewer` backend is currently only available for the :py:func:`~spikeinterface.widgets.plot_traces()` function. + + +.. code-block:: python + + plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True) + + +.. image:: ../images/plot_traces_ephyviewer.png + + Available plotting functions ---------------------------- @@ -229,7 +269,7 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`) * :py:func:`~spikeinterface.widgets.plot_template_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_template_similarity` (backends: ::code:`matplotlib`, :code:`sortingview`) -* :py:func:`~spikeinterface.widgets.plot_timeseries` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`) * :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`) diff --git a/src/spikeinterface/comparison/basecomparison.py b/src/spikeinterface/comparison/basecomparison.py index 79c784491a..5af20d79b5 100644 --- a/src/spikeinterface/comparison/basecomparison.py +++ b/src/spikeinterface/comparison/basecomparison.py @@ -262,11 +262,11 @@ def get_ordered_agreement_scores(self): indexes = np.arange(scores.shape[1]) order1 = [] for r in range(scores.shape[0]): - possible = indexes[~np.in1d(indexes, order1)] + possible = indexes[~np.isin(indexes, order1)] if possible.size > 0: ind = np.argmax(scores.iloc[r, possible].values) order1.append(possible[ind]) - remain = indexes[~np.in1d(indexes, order1)] + remain = indexes[~np.isin(indexes, order1)] order1.extend(remain) scores = scores.iloc[:, order1] diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index db45e2b25b..20ee7910b4 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -538,7 +538,7 @@ def do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_coun matched_units2 = match_12[match_12 != -1].values unmatched_units1 = match_12[match_12 == -1].index - unmatched_units2 = unit2_ids[~np.in1d(unit2_ids, matched_units2)] + unmatched_units2 = unit2_ids[~np.isin(unit2_ids, matched_units2)] ordered_units1 = np.hstack([matched_units1, unmatched_units1]) ordered_units2 = np.hstack([matched_units2, unmatched_units2]) diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py index 79227c865f..26d2c1ad6f 100644 --- a/src/spikeinterface/comparison/studytools.py +++ b/src/spikeinterface/comparison/studytools.py @@ -22,12 +22,45 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.extractors import NpzSortingExtractor from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output +from spikeinterface.sorters.basesorter import is_log_ok + from .comparisontools import _perf_keys from .paircomparisons import compare_sorter_to_ground_truth +# This is deprecated and will be removed +def iter_working_folder(working_folder): + working_folder = Path(working_folder) + for rec_folder in working_folder.iterdir(): + if not rec_folder.is_dir(): + continue + for output_folder in rec_folder.iterdir(): + if (output_folder / "spikeinterface_job.json").is_file(): + with open(output_folder / "spikeinterface_job.json", "r") as f: + job_dict = json.load(f) + rec_name = job_dict["rec_name"] + sorter_name = job_dict["sorter_name"] + yield rec_name, sorter_name, output_folder + else: + rec_name = rec_folder.name + sorter_name = output_folder.name + if not output_folder.is_dir(): + continue + if not is_log_ok(output_folder): + continue + yield rec_name, sorter_name, output_folder + + +# This is deprecated and will be removed +def iter_sorting_output(working_folder): + """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" + for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): + SorterClass = sorter_dict[sorter_name] + sorting = SorterClass.get_result_from_folder(output_folder) + yield rec_name, sorter_name, sorting + + def setup_comparison_study(study_folder, gt_dict, **job_kwargs): """ Based on a dict of (recording, sorting) create the study folder. diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index af4970a4ad..08f187895b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -592,7 +592,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 737087abc1..f35bc2b266 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -139,7 +139,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None): def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets - new_channel_ids = self.channel_ids[~np.in1d(self.channel_ids, remove_channel_ids)] + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)] sub_recording = ChannelSliceSnippets(self, new_channel_ids) return sub_recording diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 52f71c2399..e6d08d38f7 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -346,7 +346,7 @@ def remove_units(self, remove_unit_ids): """ from spikeinterface import UnitsSelectionSorting - new_unit_ids = self.unit_ids[~np.in1d(self.unit_ids, remove_unit_ids)] + new_unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] new_sorting = UnitsSelectionSorting(self, new_unit_ids) return new_sorting @@ -473,8 +473,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if not concatenated: spikes_ = [] for segment_index in range(self.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1], side="left") spikes_.append(spikes[s0:s1]) spikes = spikes_ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 401c498f03..07837bcef7 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -166,7 +166,7 @@ def generate_sorting( ) if empty_units is not None: - keep = ~np.in1d(labels, empty_units) + keep = ~np.isin(labels, empty_units) times = times[keep] labels = labels[keep] @@ -219,7 +219,7 @@ def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None): sample_index = spike["sample_index"] if sample_index not in units_used_for_spike: units_used_for_spike[sample_index] = np.array([spike["unit_index"]]) - units_not_used = unit_ids[~np.in1d(unit_ids, units_used_for_spike[sample_index])] + units_not_used = unit_ids[~np.isin(unit_ids, units_used_for_spike[sample_index])] if len(units_not_used) == 0: continue diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..651804c995 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -111,8 +111,7 @@ def __init__(self, recording, peaks): # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(peaks["segment_index"], segment_index) - i1 = np.searchsorted(peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -125,8 +124,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -183,8 +181,7 @@ def __init__( # precompute segment slice self.segment_slices = [] for segment_index in range(recording.get_num_segments()): - i0 = np.searchsorted(self.peaks["segment_index"], segment_index) - i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) def get_trace_margin(self): @@ -197,8 +194,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 97f22615df..d5663156c7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -338,8 +338,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") + s0, s1 = np.searchsorted(self.spikes["segment_index"], [self.segment_index, self.segment_index + 1]) self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index e5901d7ee0..ff9cd99389 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -302,7 +302,7 @@ def get_chunk_with_margin( return traces_chunk, left_margin, right_margin -def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): +def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. @@ -316,6 +316,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. Returns ------- @@ -341,6 +344,8 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" locations_to_sort += (locations[:, dim],) order_f = np.lexsort(locations_to_sort) + if flip: + order_f = order_f[::-1] order_r = np.argsort(order_f, kind="stable") return order_f, order_r diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index f70c45bfe5..85e36cf7a5 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -174,8 +174,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Return (0 * num_channels) array of correct dtype return self.parent_segments[0].get_traces(0, 0, channel_indices) - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) @@ -469,8 +468,7 @@ def get_unit_spike_train( if end_frame is None: end_frame = self.get_num_samples() - i0 = np.searchsorted(self.cumsum_length, start_frame, side="right") - 1 - i1 = np.searchsorted(self.cumsum_length, end_frame, side="right") - 1 + i0, i1 = np.searchsorted(self.cumsum_length, [start_frame, end_frame], side="right") - 1 # several case: # * come from one segment (i0 == i1) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4c3680b021..8c5c62d568 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .recording_tools import get_channel_distances, get_noise_levels @@ -33,7 +35,9 @@ class ChannelSparsity: """ - Handle channel sparsity for a set of units. + Handle channel sparsity for a set of units. That is, for every unit, + it indicates which channels are used to represent the waveform and the rest + of the non-represented channels are assumed to be zero. Internally, sparsity is stored as a boolean mask. @@ -92,13 +96,17 @@ def __init__(self, mask, unit_ids, channel_ids): assert self.mask.shape[0] == self.unit_ids.shape[0] assert self.mask.shape[1] == self.channel_ids.shape[0] - # some precomputed dict + # Those are computed at first call self._unit_id_to_channel_ids = None self._unit_id_to_channel_indices = None + self.num_channels = self.channel_ids.size + self.num_units = self.unit_ids.size + self.max_num_active_channels = self.mask.sum(axis=1).max() + def __repr__(self): - ratio = np.mean(self.mask) - txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}" + density = np.mean(self.mask) + txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt @property @@ -119,6 +127,85 @@ def unit_id_to_channel_indices(self): self._unit_id_to_channel_indices[unit_id] = channel_inds return self._unit_id_to_channel_indices + def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Sparsify the waveforms according to a unit_id corresponding sparsity. + + + Given a unit_id, this method selects only the active channels for + that unit and removes the rest. + + Parameters + ---------- + waveforms : np.array + Dense waveforms with shape (num_waveforms, num_samples, num_channels) or a + single dense waveform (template) with shape (num_samples, num_channels). + unit_id : str + The unit_id for which to sparsify the waveform. + + Returns + ------- + sparsified_waveforms : np.array + Sparse waveforms with shape (num_waveforms, num_samples, num_active_channels) + or a single sparsified waveform (template) with shape (num_samples, num_active_channels). + """ + + assert_msg = ( + "Waveforms must be dense to sparsify them. " + f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}" + ) + assert self.are_waveforms_dense(waveforms=waveforms), assert_msg + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + sparsified_waveforms = waveforms[..., non_zero_indices] + + return sparsified_waveforms + + def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray: + """ + Densify sparse waveforms that were sparisified according to a unit's channel sparsity. + + Given a unit_id its sparsified waveform, this method places the waveform back + into its original form within a dense array. + + Parameters + ---------- + waveforms : np.array + The sparsified waveforms array of shape (num_waveforms, num_samples, num_active_channels) or a single + sparse waveform (template) with shape (num_samples, num_active_channels). + unit_id : str + The unit_id that was used to sparsify the waveform. + + Returns + ------- + densified_waveforms : np.array + The densified waveforms array of shape (num_waveforms, num_samples, num_channels) or a single dense + waveform (template) with shape (num_samples, num_channels). + + """ + + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + + assert_msg = ( + "Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is " + f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels." + ) + assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg + + densified_shape = waveforms.shape[:-1] + (self.num_channels,) + densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype) + densified_waveforms[..., non_zero_indices] = waveforms + + return densified_waveforms + + def are_waveforms_dense(self, waveforms: np.ndarray) -> bool: + return waveforms.shape[-1] == self.num_channels + + def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool: + non_zero_indices = self.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + return waveforms.shape[-1] == num_active_channels + @classmethod def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids): """ @@ -144,16 +231,16 @@ def to_dict(self): ) @classmethod - def from_dict(cls, d): + def from_dict(cls, dictionary: dict): unit_id_to_channel_ids_corrected = {} - for unit_id in d["unit_ids"]: - if unit_id in d["unit_id_to_channel_ids"]: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id] + for unit_id in dictionary["unit_ids"]: + if unit_id in dictionary["unit_id_to_channel_ids"]: + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id] else: - unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)] - d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected + unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)] + dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected - return cls.from_unit_id_to_channel_ids(**d) + return cls.from_unit_id_to_channel_ids(**dictionary) ## Some convinient function to compute sparsity from several strategy @classmethod diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 6e92d155fe..1d99b192ee 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -138,11 +138,13 @@ def test_order_channels_by_depth(): order_1d, order_r1d = order_channels_by_depth(rec, dimensions="y") order_2d, order_r2d = order_channels_by_depth(rec, dimensions=("x", "y")) locations_rev = locations_copy[order_1d][order_r1d] + order_2d_fliped, order_r2d_fliped = order_channels_by_depth(rec, dimensions=("x", "y"), flip=True) assert np.array_equal(locations[:, 1], locations_copy[order_1d][:, 1]) assert np.array_equal(locations_copy[order_1d][:, 1], locations_copy[order_2d][:, 1]) assert np.array_equal(locations, locations_copy[order_2d]) assert np.array_equal(locations_copy, locations_copy[order_2d][order_r2d]) + assert np.array_equal(order_2d[::-1], order_2d_fliped) if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a6b94c9b84..ac114ac161 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -34,7 +34,7 @@ def test_ChannelSparsity(): for key, v in sparsity.unit_id_to_channel_ids.items(): assert key in unit_ids - assert np.all(np.in1d(v, channel_ids)) + assert np.all(np.isin(v, channel_ids)) for key, v in sparsity.unit_id_to_channel_indices.items(): assert key in unit_ids @@ -55,5 +55,93 @@ def test_ChannelSparsity(): assert np.array_equal(sparsity.mask, sparsity4.mask) +def test_sparsify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + waveforms_dense = rng.random(size=(num_units, num_samples, num_channels)) + + # Test are_waveforms_dense + assert sparsity.are_waveforms_dense(waveforms_dense) + + # Test sparsify + waveforms_sparse = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) + + # Test round-trip (note that this is loosy) + unit_id = unit_ids[unit_id] + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert np.array_equal(waveforms_dense[..., non_zero_indices], waveforms_dense2[..., non_zero_indices]) + + # Test sparsify with one waveform (template) + template_dense = waveforms_dense.mean(axis=0) + template_sparse = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert template_sparse.shape == (num_samples, num_active_channels) + + # Test round trip with template + template_dense2 = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert np.array_equal(template_dense[..., non_zero_indices], template_dense2[:, non_zero_indices]) + + +def test_densify_waveforms(): + seed = 0 + rng = np.random.default_rng(seed=seed) + + num_units = 3 + num_samples = 5 + num_channels = 4 + + is_mask_valid = False + while not is_mask_valid: + sparsity_mask = rng.integers(0, 1, size=(num_units, num_channels), endpoint=True, dtype="bool") + is_mask_valid = np.all(sparsity_mask.sum(axis=1) > 0) + + unit_ids = np.arange(num_units) + channel_ids = np.arange(num_channels) + sparsity = ChannelSparsity(mask=sparsity_mask, unit_ids=unit_ids, channel_ids=channel_ids) + + for unit_id in unit_ids: + non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] + num_active_channels = len(non_zero_indices) + waveforms_sparse = rng.random(size=(num_units, num_samples, num_active_channels)) + + # Test are waveforms sparse + assert sparsity.are_waveforms_sparse(waveforms_sparse, unit_id=unit_id) + + # Test densify + waveforms_dense = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) + assert waveforms_dense.shape == (num_units, num_samples, num_channels) + + # Test round-trip + waveforms_sparse2 = sparsity.sparsify_waveforms(waveforms_dense, unit_id=unit_id) + assert np.array_equal(waveforms_sparse, waveforms_sparse2) + + # Test densify with one waveform (template) + template_sparse = waveforms_sparse.mean(axis=0) + template_dense = sparsity.densify_waveforms(template_sparse, unit_id=unit_id) + assert template_dense.shape == (num_samples, num_channels) + + # Test round trip with template + template_sparse2 = sparsity.sparsify_waveforms(template_dense, unit_id=unit_id) + assert np.array_equal(template_sparse, template_sparse2) + + if __name__ == "__main__": test_ChannelSparsity() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..6881ab3ec5 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -399,6 +403,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -516,6 +523,10 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: if self.format == "binary": return (self.folder / extension_name).is_dir() and ( @@ -1740,13 +1751,33 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr - zarr_root = zarr.open(self.folder, mode="r+") + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if self.waveform_extractor.is_read_only(): + warn( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1863,6 +1894,9 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return if self.format == "binary": import pandas as pd @@ -1900,7 +1934,9 @@ def _save(self, **kwargs): self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif isinstance(ext_data, pd.DataFrame): ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) self.extension_group[ext_data_name].attrs["dataframe"] = True else: diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index da8e3d64b6..a2f1296e31 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -344,15 +344,15 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes with the correct segment_index # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -562,8 +562,7 @@ def _init_worker_distribute_single_buffer( # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices @@ -590,8 +589,9 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! # the border of segment are protected by nbefore on left an nafter on the right - i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) - i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + i0, i1 = np.searchsorted( + in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)] + ) # slice in absolut in spikes vector l0 = i0 + s0 @@ -685,8 +685,7 @@ def has_exceeding_spikes(recording, sorting): """ spike_vector = sorting.to_spike_vector() for segment_index in range(recording.get_num_segments()): - start_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index) - end_seg_ind = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + start_seg_ind, end_seg_ind = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spike_vector_seg = spike_vector[start_seg_ind:end_seg_ind] if len(spike_vector_seg) > 0: if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1: diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 264ac3a56d..5295cc76d8 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): ---------- parent_sorting: Recording The sorting object - units_to_merge: list of lists + units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids: None or list @@ -24,6 +24,7 @@ class MergeUnitsSorting(BaseSorting): Default: 'keep' delta_time_ms: float or None Number of ms to consider for duplicated spikes. None won't check for duplications + Returns ------- sorting: Sorting @@ -33,7 +34,7 @@ class MergeUnitsSorting(BaseSorting): def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): self._parent_sorting = parent_sorting - if not isinstance(units_to_merge[0], list): + if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge units_to_merge = [units_to_merge] @@ -59,7 +60,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties else: # we cannot automatically find new names new_unit_ids = [f"merge{i}" for i in range(num_merge)] - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError( "Unable to find 'new_unit_ids' because it is a string and parents " "already contain merges. Pass a list of 'new_unit_ids' as an argument." @@ -68,7 +69,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties # dtype int new_unit_ids = list(max(parents_unit_ids) + 1 + np.arange(num_merge, dtype=dtype)) else: - if np.any(np.in1d(new_unit_ids, keep_unit_ids)): + if np.any(np.isin(new_unit_ids, keep_unit_ids)): raise ValueError("'new_unit_ids' already exist in the sorting.unit_ids. Provide new ones") assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge" diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..c92861a8bf 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -178,7 +178,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 02e7d5677d..8b70722652 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -76,7 +76,7 @@ def _read_probe_group(folder, bids_name, recording_channel_ids): contact_ids = channels["contact_id"].values.astype("U") # extracting information of requested channels - keep = np.in1d(channel_ids, recording_channel_ids) + keep = np.isin(channel_ids, recording_channel_ids) channel_ids = channel_ids[keep] contact_ids = contact_ids[keep] diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index ebff40fae0..235dd705dc 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -11,6 +11,8 @@ NumpySorting, NpySnippetsExtractor, ZarrRecordingExtractor, + read_binary, + read_zarr, ) # sorting/recording/event from neo diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index c91aed644d..05aee160f5 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Optional from pathlib import Path import numpy as np @@ -13,10 +16,14 @@ class BasePhyKilosortSortingExtractor(BaseSorting): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py) - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. + remove_empty_units : bool, default: True + If True, empty units are removed from the sorting extractor. + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. """ extractor_name = "BasePhyKilosortSorting" @@ -29,11 +36,11 @@ class BasePhyKilosortSortingExtractor(BaseSorting): def __init__( self, - folder_path, - exclude_cluster_groups=None, - keep_good_only=False, - remove_empty_units=False, - load_all_cluster_properties=True, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + keep_good_only: bool = False, + remove_empty_units: bool = False, + load_all_cluster_properties: bool = True, ): try: import pandas as pd @@ -195,20 +202,33 @@ class PhySortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional + exclude_cluster_groups: list or str, default: None Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). + load_all_cluster_properties : bool, default: True + If True, all cluster properties are loaded from the tsv/csv files. Returns ------- extractor : PhySortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "PhySorting" name = "phy" - def __init__(self, folder_path, exclude_cluster_groups=None): - BasePhyKilosortSortingExtractor.__init__(self, folder_path, exclude_cluster_groups, keep_good_only=False) + def __init__( + self, + folder_path: Path | str, + exclude_cluster_groups: Optional[list[str] | str] = None, + load_all_cluster_properties: bool = True, + ): + BasePhyKilosortSortingExtractor.__init__( + self, + folder_path, + exclude_cluster_groups, + keep_good_only=False, + load_all_cluster_properties=load_all_cluster_properties, + ) self._kwargs = { "folder_path": str(Path(folder_path).absolute()), @@ -223,8 +243,6 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py). - exclude_cluster_groups: list or str, optional - Cluster groups to exclude (e.g. "noise" or ["noise", "mua"]). keep_good_only : bool, default: True Whether to only keep good units. If True, only Kilosort-labeled 'good' units are returned. @@ -234,13 +252,13 @@ class KiloSortSortingExtractor(BasePhyKilosortSortingExtractor): Returns ------- extractor : KiloSortSortingExtractor - The loaded data. + The loaded Sorting object. """ extractor_name = "KiloSortSorting" name = "kilosort" - def __init__(self, folder_path, keep_good_only=False, remove_empty_units=True): + def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove_empty_units: bool = True): BasePhyKilosortSortingExtractor.__init__( self, folder_path, diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 5a0148c5c4..22b40a51c5 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -47,9 +47,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_amplitude_scalings = self._extension_data["amplitude_scalings"][spike_mask] return dict(amplitude_scalings=new_amplitude_scalings) @@ -99,8 +99,7 @@ def _run(self, **job_kwargs): # precompute segment slice segment_slices = [] for segment_index in range(we.get_num_segments()): - i0 = np.searchsorted(self.spikes["segment_index"], segment_index) - i1 = np.searchsorted(self.spikes["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(self.spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append(slice(i0, i1)) # and run @@ -317,8 +316,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) if i0 != i1: local_spikes = spikes_in_segment[i0:i1] @@ -335,8 +333,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: # local spikes with margin! - i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) - i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + i0_margin, i1_margin = np.searchsorted( + spikes_in_segment["sample_index"], [start_frame - left, end_frame + right] + ) local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] collisions_local = find_collisions( local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices @@ -462,14 +461,11 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_ spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] # find the possible spikes per and post within delta_collision_samples - consecutive_window_pre = np.searchsorted( - spikes_w_margin["sample_index"], - spike["sample_index"] - delta_collision_samples, - ) - consecutive_window_post = np.searchsorted( + consecutive_window_pre, consecutive_window_post = np.searchsorted( spikes_w_margin["sample_index"], - spike["sample_index"] + delta_collision_samples, + [spike["sample_index"] - delta_collision_samples, spike["sample_index"] + delta_collision_samples], ) + # exclude the spike itself (it is included in the collision_spikes by construction) pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 233625e09e..ce1c3bd5a0 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -600,8 +600,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) if i0 != i1: # protect from spikes on border : spike_time<0 or spike_time>seg_size diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 62a4e2c320..38cb714d59 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.isin(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) + filtered_idxs = np.isin(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -218,9 +218,7 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): d = np.diff(spike_times) assert np.all(d >= 0) - i0 = np.searchsorted(spike_times, start_frame) - i1 = np.searchsorted(spike_times, end_frame) - + i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) n_spikes = i1 - i0 amplitudes = np.zeros(n_spikes, dtype=recording.get_dtype()) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index c6f498f7e8..4cbe4d665e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -32,9 +32,9 @@ def _set_params(self, ms_before=0.5, ms_after=0.5, method="center_of_mass", meth def _select_extension_data(self, unit_ids): old_unit_ids = self.waveform_extractor.sorting.unit_ids - unit_inds = np.flatnonzero(np.in1d(old_unit_ids, unit_ids)) + unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spike_mask = np.in1d(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(self.spikes["unit_index"], unit_inds) new_spike_locations = self._extension_data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 119f0dc53d..ea44dea9cb 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -19,16 +19,16 @@ # plt.show() -def get_1d_template_metric_names(): - return deepcopy(list(_1d_metric_name_to_func.keys())) +def get_single_channel_template_metric_names(): + return deepcopy(list(_single_channel_metric_name_to_func.keys())) -def get_2d_template_metric_names(): - return deepcopy(list(_2d_metric_name_to_func.keys())) +def get_multi_channel_template_metric_names(): + return deepcopy(list(_multi_channel_metric_name_to_func.keys())) def get_template_metric_names(): - return get_1d_template_metric_names() + get_2d_template_metric_names() + return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() class TemplateMetricsCalculator(BaseWaveformExtractorExtension): @@ -41,7 +41,7 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): """ extension_name = "template_metrics" - min_channels_for_2d_warning = 10 + min_channels_for_multi_channel_warning = 10 def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) @@ -53,12 +53,12 @@ def _set_params( upsampling_factor=10, sparsity=None, functions_kwargs=None, - include_2d_metrics=False, + include_multi_channel_metrics=False, ): if metric_names is None: - metric_names = get_1d_template_metric_names() - if include_2d_metrics: - metric_names += get_2d_template_metric_names() + metric_names = get_single_channel_template_metric_names() + if include_multi_channel_metrics: + metric_names += get_multi_channel_template_metric_names() functions_kwargs = functions_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], @@ -86,8 +86,8 @@ def _run(self): unit_ids = self.waveform_extractor.sorting.unit_ids sampling_frequency = self.waveform_extractor.sampling_frequency - metrics_1d = [m for m in metric_names if m in get_1d_template_metric_names()] - metrics_2d = [m for m in metric_names if m in get_2d_template_metric_names()] + metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] + metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] if sparsity is None: extremum_channels_ids = get_template_extremum_channel( @@ -118,7 +118,7 @@ def _run(self): chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] - # compute 1d metrics + # compute single_channel metrics for i, template_single in enumerate(template.T): if sparsity is None: index = unit_id @@ -134,7 +134,7 @@ def _run(self): trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) - for metric_name in metrics_1d: + for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] value = func( template_upsampled, @@ -145,15 +145,15 @@ def _run(self): ) template_metrics.at[index, metric_name] = value - # compute metrics 2d - for metric_name in metrics_2d: + # compute metrics multi_channel + for metric_name in metrics_multi_channel: # retrieve template (with sparsity if waveform extractor is sparse) template = self.waveform_extractor.get_template(unit_id=unit_id) - if template.shape[1] < self.min_channels_for_2d_warning: + if template.shape[1] < self.min_channels_for_multi_channel_warning: warnings.warn( - f"With less than {self.min_channels_for_2d_warning} channels, " - "2D metrics might not be reliable." + f"With less than {self.min_channels_for_multi_channel_warning} channels, " + "multi-channel metrics might not be reliable." ) if self.waveform_extractor.is_sparse(): channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] @@ -206,7 +206,7 @@ def compute_template_metrics( peak_sign="neg", upsampling_factor=10, sparsity=None, - include_2d_metrics=False, + include_multi_channel_metrics=False, functions_kwargs=dict( recovery_window_ms=0.7, peak_relative_threshold=0.2, @@ -228,7 +228,7 @@ def compute_template_metrics( * num_positive_peaks * num_negative_peaks - Optionally, the following 2d metrics can be computed (when include_2d_metrics=True): + Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): * velocity_above * velocity_below * exp_decay @@ -250,8 +250,8 @@ def compute_template_metrics( Default is sparsity=None and template metric is computed on extremum channel only. If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. - include_2d_metrics: bool, default: False - Whether to compute 2d metrics + include_multi_channel_metrics: bool, default: False + Whether to compute multi-channel metrics functions_kwargs: dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -272,26 +272,27 @@ def compute_template_metrics( Notes ----- - If any 2d metric is in the metric_names or include_2d_metrics is True, sparsity must be None, so that one metric - value will be computed per unit. + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. """ if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: tmc = TemplateMetricsCalculator(waveform_extractor) # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. - if include_2d_metrics or ( - metric_names is not None and any([m in get_2d_template_metric_names() for m in metric_names]) + if include_multi_channel_metrics or ( + metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) ): - assert ( - sparsity is None - ), "If 2D metrics are computed, sparsity must be None, so that each unit will correspond to 1 row of the output dataframe." + assert sparsity is None, ( + "If multi-channel metrics are computed, sparsity must be None, " + "so that each unit will correspond to 1 row of the output dataframe." + ) tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, - include_2d_metrics=include_2d_metrics, + include_multi_channel_metrics=include_multi_channel_metrics, functions_kwargs=functions_kwargs, ) tmc.run() @@ -326,7 +327,7 @@ def get_trough_and_peak_idx(template): ######################################################################################### -# 1D metrics +# Single-channel metrics def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to valley duration in seconds of input waveforms. @@ -565,7 +566,7 @@ def get_num_negative_peaks(template_single, **kwargs): return len(neg_peaks[0]) -_1d_metric_name_to_func = { +_single_channel_metric_name_to_func = { "peak_to_valley": get_peak_to_valley, "peak_trough_ratio": get_peak_trough_ratio, "half_width": get_half_width, @@ -577,7 +578,7 @@ def get_num_negative_peaks(template_single, **kwargs): ######################################################################################### -# 2D metrics +# Multi-channel metrics def fit_velocity(peak_times, channel_dist): @@ -802,11 +803,11 @@ def get_spread(template, channel_locations, **kwargs): return spread -_2d_metric_name_to_func = { +_multi_channel_metric_name_to_func = { "velocity_above": get_velocity_above, "velocity_below": get_velocity_below, "exp_decay": get_exp_decay, "spread": get_spread, } -_metric_name_to_func = {**_1d_metric_name_to_func, **_2d_metric_name_to_func} +_metric_name_to_func = {**_single_channel_metric_name_to_func, **_multi_channel_metric_name_to_func} diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..f7272ddefe 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,9 +2,10 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity from spikeinterface.extractors import toy_example if hasattr(pytest, "global_test_folder"): @@ -76,6 +77,16 @@ def setUp(self): overwrite=True, ) self.we2 = we2 + + # make we read-only + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, @@ -97,6 +108,12 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +194,11 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..d2739f69dd 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 0b8d8a730b..55e34ba5dd 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -18,13 +18,18 @@ class DepthOrderRecording(ChannelSliceRecording): If str, it needs to be 'x', 'y', 'z'. If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') + flip: bool, default: False + If flip is False then the order is bottom first (starting from tip of the probe). + If flip is True then the order is upper first. """ name = "depth_order" installed = True - def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): - order_f, order_r = order_channels_by_depth(parent_recording, channel_ids=channel_ids, dimensions=dimensions) + def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y"), flip=False): + order_f, order_r = order_channels_by_depth( + parent_recording, channel_ids=channel_ids, dimensions=dimensions, flip=flip + ) reordered_channel_ids = parent_recording.channel_ids[order_f] ChannelSliceRecording.__init__( self, @@ -35,6 +40,7 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording=parent_recording, channel_ids=channel_ids, dimensions=dimensions, + flip=flip, ) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 0f4800c6e8..cc4e8601e2 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -18,7 +18,7 @@ def detect_bad_channels( nyquist_threshold=0.8, direction="y", chunk_duration_s=0.3, - num_random_chunks=10, + num_random_chunks=100, welch_window_ms=10.0, highpass_filter_cutoff=300, neighborhood_r2_threshold=0.9, @@ -81,9 +81,10 @@ def detect_bad_channels( highpass_filter_cutoff : float If the recording is not filtered, the cutoff frequency of the highpass filter, by default 300 chunk_duration_s : float - Duration of each chunk, by default 0.3 + Duration of each chunk, by default 0.5 num_random_chunks : int - Number of random chunks, by default 10 + Number of random chunks, by default 100 + Having many chunks is important for reproducibility. welch_window_ms : float Window size for the scipy.signal.welch that will be converted to nperseg, by default 10ms neighborhood_r2_threshold : float, default 0.95 @@ -174,20 +175,18 @@ def detect_bad_channels( channel_locations = recording.get_channel_locations() dim = ["x", "y", "z"].index(direction) assert dim < channel_locations.shape[1], f"Direction {direction} is wrong" - locs_depth = channel_locations[:, dim] - if np.array_equal(np.sort(locs_depth), locs_depth): + order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) + if np.all(np.diff(order_f) == 1): + # already ordered order_f = None order_r = None - else: - # sort by x, y to avoid ambiguity - order_f, order_r = order_channels_by_depth(recording=recording, dimensions=("x", "y")) # Create empty channel labels and fill with bad-channel detection estimate for each chunk chunk_channel_labels = np.zeros((recording.get_num_channels(), len(random_data)), dtype=np.int8) for i, random_chunk in enumerate(random_data): - random_chunk_sorted = random_chunk[order_f] if order_f is not None else random_chunk - chunk_channel_labels[:, i] = detect_bad_channels_ibl( + random_chunk_sorted = random_chunk[:, order_f] if order_f is not None else random_chunk + chunk_labels = detect_bad_channels_ibl( raw=random_chunk_sorted, fs=recording.sampling_frequency, psd_hf_threshold=psd_hf_threshold, @@ -198,11 +197,10 @@ def detect_bad_channels( nyquist_threshold=nyquist_threshold, welch_window_ms=welch_window_ms, ) + chunk_channel_labels[:, i] = chunk_labels[order_r] if order_r is not None else chunk_labels # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. mode_channel_labels, _ = scipy.stats.mode(chunk_channel_labels, axis=1, keepdims=False) - if order_r is not None: - mode_channel_labels = mode_channel_labels[order_r] (bad_inds,) = np.where(mode_channel_labels != 0) bad_channel_ids = recording.channel_ids[bad_inds] diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index e634d55e7f..95ecd0fe52 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -49,7 +49,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non self.bad_channel_ids = bad_channel_ids self._bad_channel_idxs = recording.ids_to_indices(self.bad_channel_ids) - self._good_channel_idxs = ~np.in1d(np.arange(recording.get_num_channels()), self._bad_channel_idxs) + self._good_channel_idxs = ~np.isin(np.arange(recording.get_num_channels()), self._bad_channel_idxs) self._bad_channel_idxs.setflags(write=False) if sigma_um is None: diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 3148539165..7e84822c61 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -165,7 +165,9 @@ def __init__( for l in np.unique(labels): assert l in artifacts.keys(), f"Artefacts are provided but label {l} has no value!" else: - assert "ms_before" != None and "ms_after" != None, f"ms_before/after should not be None for mode {mode}" + assert ( + ms_before is not None and ms_after is not None + ), f"ms_before/after should not be None for mode {mode}" sorting = NumpySorting.from_times_labels(list_triggers, list_labels, recording.get_sampling_frequency()) sorting = sorting.save() waveforms_kwargs.update({"ms_before": ms_before, "ms_after": ms_after}) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index ee28485983..8dd5f857f6 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -544,7 +544,7 @@ def compute_synchrony_metrics(waveform_extractor, synchrony_sizes=(2, 4, 8), **k # some segments/units might have no spikes if len(spikes_per_unit) == 0: continue - spike_complexity = complexity[np.in1d(unique_spike_index, spikes_per_unit["sample_index"])] + spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] for synchrony_size in synchrony_sizes: synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) @@ -848,16 +848,14 @@ def compute_drift_metrics( spike_vector = sorting.to_spike_vector() # retrieve spikes in segment - i0 = np.searchsorted(spike_vector["segment_index"], segment_index) - i1 = np.searchsorted(spike_vector["segment_index"], segment_index + 1) + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) spikes_in_segment = spike_vector[i0:i1] spike_locations_in_segment = spike_locations[i0:i1] # compute median positions (if less than min_spikes_per_interval, median position is 0) median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) - i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 59000211d4..ed06f7d738 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -152,8 +152,8 @@ def calculate_pc_metrics( neighbor_unit_ids = unit_ids neighbor_channel_indices = we.channel_ids_to_indices(neighbor_channel_ids) - labels = all_labels[np.in1d(all_labels, neighbor_unit_ids)] - pcs = all_pcs[np.in1d(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] + labels = all_labels[np.isin(all_labels, neighbor_unit_ids)] + pcs = all_pcs[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) func_args = ( @@ -506,7 +506,7 @@ def nearest_neighbors_isolation( other_units_ids = [ unit_id for unit_id in other_units_ids - if np.sum(np.in1d(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) + if np.sum(np.isin(sparsity.unit_id_to_channel_indices[unit_id], closest_chans_target_unit)) >= (n_channels_target_unit * min_spatial_overlap) ] @@ -536,10 +536,10 @@ def nearest_neighbors_isolation( if waveform_extractor.is_sparse(): # in this case, waveforms are sparse so we need to do some smart indexing waveforms_target_unit_sampled = waveforms_target_unit_sampled[ - :, :, np.in1d(closest_chans_target_unit, common_channel_idxs) + :, :, np.isin(closest_chans_target_unit, common_channel_idxs) ] waveforms_other_unit_sampled = waveforms_other_unit_sampled[ - :, :, np.in1d(closest_chans_other_unit, common_channel_idxs) + :, :, np.isin(closest_chans_other_unit, common_channel_idxs) ] else: waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :, common_channel_idxs] diff --git a/src/spikeinterface/sorters/__init__.py b/src/spikeinterface/sorters/__init__.py index a0d437559d..ba663327e8 100644 --- a/src/spikeinterface/sorters/__init__.py +++ b/src/spikeinterface/sorters/__init__.py @@ -1,11 +1,4 @@ from .basesorter import BaseSorter from .sorterlist import * from .runsorter import * - -from .launcher import ( - run_sorters, - run_sorter_by_property, - collect_sorting_outputs, - iter_working_folder, - iter_sorting_output, -) +from .launcher import run_sorter_jobs, run_sorters, run_sorter_by_property diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index ff559cc78d..c7581ba1e1 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -411,3 +411,14 @@ def get_job_kwargs(params, verbose): if not verbose: job_kwargs["progress_bar"] = False return job_kwargs + + +def is_log_ok(output_folder): + # log is OK when run_time is not None + if (output_folder / "spikeinterface_log.json").is_file(): + with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + ok = run_time is not None + return ok + return False diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..db3d88f116 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -3,7 +3,6 @@ import os import shutil import numpy as np -import os from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs @@ -21,18 +20,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, "clustering": {}, "matching": {}, - "registration": {}, "apply_preprocessing": True, - "shared_memory": False, - "job_kwargs": {}, + "shared_memory": True, + "job_kwargs": {"n_jobs": -1}, } @classmethod @@ -63,8 +61,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - # if recording.is_filtered == True: - # print('Looks like the recording is already filtered, check preprocessing!') recording_f = bandpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: @@ -103,8 +99,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params.update(params["waveforms"]) - clustering_params.update(params["general"]) + clustering_params["waveforms_kwargs"] = params["waveforms"] + + for k in ["ms_before", "ms_after"]: + clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" @@ -126,6 +125,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) + for k in ["ms_before", "ms_after"]: + waveforms_params[k] = params["general"][k] + if params["shared_memory"]: mode = "memory" waveforms_folder = None @@ -143,6 +145,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in matching_job_params: + matching_job_params.pop(value) + matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 52098f45cd..f32a468a22 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -4,61 +4,193 @@ from pathlib import Path import shutil import numpy as np -import json import tempfile import os import stat import subprocess import sys +import warnings -from spikeinterface.core import load_extractor, aggregate_units -from spikeinterface.core.core_tools import check_json +from spikeinterface.core import aggregate_units from .sorterlist import sorter_dict -from .runsorter import run_sorter, run_sorter - - -def _run_one(arg_list): - # the multiprocessing python module force to have one unique tuple argument - ( - sorter_name, - recording, - output_folder, - verbose, - sorter_params, - docker_image, - singularity_image, - with_output, - ) = arg_list - - if isinstance(recording, dict): - recording = load_extractor(recording) +from .runsorter import run_sorter +from .basesorter import is_log_ok + +_default_engine_kwargs = dict( + loop=dict(), + joblib=dict(n_jobs=-1, backend="loky"), + processpoolexecutor=dict(max_workers=2, mp_context=None), + dask=dict(client=None), + slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"), +) + + +_implemented_engine = list(_default_engine_kwargs.keys()) + + +def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False): + """ + Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs. + + For **engine="loop"** this is equivalent to: + + ..code:: + + for job in job_list: + run_sorter(**job) + + The following engines block the I/O: + * "loop" + * "joblib" + * "multiprocessing" + * "dask" + + The following engines are *asynchronous*: + * "slurm" + + Where *blocking* means that this function is blocking until the results are returned. + This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking), + but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know + when jobs are finish. + In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results. + + + Parameters + ---------- + job_list: list of dict + A list a dict that are propagated to run_sorter(...) + engine: str "loop", "joblib", "dask", "slurm" + The engine to run the list. + * "loop": a simple loop. This engine is + engine_kwargs: dict + + return_output: bool, dfault False + Return a sorting or None. + + Returns + ------- + sortings: None or list of sorting + With engine="loop" or "joblib" you can optional get directly the list of sorting result if return_output=True. + """ + + assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" + + engine_kwargs_ = dict() + engine_kwargs_.update(_default_engine_kwargs[engine]) + engine_kwargs_.update(engine_kwargs) + engine_kwargs = engine_kwargs_ + + if return_output: + assert engine in ( + "loop", + "joblib", + "processpoolexecutor", + ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." + out = [] else: - recording = recording - - # because this is checks in run_sorters before this call - remove_existing_folder = False - # result is retrieve later - delete_output_folder = False - # because we won't want the loop/worker to break - raise_error = False - - run_sorter( - sorter_name, - recording, - output_folder=output_folder, - remove_existing_folder=remove_existing_folder, - delete_output_folder=delete_output_folder, - verbose=verbose, - raise_error=raise_error, - docker_image=docker_image, - singularity_image=singularity_image, - with_output=with_output, - **sorter_params, - ) + out = None + + if engine == "loop": + # simple loop in main process + for kwargs in job_list: + sorting = run_sorter(**kwargs) + if return_output: + out.append(sorting) + + elif engine == "joblib": + from joblib import Parallel, delayed + + n_jobs = engine_kwargs["n_jobs"] + backend = engine_kwargs["backend"] + sortings = Parallel(n_jobs=n_jobs, backend=backend)(delayed(run_sorter)(**kwargs) for kwargs in job_list) + if return_output: + out.extend(sortings) + + elif engine == "processpoolexecutor": + from concurrent.futures import ProcessPoolExecutor + + max_workers = engine_kwargs["max_workers"] + mp_context = engine_kwargs["mp_context"] + with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp_context) as executor: + futures = [] + for kwargs in job_list: + res = executor.submit(run_sorter, **kwargs) + futures.append(res) + for futur in futures: + sorting = futur.result() + if return_output: + out.append(sorting) -_implemented_engine = ("loop", "joblib", "dask", "slurm") + elif engine == "dask": + client = engine_kwargs["client"] + assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" + + tasks = [] + for kwargs in job_list: + task = client.submit(run_sorter, **kwargs) + tasks.append(task) + + for task in tasks: + task.result() + + elif engine == "slurm": + # generate python script for slurm + tmp_script_folder = engine_kwargs["tmp_script_folder"] + if tmp_script_folder is None: + tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") + tmp_script_folder = Path(tmp_script_folder) + cpus_per_task = engine_kwargs["cpus_per_task"] + mem = engine_kwargs["mem"] + + tmp_script_folder.mkdir(exist_ok=True, parents=True) + + for i, kwargs in enumerate(job_list): + script_name = tmp_script_folder / f"si_script_{i}.py" + with open(script_name, "w") as f: + kwargs_txt = "" + for k, v in kwargs.items(): + kwargs_txt += " " + if k == "recording": + # put None temporally + kwargs_txt += "recording=None" + else: + if isinstance(v, str): + kwargs_txt += f'{k}="{v}"' + elif isinstance(v, Path): + kwargs_txt += f'{k}="{str(v.absolute())}"' + else: + kwargs_txt += f"{k}={v}" + kwargs_txt += ",\n" + + # recording_dict = task_args[1] + recording_dict = kwargs["recording"].to_dict() + slurm_script = _slurm_script.format( + python=sys.executable, recording_dict=recording_dict, kwargs_txt=kwargs_txt + ) + f.write(slurm_script) + os.fchmod(f.fileno(), mode=stat.S_IRWXU) + + subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + + return out + + +_slurm_script = """#! {python} +from numpy import array +from spikeinterface import load_extractor +from spikeinterface.sorters import run_sorter + +rec_dict = {recording_dict} + +kwargs = dict( +{kwargs_txt} +) +kwargs['recording'] = load_extractor(rec_dict) + +run_sorter(**kwargs) +""" def run_sorter_by_property( @@ -66,7 +198,7 @@ def run_sorter_by_property( recording, grouping_property, working_folder, - mode_if_folder_exists="raise", + mode_if_folder_exists=None, engine="loop", engine_kwargs={}, verbose=False, @@ -93,11 +225,10 @@ def run_sorter_by_property( Property to split by before sorting working_folder: str The working directory. - mode_if_folder_exists: {'raise', 'overwrite', 'keep'} - The mode when the subfolder of recording/sorter already exists. - * 'raise' : raise error if subfolder exists - * 'overwrite' : delete and force recompute - * 'keep' : do not compute again if f=subfolder exists and log is OK + mode_if_folder_exists: None + Must be None. This is deprecated. + If not None then a warning is raise. + Will be removed in next release. engine: {'loop', 'joblib', 'dask'} Which engine to use to run sorter. engine_kwargs: dict @@ -127,46 +258,49 @@ def run_sorter_by_property( engine_kwargs={"n_jobs": 4}) """ + if mode_if_folder_exists is not None: + warnings.warn( + "run_sorter_by_property(): mode_if_folder_exists is not used anymore", + DeprecationWarning, + stacklevel=2, + ) + + working_folder = Path(working_folder).absolute() assert grouping_property in recording.get_property_keys(), ( f"The 'grouping_property' {grouping_property} is not " f"a recording property!" ) recording_dict = recording.split_by(grouping_property) - sorting_output = run_sorters( - [sorter_name], - recording_dict, - working_folder, - mode_if_folder_exists=mode_if_folder_exists, - engine=engine, - engine_kwargs=engine_kwargs, - verbose=verbose, - with_output=True, - docker_images={sorter_name: docker_image}, - singularity_images={sorter_name: singularity_image}, - sorter_params={sorter_name: sorter_params}, - ) - grouping_property_values = None - sorting_list = [] - for output_name, sorting in sorting_output.items(): - prop_name, sorter_name = output_name - sorting_list.append(sorting) - if grouping_property_values is None: - grouping_property_values = np.array( - [prop_name] * len(sorting.get_unit_ids()), dtype=np.dtype(type(prop_name)) - ) - else: - grouping_property_values = np.concatenate( - (grouping_property_values, [prop_name] * len(sorting.get_unit_ids())) - ) + job_list = [] + for k, rec in recording_dict.items(): + job = dict( + sorter_name=sorter_name, + recording=rec, + output_folder=working_folder / str(k), + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **sorter_params, + ) + job_list.append(job) + + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=True) + + unit_groups = [] + for sorting, group in zip(sorting_list, recording_dict.keys()): + num_units = sorting.get_unit_ids().size + unit_groups.extend([group] * num_units) + unit_groups = np.array(unit_groups) aggregate_sorting = aggregate_units(sorting_list) - aggregate_sorting.set_property(key=grouping_property, values=grouping_property_values) + aggregate_sorting.set_property(key=grouping_property, values=unit_groups) aggregate_sorting.register_recording(recording) return aggregate_sorting +# This is deprecated and will be removed def run_sorters( sorter_list, recording_dict_or_list, @@ -180,7 +314,9 @@ def run_sorters( docker_images={}, singularity_images={}, ): - """Run several sorter on several recordings. + """ + This function is deprecated and will be removed in version 0.100 + Please use run_sorter_jobs() instead. Parameters ---------- @@ -221,6 +357,13 @@ def run_sorters( results : dict The output is nested dict[(rec_name, sorter_name)] of SortingExtractor. """ + + warnings.warn( + "run_sorters() is deprecated please use run_sorter_jobs() instead. This will be removed in 0.100", + DeprecationWarning, + stacklevel=2, + ) + working_folder = Path(working_folder) mode_if_folder_exists in ("raise", "keep", "overwrite") @@ -247,8 +390,7 @@ def run_sorters( dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" - need_dump = engine != "loop" - task_args_list = [] + job_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / str(rec_name) / sorter_name @@ -268,181 +410,21 @@ def run_sorters( params = sorter_params.get(sorter_name, {}) docker_image = docker_images.get(sorter_name, None) singularity_image = singularity_images.get(sorter_name, None) - _check_container_images(docker_image, singularity_image, sorter_name) - - if need_dump: - if not recording.check_if_dumpable(): - raise Exception("recording not dumpable call recording.save() before") - recording_arg = recording.to_dict(recursive=True) - else: - recording_arg = recording - - task_args = ( - sorter_name, - recording_arg, - output_folder, - verbose, - params, - docker_image, - singularity_image, - with_output, - ) - task_args_list.append(task_args) - if engine == "loop": - # simple loop in main process - for task_args in task_args_list: - _run_one(task_args) - - elif engine == "joblib": - from joblib import Parallel, delayed - - n_jobs = engine_kwargs.get("n_jobs", -1) - backend = engine_kwargs.get("backend", "loky") - Parallel(n_jobs=n_jobs, backend=backend)(delayed(_run_one)(task_args) for task_args in task_args_list) - - elif engine == "dask": - client = engine_kwargs.get("client", None) - assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" - - tasks = [] - for task_args in task_args_list: - task = client.submit(_run_one, task_args) - tasks.append(task) - - for task in tasks: - task.result() - - elif engine == "slurm": - # generate python script for slurm - tmp_script_folder = engine_kwargs.get("tmp_script_folder", None) - if tmp_script_folder is None: - tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") - tmp_script_folder = Path(tmp_script_folder) - cpus_per_task = engine_kwargs.get("cpus_per_task", 1) - mem = engine_kwargs.get("mem", "1G") - - for i, task_args in enumerate(task_args_list): - script_name = tmp_script_folder / f"si_script_{i}.py" - with open(script_name, "w") as f: - arg_list_txt = "(\n" - for j, arg in enumerate(task_args): - arg_list_txt += "\t" - if j != 1: - if isinstance(arg, str): - arg_list_txt += f'"{arg}"' - elif isinstance(arg, Path): - arg_list_txt += f'"{str(arg.absolute())}"' - else: - arg_list_txt += f"{arg}" - else: - arg_list_txt += "recording" - arg_list_txt += ",\r" - arg_list_txt += ")" - - recording_dict = task_args[1] - slurm_script = _slurm_script.format( - python=sys.executable, recording_dict=recording_dict, arg_list_txt=arg_list_txt - ) - f.write(slurm_script) - os.fchmod(f.fileno(), mode=stat.S_IRWXU) - - print(slurm_script) - - subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=output_folder, + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **params, + ) + job_list.append(job) - non_blocking_engine = ("loop", "joblib") - if engine in non_blocking_engine: - # dump spikeinterface_job.json - # only for non blocking engine - for rec_name, recording in recording_dict.items(): - for sorter_name in sorter_list: - output_folder = working_folder / str(rec_name) / sorter_name - with open(output_folder / "spikeinterface_job.json", "w") as f: - dump_dict = {"rec_name": rec_name, "sorter_name": sorter_name, "engine": engine} - if engine != "dask": - dump_dict.update({"engine_kwargs": engine_kwargs}) - json.dump(check_json(dump_dict), f) + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=with_output) if with_output: - if engine not in non_blocking_engine: - print( - f'Warning!! With engine="{engine}" you cannot have directly output results\n' - "Use : run_sorters(..., with_output=False)\n" - "And then: results = collect_sorting_outputs(output_folders)" - ) - return - - results = collect_sorting_outputs(working_folder) + keys = [(rec_name, sorter_name) for rec_name in recording_dict for sorter_name in sorter_list] + results = dict(zip(keys, sorting_list)) return results - - -_slurm_script = """#! {python} -from numpy import array -from spikeinterface.sorters.launcher import _run_one - -recording = {recording_dict} - -arg_list = {arg_list_txt} - -_run_one(arg_list) -""" - - -def is_log_ok(output_folder): - # log is OK when run_time is not None - if (output_folder / "spikeinterface_log.json").is_file(): - with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - ok = run_time is not None - return ok - return False - - -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def collect_sorting_outputs(working_folder): - """Collect results in a working_folder. - - The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor. - """ - results = {} - for rec_name, sorter_name, sorting in iter_sorting_output(working_folder): - results[(rec_name, sorter_name)] = sorting - return results - - -def _check_container_images(docker_image, singularity_image, sorter_name): - if docker_image is not None: - assert singularity_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" - if singularity_image is not None: - assert docker_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index cd8bc0fa5d..14c938f8ba 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,4 +1,5 @@ import os +import sys import shutil import time @@ -6,8 +7,10 @@ from pathlib import Path from spikeinterface.core import load_extractor -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import run_sorters, run_sorter_by_property, collect_sorting_outputs + +# from spikeinterface.extractors import toy_example +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters import run_sorter_jobs, run_sorters, run_sorter_by_property if hasattr(pytest, "global_test_folder"): @@ -15,10 +18,17 @@ else: cache_folder = Path("cache_folder") / "sorters" +base_output = cache_folder / "sorter_output" + +# no need to have many +num_recordings = 2 +sorters = ["tridesclous2"] + def setup_module(): - rec, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) - for i in range(4): + base_seed = 42 + for i in range(num_recordings): + rec, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) rec_folder = cache_folder / f"toy_rec_{i}" if rec_folder.is_dir(): shutil.rmtree(rec_folder) @@ -31,19 +41,106 @@ def setup_module(): rec.save(folder=rec_folder) -def test_run_sorters_with_list(): - working_folder = cache_folder / "test_run_sorters_list" +def get_job_list(): + jobs = [] + for i in range(num_recordings): + for sorter_name in sorters: + recording = load_extractor(cache_folder / f"toy_rec_{i}") + kwargs = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=base_output / f"{sorter_name}_rec{i}", + verbose=True, + raise_error=False, + ) + jobs.append(kwargs) + + return jobs + + +@pytest.fixture(scope="module") +def job_list(): + return get_job_list() + + +def test_run_sorter_jobs_loop(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs(job_list, engine="loop", return_output=True) + print(sortings) + + +def test_run_sorter_jobs_joblib(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="joblib", engine_kwargs=dict(n_jobs=2, backend="loky"), return_output=True + ) + print(sortings) + + +def test_run_sorter_jobs_processpoolexecutor(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="processpoolexecutor", engine_kwargs=dict(max_workers=2), return_output=True + ) + print(sortings) + + +@pytest.mark.skipif(True, reason="This is tested locally") +def test_run_sorter_jobs_dask(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + # create a dask Client for a slurm queue + from dask.distributed import Client + + test_mode = "local" + # test_mode = "client_slurm" + + if test_mode == "local": + client = Client() + elif test_mode == "client_slurm": + from dask_jobqueue import SLURMCluster + + cluster = SLURMCluster( + processes=1, + cores=1, + memory="12GB", + python=sys.executable, + walltime="12:00:00", + ) + cluster.scale(2) + client = Client(cluster) + + # dask + t0 = time.perf_counter() + run_sorter_jobs(job_list, engine="dask", engine_kwargs=dict(client=client)) + t1 = time.perf_counter() + print(t1 - t0) + + +@pytest.mark.skip("Slurm launcher need a machine with slurm") +def test_run_sorter_jobs_slurm(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + working_folder = cache_folder / "test_run_sorters_slurm" if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable - rec0 = load_extractor(cache_folder / "toy_rec_0") - rec1 = load_extractor(cache_folder / "toy_rec_1") - - recording_list = [rec0, rec1] - sorter_list = ["tridesclous"] + tmp_script_folder = working_folder / "slurm_scripts" - run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + run_sorter_jobs( + job_list, + engine="slurm", + engine_kwargs=dict( + tmp_script_folder=tmp_script_folder, + cpus_per_task=32, + mem="32G", + ), + ) def test_run_sorter_by_property(): @@ -59,7 +156,7 @@ def test_run_sorter_by_property(): rec0_by = rec0.split_by("group") group_names0 = list(rec0_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1, engine="loop", verbose=False) assert "group" in sorting0.get_property_keys() assert all([g in group_names0 for g in sorting0.get_property("group")]) @@ -68,12 +165,31 @@ def test_run_sorter_by_property(): rec1_by = rec1.split_by("group") group_names1 = list(rec1_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2, engine="loop", verbose=False) assert "group" in sorting1.get_property_keys() assert all([g in group_names1 for g in sorting1.get_property("group")]) +# run_sorters is deprecated +# This will test will be removed in next release +def test_run_sorters_with_list(): + working_folder = cache_folder / "test_run_sorters_list" + if working_folder.is_dir(): + shutil.rmtree(working_folder) + + # make dumpable + rec0 = load_extractor(cache_folder / "toy_rec_0") + rec1 = load_extractor(cache_folder / "toy_rec_1") + + recording_list = [rec0, rec1] + sorter_list = ["tridesclous2"] + + run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + + +# run_sorters is deprecated +# This will test will be removed in next release def test_run_sorters_with_dict(): working_folder = cache_folder / "test_run_sorters_dict" if working_folder.is_dir(): @@ -84,9 +200,9 @@ def test_run_sorters_with_dict(): recording_dict = {"toy_tetrode": rec0, "toy_octotrode": rec1} - sorter_list = ["tridesclous", "tridesclous2"] + sorter_list = ["tridesclous2"] - sorter_params = {"tridesclous": dict(detect_threshold=5.6), "tridesclous2": dict()} + sorter_params = {"tridesclous2": dict()} # simple loop t0 = time.perf_counter() @@ -116,143 +232,19 @@ def test_run_sorters_with_dict(): ) -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_joblib(): - working_folder = cache_folder / "test_run_sorters_joblib" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # joblib - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder / "with_joblib", - engine="joblib", - engine_kwargs={"n_jobs": 4}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_dask(): - working_folder = cache_folder / "test_run_sorters_dask" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # create a dask Client for a slurm queue - from dask.distributed import Client - from dask_jobqueue import SLURMCluster - - python = "/home/samuel.garcia/.virtualenvs/py36/bin/python3.6" - cluster = SLURMCluster( - processes=1, - cores=1, - memory="12GB", - python=python, - walltime="12:00:00", - ) - cluster.scale(5) - client = Client(cluster) - - # dask - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="dask", - engine_kwargs={"client": client}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_slurm(): - working_folder = cache_folder / "test_run_sorters_slurm" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - # create recording - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "spykingcircus2", - "tridesclous2", - ] - - tmp_script_folder = working_folder / "slurm_scripts" - tmp_script_folder.mkdir(parents=True) - - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="slurm", - engine_kwargs={ - "tmp_script_folder": tmp_script_folder, - "cpus_per_task": 32, - "mem": "32G", - }, - with_output=False, - mode_if_folder_exists="keep", - verbose=True, - ) - - -def test_collect_sorting_outputs(): - working_folder = cache_folder / "test_run_sorters_dict" - results = collect_sorting_outputs(working_folder) - print(results) - - -def test_sorter_installation(): - # This import is to get error on github when import fails - import tridesclous - - # import circus - - if __name__ == "__main__": - setup_module() - # pass - # test_run_sorters_with_list() - - # test_run_sorter_by_property() + # setup_module() + job_list = get_job_list() - test_run_sorters_with_dict() + # test_run_sorter_jobs_loop(job_list) + # test_run_sorter_jobs_joblib(job_list) + # test_run_sorter_jobs_processpoolexecutor(job_list) + # test_run_sorter_jobs_multiprocessing(job_list) + # test_run_sorter_jobs_dask(job_list) + test_run_sorter_jobs_slurm(job_list) - # test_run_sorters_joblib() - - # test_run_sorters_dask() - - # test_run_sorters_slurm() + # test_run_sorter_by_property() - # test_collect_sorting_outputs() + # this deprecated + # test_run_sorters_with_list() + # test_run_sorters_with_dict() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 07c7db155c..4efabbc9c5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -502,7 +502,7 @@ def plot_errors_matching(benchmark, comp, unit_id, nb_spikes=200, metric="cosine seg_num = 0 # TODO: make compatible with multiple segments idx_1 = np.where(comp.get_labels1(unit_id)[seg_num] == label) idx_2 = benchmark.we.get_sampled_indices(unit_id)["spike_index"] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] if len(intersection) == 0: print(f"No {label}s found for unit {unit_id}") @@ -552,7 +552,7 @@ def plot_errors_matching_all_neurons(benchmark, comp, nb_spikes=200, metric="cos for label in ["TP", "FN"]: idx_1 = np.where(comp.get_labels1(unit_id) == label)[0] - intersection = np.where(np.in1d(idx_2, idx_1))[0] + intersection = np.where(np.isin(idx_2, idx_1))[0] intersection = np.random.permutation(intersection)[:nb_spikes] wfs_sliced = wfs[intersection, :, :] @@ -600,29 +600,38 @@ def plot_comparison_matching( else: ax = axs[j] comp1, comp2 = comp_per_method[method1], comp_per_method[method2] - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - if j == 0: - ax.set_ylabel(f"{method1}") - else: - ax.set_yticks([]) - if i == num_methods - 1: - ax.set_xlabel(f"{method2}") + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + if j == i: + ax.set_ylabel(f"{method1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{method2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) return fig, axs diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 1514a63dd4..73497a59fd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -133,7 +133,7 @@ def run(self, peaks=None, positions=None, delta=0.2): matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] - garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) + garbage_matches = ~np.isin(np.arange(len(times2)), self.good_matches) garbage_channels = self.peaks["channel_index"][garbage_matches] garbage_peaks = times2[garbage_matches] nb_garbage = len(garbage_peaks) @@ -365,7 +365,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["full_gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["full_gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.gt_peaks["amplitude"][mask]) ax.scatter(self.gt_positions["x"][mask], self.gt_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.gt_positions["x"][mask].mean(), self.gt_positions["y"][mask].mean()) @@ -391,7 +391,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["gt"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["gt"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.sliced_gt_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.sliced_gt_peaks["amplitude"][mask]) ax.scatter( self.sliced_gt_positions["x"][mask], self.sliced_gt_positions["y"][mask], c=colors, s=1, alpha=0.5 @@ -420,7 +420,7 @@ def plot_clusters_amplitudes(self, title=None, show_probe=False, clim=(-100, 0), idx = self.waveforms["garbage"].get_sampled_indices(unit_id)["spike_index"] all_spikes = self.waveforms["garbage"].sorting.get_unit_spike_train(unit_id) - mask = np.in1d(self.garbage_peaks["sample_index"], all_spikes[idx]) + mask = np.isin(self.garbage_peaks["sample_index"], all_spikes[idx]) colors = scalarMap.to_rgba(self.garbage_peaks["amplitude"][mask]) ax.scatter(self.garbage_positions["x"][mask], self.garbage_positions["y"][mask], c=colors, s=1, alpha=0.5) x_mean, y_mean = (self.garbage_positions["x"][mask].mean(), self.garbage_positions["y"][mask].mean()) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6edf5af16b..b87bbc7cee 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -30,7 +30,7 @@ def _split_waveforms( local_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(cluster_probability > probability_thr) - local_labels_with_noise[~np.in1d(local_labels_with_noise, persistent_clusters)] = -1 + local_labels_with_noise[~np.isin(local_labels_with_noise, persistent_clusters)] = -1 # remove super small cluster labels, count = np.unique(local_labels_with_noise[:valid_size], return_counts=True) @@ -43,7 +43,7 @@ def _split_waveforms( to_remove = labels[(count / valid_size) < minimum_cluster_size_ratio] # ~ print('to_remove', to_remove, count / valid_size) if to_remove.size > 0: - local_labels_with_noise[np.in1d(local_labels_with_noise, to_remove)] = -1 + local_labels_with_noise[np.isin(local_labels_with_noise, to_remove)] = -1 local_labels_with_noise[valid_size:] = -2 @@ -123,7 +123,7 @@ def _split_waveforms_nested( active_labels_with_noise = clustering[0] cluster_probability = clustering[2] (persistent_clusters,) = np.nonzero(clustering[2] > probability_thr) - active_labels_with_noise[~np.in1d(active_labels_with_noise, persistent_clusters)] = -1 + active_labels_with_noise[~np.isin(active_labels_with_noise, persistent_clusters)] = -1 active_labels = active_labels_with_noise[active_ind < valid_size] active_labels_set = np.unique(active_labels) @@ -381,9 +381,9 @@ def auto_clean_clustering( continue wfs0 = wfs_arrays[label0] - wfs0 = wfs0[:, :, np.in1d(channel_inds0, used_chans)] + wfs0 = wfs0[:, :, np.isin(channel_inds0, used_chans)] wfs1 = wfs_arrays[label1] - wfs1 = wfs1[:, :, np.in1d(channel_inds1, used_chans)] + wfs1 = wfs1[:, :, np.isin(channel_inds1, used_chans)] # TODO : remove assert wfs0.shape[2] == wfs1.shape[2] @@ -536,7 +536,6 @@ def remove_duplicates_via_matching( waveform_extractor, noise_levels, peak_labels, - sparsify_threshold=1, method_kwargs={}, job_kwargs={}, tmp_folder=None, @@ -552,6 +551,10 @@ def remove_duplicates_via_matching( from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) + + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() nb_templates = len(templates) duration = waveform_extractor.nbefore + waveform_extractor.nafter @@ -559,9 +562,9 @@ def remove_duplicates_via_matching( fs = waveform_extractor.recording.get_sampling_frequency() num_chans = waveform_extractor.recording.get_num_channels() - for t in range(nb_templates): - is_silent = templates[t].ptp(0) < sparsify_threshold - templates[t, :, is_silent] = 0 + if waveform_extractor.is_sparse(): + for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + templates[count][:, ~sparsity[count]] = 0 zdata = templates.reshape(nb_templates, -1) @@ -581,6 +584,7 @@ def remove_duplicates_via_matching( recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) + recording = recording.set_probe(waveform_extractor.recording.get_probe()) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) half_marging = margin // 2 @@ -597,7 +601,6 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "sparsify_threshold": sparsify_threshold, "omp_min_sps": 0.1, "templates": None, "overlaps": None, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index fcbcac097f..be8ecd6702 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -41,7 +41,6 @@ class RandomProjectionClustering: "ms_before": 1.5, "ms_after": 1.5, "random_seed": 42, - "cleaning_method": "matching", "shared_memory": False, "min_values": {"ptp": 0, "energy": 0}, "tmp_folder": None, @@ -160,86 +159,60 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - cleaning_method = params["cleaning_method"] - if verbose: - print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method)) - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=True, - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = {} - for label in labels: - mask = label == peak_labels - wfs_arrays[label] = hdbscan_data[mask] - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels, **params["cleaning_kwargs"]) - - elif cleaning_method == "matching": - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - - sorting_folder = tmp_folder / "sorting" - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - sorting = sorting.save(folder=sorting_folder) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - return_scaled=False, - mode=mode, - ) - - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False - cleaning_matching_params["progress_bar"] = False - - cleaning_params = params["cleaning_kwargs"].copy() - cleaning_params["tmp_folder"] = tmp_folder - - labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params - ) - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + # create a tmp folder + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" + else: + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + sorting = sorting.save(folder=sorting_folder) + we = extract_waveforms( + recording, + sorting, + waveform_folder, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + **params["job_kwargs"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + if params["tmp_folder"] is None: + shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") + shutil.rmtree(tmp_folder / "sorting") if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index aeec14158f..08ce9f6791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -198,7 +198,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): for chan_ind in prev_local_chan_inds: if total_count[chan_ind] == 0: continue - # ~ inds, = np.nonzero(np.in1d(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) + # ~ inds, = np.nonzero(np.isin(peaks['channel_index'], closest_channels[chan_ind]) & (peak_labels==0)) (inds,) = np.nonzero((peaks["channel_index"] == chan_ind) & (peak_labels == 0)) if inds.size <= d["min_spike_on_channel"]: chan_amps[chan_ind] = 0.0 @@ -235,12 +235,12 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # TODO: only for debug, remove later - assert np.all(np.in1d(local_chan_inds, wf_chans)) + assert np.all(np.isin(local_chan_inds, wf_chans)) # none label spikes wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, local_chan_inds)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, local_chan_inds)] wfs.append(wfs_chan) # put noise to enhance clusters @@ -517,7 +517,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, (wf_chans,) = np.nonzero(sparsity_mask[chan_ind]) # print('wf_chans', wf_chans) # TODO: only for debug, remove later - assert np.all(np.in1d(wanted_chans, wf_chans)) + assert np.all(np.isin(wanted_chans, wf_chans)) wfs_chan = wfs_arrays[chan_ind] # TODO: only for debug, remove later @@ -525,7 +525,7 @@ def _collect_sparse_waveforms(peaks, wfs_arrays, closest_channels, peak_labels, wfs_chan = wfs_chan[inds, :, :] # only some channels - wfs_chan = wfs_chan[:, :, np.in1d(wf_chans, wanted_chans)] + wfs_chan = wfs_chan[:, :, np.isin(wf_chans, wanted_chans)] wfs.append(wfs_chan) wfs = np.concatenate(wfs, axis=0) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 2196320378..a19e7b71b5 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -5,7 +5,6 @@ import scipy.spatial -from tqdm import tqdm import scipy try: @@ -16,7 +15,8 @@ except ImportError: HAVE_SKLEARN = False -from spikeinterface.core import get_noise_levels, get_random_data_chunks + +from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -131,6 +131,38 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret +def compute_overlaps(templates, num_samples, num_channels, sparsities): + num_templates = len(templates) + + dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) + for i in range(num_templates): + dense_templates[i, :, sparsities[i]] = templates[i].T + + size = 2 * num_samples - 1 + + all_delays = list(range(0, num_samples + 1)) + + overlaps = {} + + for delay in all_delays: + source = dense_templates[:, :delay, :].reshape(num_templates, -1) + target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) + + overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + + if delay < num_samples: + overlaps[size - delay + 1] = overlaps[delay].T.tocsr() + + new_overlaps = [] + + for i in range(num_templates): + data = [overlaps[j][i, :].T for j in range(size)] + data = scipy.sparse.hstack(data) + new_overlaps += [data] + + return new_overlaps + + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -152,21 +184,18 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain. ptp limit for considering a channel as silent - smoothing_factor: float - Templates are smoothed via Spline Interpolation noise_levels: array The noise levels, for every channels. If None, they will be automatically computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- """ _default_params = { - "sparsify_threshold": 1, "amplitudes": [0.6, 2], "omp_min_sps": 0.1, "waveform_extractor": None, @@ -175,36 +204,21 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], + "vicinity": 0, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold): - is_silent = template.ptp(0) < sparsify_threshold - template[:, is_silent] = 0 - (active_channels,) = np.where(np.logical_not(is_silent)) - - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): waveform_extractor = d["waveform_extractor"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] num_templates = len(d["waveform_extractor"].sorting.unit_ids) + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() d["sparsities"] = {} @@ -212,52 +226,10 @@ def _prepare_templates(cls, d): d["norms"] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - if d["smoothing_factor"] > 0: - template = cls._regularize_template(templates[count], d["smoothing_factor"]) - else: - template = templates[count] - template, active_channels = cls._sparsify_template(template, d["sparsify_threshold"]) - d["sparsities"][count] = active_channels + template = templates[count][:, sparsity[count]] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template[:, active_channels] / d["norms"][count] - - return d - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - sparsities = d["sparsities"] - - dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) - for i in range(num_templates): - dense_templates[i, :, sparsities[i]] = templates[i].T - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay + 1] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"][count] = template / d["norms"][count] return d @@ -276,6 +248,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["nbefore"] = d["waveform_extractor"].nbefore d["nafter"] = d["waveform_extractor"].nafter d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] if d["noise_levels"] is None: print("CircusOMPPeeler : noise should be computed outside") @@ -290,15 +263,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_templates"] = len(d["templates"]) if d["overlaps"] is None: - d = cls._prepare_overlaps(d) + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - norms = d["norms"] - sparsities = d["sparsities"] - - nb_active_channels = np.array([len(sparsities[i]) for i in range(d["num_templates"])]) + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) return d @@ -336,6 +306,7 @@ def main_function(cls, traces, d): sparsities = d["sparsities"] ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] if "cached_fft_kernels" not in d: d["cached_fft_kernels"] = {"fshape": 0} @@ -381,7 +352,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((num_peaks, num_peaks), dtype=np.float32) + M = np.zeros((100, 100), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -393,6 +364,8 @@ def main_function(cls, traces, d): cached_overlaps = {} is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) while np.any(is_valid): best_amplitude_ind = scalar_products[is_valid].argmax() @@ -412,20 +385,39 @@ def main_function(cls, traces, d): M = Z M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 else: M[0, 0] = 1 @@ -435,9 +427,16 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - - all_amplitudes /= norms[selection[0]] + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] @@ -515,14 +514,12 @@ class CircusPeeler(BaseTemplateMatchingEngine): Maximal amplitude allowed for every template min_amplitude: float Minimal amplitude allowed for every template - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain a given fraction of the total norm use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - progress_bar_steps: bool - In order to display or not steps from the algorithm + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- @@ -535,68 +532,40 @@ class CircusPeeler(BaseTemplateMatchingEngine): "detect_threshold": 5, "noise_levels": None, "random_chunk_kwargs": {}, - "sparsify_threshold": 0.99, "max_amplitude": 1.5, "min_amplitude": 0.5, "use_sparse_matrix_threshold": 0.25, - "progess_bar_steps": False, "waveform_extractor": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold, noise_levels): - is_silent = template.std(0) < 0.1 * noise_levels - - template[:, is_silent] = 0 - - channel_norms = np.linalg.norm(template, axis=0) ** 2 - total_norm = np.linalg.norm(template) ** 2 - - idx = np.argsort(channel_norms)[::-1] - explained_norms = np.cumsum(channel_norms[idx] / total_norm) - channel = np.searchsorted(explained_norms, sparsify_threshold) - active_channels = np.sort(idx[:channel]) - template[:, idx[channel:]] = 0 - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): - parameters = d - waveform_extractor = parameters["waveform_extractor"] - num_samples = parameters["num_samples"] - num_channels = parameters["num_channels"] - num_templates = parameters["num_templates"] - max_amplitude = parameters["max_amplitude"] - min_amplitude = parameters["min_amplitude"] - use_sparse_matrix_threshold = parameters["use_sparse_matrix_threshold"] + waveform_extractor = d["waveform_extractor"] + num_samples = d["num_samples"] + num_channels = d["num_channels"] + num_templates = d["num_templates"] + use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] + + d["norms"] = np.zeros(num_templates, dtype=np.float32) - parameters["norms"] = np.zeros(num_templates, dtype=np.float32) + all_units = list(d["waveform_extractor"].sorting.unit_ids) - all_units = list(parameters["waveform_extractor"].sorting.unit_ids) + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask templates = waveform_extractor.get_all_templates(mode="median").copy() + d["sparsities"] = {} + d["circus_templates"] = {} for count, unit_id in enumerate(all_units): - if parameters["smoothing_factor"] > 0: - templates[count] = cls._regularize_template(templates[count], parameters["smoothing_factor"]) - - templates[count], _ = cls._sparsify_template( - templates[count], parameters["sparsify_threshold"], parameters["noise_levels"] - ) - parameters["norms"][count] = np.linalg.norm(templates[count]) - templates[count] /= parameters["norms"][count] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + templates[count][:, ~sparsity[count]] = 0 + d["norms"][count] = np.linalg.norm(templates[count]) + templates[count] /= d["norms"][count] + d["circus_templates"][count] = templates[count][:, sparsity[count]] templates = templates.reshape(num_templates, -1) @@ -604,54 +573,11 @@ def _prepare_templates(cls, d): if nnz <= use_sparse_matrix_threshold: templates = scipy.sparse.csr_matrix(templates) print(f"Templates are automatically sparsified (sparsity level is {nnz})") - parameters["is_dense"] = False + d["is_dense"] = False else: - parameters["is_dense"] = True - - parameters["templates"] = templates + d["is_dense"] = True - return parameters - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - is_dense = d["is_dense"] - - if not is_dense: - dense_templates = templates.toarray() - else: - dense_templates = templates - - dense_templates = dense_templates.reshape(num_templates, num_samples, num_channels) - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - if d["progess_bar_steps"]: - all_delays = tqdm(all_delays, desc="[1] compute overlaps") - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"] = templates return d @@ -687,15 +613,13 @@ def _optimize_amplitudes(cls, noise_snippets, d): alpha = 0.5 norms = parameters["norms"] all_units = list(waveform_extractor.sorting.unit_ids) - if parameters["progess_bar_steps"]: - all_units = tqdm(all_units, desc="[2] compute amplitudes") parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) noise = templates.dot(noise_snippets) / norms[:, np.newaxis] all_amps = {} for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id) + waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) snippets = waveform.reshape(waveform.shape[0], -1).T amps = templates.dot(snippets) / norms[:, np.newaxis] good = amps[count, :].flatten() @@ -708,16 +632,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) parameters["amplitudes"][count] = res.x - # import pylab as plt - # plt.hist(good, 100, alpha=0.5) - # plt.hist(bad, 100, alpha=0.5) - # plt.hist(noise[count], 100, alpha=0.5) - # ymin, ymax = plt.ylim() - # plt.plot([res.x[0], res.x[0]], [ymin, ymax], 'k--') - # plt.plot([res.x[1], res.x[1]], [ymin, ymax], 'k--') - # plt.savefig('test_%d.png' %count) - # plt.close() - return d @classmethod @@ -727,8 +641,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["sparsify_threshold", "use_sparse_matrix_threshold"]: + for v in ["use_sparse_matrix_threshold"]: assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() @@ -746,7 +659,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) default_parameters = cls._prepare_templates(default_parameters) - default_parameters = cls._prepare_overlaps(default_parameters) + + default_parameters["overlaps"] = compute_overlaps( + default_parameters["circus_templates"], + default_parameters["num_samples"], + default_parameters["num_channels"], + default_parameters["sparsities"], + ) default_parameters["exclude_sweep_size"] = int( default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 @@ -817,31 +736,31 @@ def main_function(cls, traces, d): sym_patch = d["sym_patch"] peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakByChannel.detect_peaks( + peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( peak_traces, peak_sign, abs_threholds, exclude_sweep_size ) if jitter > 0: - jittered_peaks = peak_sample_ind[:, np.newaxis] + np.arange(-jitter, jitter) + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] - peak_sample_ind, unique_idx = np.unique(jittered_peaks, return_index=True) + peak_sample_index, unique_idx = np.unique(jittered_peaks, return_index=True) peak_chan_ind = jittered_channels[unique_idx] else: - peak_sample_ind, unique_idx = np.unique(peak_sample_ind, return_index=True) + peak_sample_index, unique_idx = np.unique(peak_sample_index, return_index=True) peak_chan_ind = peak_chan_ind[unique_idx] - num_peaks = len(peak_sample_ind) + num_peaks = len(peak_sample_index) if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_ind] - peak_sample_ind += margin // 2 + snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] + peak_sample_index += margin // 2 else: - peak_sample_ind += margin // 2 + peak_sample_index += margin // 2 snippet_window = np.arange(-d["nbefore"], d["nafter"]) - snippets = traces[peak_sample_ind[:, np.newaxis] + snippet_window] + snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) @@ -865,10 +784,10 @@ def main_function(cls, traces, d): best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) best_amplitude = scalar_products[best_cluster_ind, peak_index] - best_peak_sample_ind = peak_sample_ind[peak_index] + best_peak_sample_index = peak_sample_index[peak_index] best_peak_chan_ind = peak_chan_ind[peak_index] - peak_data = peak_sample_ind - peak_sample_ind[peak_index] + peak_data = peak_sample_index - peak_sample_index[peak_index] is_valid_nn = np.searchsorted(peak_data, [-neighbor_window, neighbor_window + 1]) idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window @@ -880,7 +799,7 @@ def main_function(cls, traces, d): scalar_products[:, is_valid_nn[0] : is_valid_nn[1]] += to_add scalar_products[best_cluster_ind, is_valid_nn[0] : is_valid_nn[1]] = -np.inf - spikes["sample_index"][num_spikes] = best_peak_sample_ind + spikes["sample_index"][num_spikes] = best_peak_sample_index spikes["channel_index"][num_spikes] = best_peak_chan_ind spikes["cluster_index"][num_spikes] = best_cluster_ind spikes["amplitude"][num_spikes] = best_amplitude diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..9593f14d1c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,8 +1,3 @@ -# basics -# from .timeseries import plot_timeseries, TracesWidget -from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget - # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget @@ -15,9 +10,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget -from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py index 939475c17d..9715b7ea87 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/activity.py @@ -95,8 +95,7 @@ def plot(self): num_frames = int(duration / self.bin_duration_s) def animate_func(i): - i0 = np.searchsorted(peaks["sample_index"], bin_size * i) - i1 = np.searchsorted(peaks["sample_index"], bin_size * (i + 1)) + i0, i1 = np.searchsorted(peaks["sample_index"], [bin_size * i, bin_size * (i + 1)]) local_peaks = peaks[i0:i1] artists = self._plot_one_bin(rec, probe, local_peaks, self.bin_duration_s) return artists diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py deleted file mode 100644 index d05373103e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class RasterWidget(BaseWidget): - """ - Plots spike train rasters. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - color: matplotlib color - The color to be used - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: RasterWidget - The output widget - """ - - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._color = color - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Raster" - - def plot(self): - self._do_plot() - - def _do_plot(self): - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - import matplotlib.pyplot as plt - - with plt.rc_context({"axes.edgecolor": "gray"}): - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - self.ax.plot( - spiketimes, - u_i * np.ones_like(spiketimes), - marker="|", - mew=1, - markersize=3, - ls="", - color=self._color, - ) - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlim(visible_start_frame, visible_end_frame) - self.ax.set_xlabel("time (s)") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - # trange[0] += max_t - trange[1] - trange[1] = self._max_frame - if trange[0] < 0: - # trange[1] += -trange[0] - trange[0] = 0 - # trange[0] = np.maximum(0, trange[0]) - # trange[1] = np.minimum(max_t, trange[1]) - return trange - - -def plot_rasters(*args, **kwargs): - W = RasterWidget(*args, **kwargs) - W.plot() - return W - - -plot_rasters.__doc__ = RasterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 5004765251..39eb80e2e5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,44 +43,6 @@ def setUp(self): def tearDown(self): pass - # def test_timeseries(self): - # sw.plot_timeseries(self._rec, mode='auto') - # sw.plot_timeseries(self._rec, mode='line', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True, order_channel_by_depth=True) - - def test_rasters(self): - sw.plot_rasters(self._sorting) - - def test_plot_probe_map(self): - sw.plot_probe_map(self._rec) - sw.plot_probe_map(self._rec, with_channel_ids=True) - - # TODO - # def test_spectrum(self): - # sw.plot_spectrum(self._rec) - - # TODO - # def test_spectrogram(self): - # sw.plot_spectrogram(self._rec, channel=0) - - # def test_unitwaveforms(self): - # w = sw.plot_unit_waveforms(self._we) - # unit_ids = self._sorting.unit_ids[:6] - # sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids) - # sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids) - - # def test_plot_unit_waveform_density_map(self): - # unit_ids = self._sorting.unit_ids[:3] - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=4) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=50) - # - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=25, same_axis=True) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=2, same_axis=True) - - # def test_unittemplates(self): - # sw.plot_unit_templates(self._we) - def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) sw.plot_unit_probe_map(self._we, animated=True) @@ -120,12 +82,6 @@ def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - def test_confusion(self): - sw.plot_confusion_matrix(self._gt_comp, count_text=True) - - def test_agreement(self): - sw.plot_agreement_matrix(self._gt_comp, count_text=True) - def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) @@ -150,8 +106,6 @@ def test_sorting_performance(self): mytest.setUp() # ~ mytest.test_timeseries() - # ~ mytest.test_rasters() - mytest.test_plot_probe_map() # ~ mytest.test_unitwaveforms() # ~ mytest.test_plot_unit_waveform_density_map() # mytest.test_unittemplates() @@ -169,8 +123,6 @@ def test_sorting_performance(self): # ~ mytest.test_plot_drift_over_time() # ~ mytest.test_plot_peak_activity_map() - # mytest.test_confusion() - # mytest.test_agreement() # ~ mytest.test_multicomp_graph() #  mytest.test_sorting_performance() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py deleted file mode 100644 index ab6fa2ace5..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.ticker import MaxNLocator -from .basewidget import BaseWidget - -import scipy.spatial - - -class TracesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - segment_index: None or int - The segment index. - channel_ids: list - The channel ids to display. - order_channel_by_depth: boolean - Reorder channel by depth. - time_range: list - List with start time and end time - mode: 'line' or 'map' or 'auto' - 2 possible mode: - * 'line' : classical for low channel count - * 'map' : for high channel count use color heat map - * 'auto' : auto switch depending the channel count <32ch - cmap: str default 'RdBu' - matplotlib colormap used in mode 'map' - show_channel_ids: bool - Set yticks with channel ids - color_groups: bool - If True groups are plotted with different colors - color: matplotlib color, default: None - The color used to draw the traces. - clim: None or tupple - When mode='map' this control color lims - with_colorbar: bool default True - When mode='map' add colorbar - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: TracesWidget - The output widget - """ - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - cmap="RdBu", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - with_colorbar=True, - figure=None, - ax=None, - **plot_kwargs, - ): - BaseWidget.__init__(self, figure, ax) - self.recording = recording - self._sampling_frequency = recording.get_sampling_frequency() - self.visible_channel_ids = channel_ids - self._plot_kwargs = plot_kwargs - - if segment_index is None: - nseg = recording.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - self.segment_index = segment_index - - if self.visible_channel_ids is None: - self.visible_channel_ids = recording.get_channel_ids() - - if order_channel_by_depth: - locations = self.recording.get_channel_locations() - channel_inds = self.recording.ids_to_indices(self.visible_channel_ids) - locations = locations[channel_inds, :] - origin = np.array([np.max(locations[:, 0]), np.min(locations[:, 1])])[None, :] - dist = scipy.spatial.distance.cdist(locations, origin, metric="euclidean") - dist = dist[:, 0] - self.order = np.argsort(dist) - else: - self.order = None - - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - fs = recording.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - self.mode = mode - self.cmap = cmap - - self.show_channel_ids = show_channel_ids - - self._frame_range = (time_range * fs).astype("int64") - a_max = self.recording.get_num_frames(segment_index=self.segment_index) - self._frame_range = np.clip(self._frame_range, 0, a_max) - self._time_range = [e / fs for e in self._frame_range] - - self.clim = clim - self.with_colorbar = with_colorbar - - self._initialize_stats() - - # self._vspacing = self._mean_channel_std * 20 - self._vspacing = self._max_channel_amp * 1.5 - - if recording.get_channel_groups() is None: - color_groups = False - - self._color_groups = color_groups - self._color = color - if color_groups: - self._colors = [] - self._group_color_map = {} - all_groups = recording.get_channel_groups() - groups = np.unique(all_groups) - N = len(groups) - import colorsys - - HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] - self._colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)) - color_idx = 0 - for group in groups: - self._group_color_map[group] = color_idx - color_idx += 1 - self.name = "TimeSeries" - - def plot(self): - self._do_plot() - - def _do_plot(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - if self.order is not None: - chunk0 = chunk0[:, self.order] - self.visible_channel_ids = np.array(self.visible_channel_ids)[self.order] - - ax = self.ax - - n = len(self.visible_channel_ids) - - if self.mode == "line": - ax.set_xlim( - self._frame_range[0] / self._sampling_frequency, self._frame_range[1] / self._sampling_frequency - ) - ax.set_ylim(-self._vspacing, self._vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.get_yaxis().set_ticks([]) - ax.set_xlabel("time (s)") - - self._plots = {} - self._plot_offsets = {} - offset0 = self._vspacing * (n - 1) - times = np.arange(self._frame_range[0], self._frame_range[1]) / self._sampling_frequency - for im, m in enumerate(self.visible_channel_ids): - self._plot_offsets[m] = offset0 - if self._color_groups: - group = self.recording.get_channel_groups(channel_ids=[m])[0] - group_color_idx = self._group_color_map[group] - color = self._colors[group_color_idx] - else: - color = self._color - self._plots[m] = ax.plot(times, self._plot_offsets[m] + chunk0[:, im], color=color, **self._plot_kwargs) - offset0 = offset0 - self._vspacing - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) * self._vspacing) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - elif self.mode == "map": - extent = (self._time_range[0], self._time_range[1], 0, self.recording.get_num_channels()) - im = ax.imshow( - chunk0.T, interpolation="nearest", origin="upper", aspect="auto", extent=extent, cmap=self.cmap - ) - - if self.clim is None: - im.set_clim(-self._max_channel_amp, self._max_channel_amp) - else: - im.set_clim(*self.clim) - - if self.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) + 0.5) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - def _initialize_stats(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - - self._mean_channel_std = np.mean(np.std(chunk0, axis=0)) - self._max_channel_amp = np.max(np.max(np.abs(chunk0), axis=0)) - - -def plot_timeseries(*args, **kwargs): - W = TracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py b/src/spikeinterface/widgets/agreement_matrix.py similarity index 53% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py rename to src/spikeinterface/widgets/agreement_matrix.py index 369746e99b..ec6ea1c87c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -1,11 +1,13 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class AgreementMatrixWidget(BaseWidget): """ - Plots sorting comparison confusion matrix. + Plots sorting comparison agreement matrix. Parameters ---------- @@ -19,31 +21,34 @@ class AgreementMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created + """ - def __init__(self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) - BaseWidget.__init__(self, figure, ax) - self._sc = sorting_comparison - self._ordered = ordered - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def plot(self): - self._do_plot() + comp = dp.sorting_comparison - def _do_plot(self): - # a dataframe - if self._ordered: - scores = self._sc.get_ordered_agreement_scores() + if dp.ordered: + scores = comp.get_ordered_agreement_scores() else: - scores = self._sc.agreement_scores + scores = comp.agreement_scores N1 = scores.shape[0] N2 = scores.shape[1] @@ -54,9 +59,9 @@ def _do_plot(self): # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(scores.values, cmap="Greens") - if self._count_text: + if dp.count_text: for i, u1 in enumerate(unit_ids1): - u2 = self._sc.best_match_12[u1] + u2 = comp.best_match_12[u1] if u2 != -1: j = np.where(unit_ids2 == u2)[0][0] @@ -68,24 +73,15 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(scores.index, fontsize=12) self.ax.set_xticklabels(scores.columns, fontsize=12) - self.ax.set_xlabel(self._sc.name_list[1], fontsize=20) - self.ax.set_ylabel(self._sc.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( N1 - 0.5, -0.5, ) - - -def plot_agreement_matrix(*args, **kwargs): - W = AgreementMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_agreement_matrix.__doc__ = AgreementMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index dea46b8f51..4ed83fcca9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -39,12 +39,14 @@ def set_default_plotter_backend(backend): "height_cm": "Height of the figure in cm (default 6)", "display": "If True, widgets are immediately displayed", }, + "ephyviewer": {}, } default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, + "ephyviewer": {}, } diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/confusion_matrix.py similarity index 62% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py rename to src/spikeinterface/widgets/confusion_matrix.py index 942b613fbf..8eb58f30b2 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -1,6 +1,8 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): @@ -15,40 +17,35 @@ class ConfusionMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget + """ - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + dp = to_attr(data_plot) - def plot(self): - self._do_plot() + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() + comp = dp.gt_comparison + confusion_matrix = comp.get_confusion_matrix() N1 = confusion_matrix.shape[0] - 1 N2 = confusion_matrix.shape[1] - 1 # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(confusion_matrix.values, cmap="Greens") - if self._count_text: + if dp.count_text: for (i, j), z in np.ndenumerate(confusion_matrix.values): if z != 0: if z > np.max(confusion_matrix.values) / 2.0: @@ -65,27 +62,18 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) else: self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 + 0.5) self.ax.set_ylim( N1 + 0.5, -0.5, ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py new file mode 100644 index 0000000000..7fb74abd7c --- /dev/null +++ b/src/spikeinterface/widgets/probe_map.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_unit_colors + + +class ProbeMapWidget(BaseWidget): + """ + Plot the probe of a recording. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + channel_ids: list + The channel ids to display + with_channel_ids: bool False default + Add channel ids text on the probe + **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function + + """ + + def __init__( + self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs + ): + # split backend_or_plot_probe_kwargs + backend_kwargs = dict() + plot_probe_kwargs = dict() + backend = self.check_backend(backend) + for k, v in backend_or_plot_probe_kwargs.items(): + if k in default_backend_kwargs[backend]: + backend_kwargs[k] = v + else: + plot_probe_kwargs[k] = v + + plot_data = dict( + recording=recording, + channel_ids=channel_ids, + with_channel_ids=with_channel_ids, + plot_probe_kwargs=plot_probe_kwargs, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import get_auto_lims, plot_probe + + dp = to_attr(data_plot) + + plot_probe_kwargs = dp.plot_probe_kwargs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + probegroup = dp.recording.get_probegroup() + + xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) + for i, probe in enumerate(probegroup.probes): + xlims2, ylims2, _ = get_auto_lims(probe) + xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) + ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) + + plot_probe_kwargs["title"] = False + pos = 0 + text_on_contact = None + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + if dp.with_channel_ids: + text_on_contact = dp.recording.channel_ids[pos : pos + n] + pos += n + plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) + + self.ax.set_xlim(*xlims) + self.ax.set_ylim(*ylims) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py new file mode 100644 index 0000000000..4a1d76279f --- /dev/null +++ b/src/spikeinterface/widgets/rasters.py @@ -0,0 +1,84 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs + + +class RasterWidget(BaseWidget): + """ + Plots spike train rasters. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + unit_ids: list + List of unit ids + time_range: list + List with start time and end time + color: matplotlib color + The color to be used + """ + + def __init__( + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs + ): + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] + + plot_data = dict( + sorting=sorting, + segment_index=segment_index, + unit_ids=unit_ids, + color=color, + frame_range=frame_range, + time_range=time_range, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index a5f75ebf50..f44878927d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -72,7 +72,7 @@ def setUpClass(cls): else: cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) - cls.skip_backends = ["ipywidgets"] + cls.skip_backends = ["ipywidgets", "ephyviewer"] if ON_GITHUB and not KACHERY_CLOUD_SET: cls.skip_backends.append("sortingview") @@ -324,6 +324,30 @@ def test_sorting_summary(self): sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + def test_plot_agreement_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_agreement_matrix(self.gt_comp) + + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) + + def test_plot_rasters(self): + possible_backends = list(sw.RasterWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_rasters(self.sorting) + if __name__ == "__main__": # unittest.main() @@ -344,7 +368,11 @@ def test_sorting_summary(self): # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_amplitudes() + # mytest.test_plot_agreement_matrix() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index e025f779c1..7bb2126744 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -524,6 +524,30 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **backend_kwargs) + def plot_ephyviewer(self, data_plot, **backend_kwargs): + import ephyviewer + from ..preprocessing import depth_order + + dp = to_attr(data_plot) + + app = ephyviewer.mkQApp() + win = ephyviewer.MainViewer(debug=False, show_auto_scale=True) + + for k, rec in dp.recordings.items(): + if dp.order_channel_by_depth: + rec = depth_order(rec, flip=True) + + sig_source = ephyviewer.SpikeInterfaceRecordingSource(recording=rec) + view = ephyviewer.TraceViewer(source=sig_source, name=k) + view.params["scale_mode"] = "by_channel" + if dp.show_channel_ids: + view.params["display_labels"] = True + view.auto_scale() + win.add_view(view) + + win.show() + app.exec() + def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index e8a6868e92..b3391c0712 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -103,7 +103,7 @@ def __init__( if same_axis and not np.array_equal(chan_inds, shared_chan_inds): # add more channels if necessary wfs_ = np.zeros((wfs.shape[0], wfs.shape[1], shared_chan_inds.size), dtype=float) - mask = np.in1d(shared_chan_inds, chan_inds) + mask = np.isin(shared_chan_inds, chan_inds) wfs_[:, :, mask] = wfs wfs_[:, :, ~mask] = np.nan wfs = wfs_ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 9c89b3981e..6ea2593432 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,12 +2,16 @@ from .base import backend_kwargs_desc +from .agreement_matrix import AgreementMatrixWidget from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget +from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget from .spikes_on_traces import SpikesOnTracesWidget @@ -23,12 +27,16 @@ widget_list = [ + AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, + RasterWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -76,12 +84,16 @@ # make function for all widgets +plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget +plot_rasters = RasterWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget