Skip to content

Commit

Permalink
Merge branch 'main' into read-ks-docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Sep 21, 2023
2 parents e3cb9bb + df0504c commit b78ff8c
Show file tree
Hide file tree
Showing 30 changed files with 1,206 additions and 1,346 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ spikeinterface.sorters
.. autofunction:: print_sorter_versions
.. autofunction:: get_sorter_description
.. autofunction:: run_sorter
.. autofunction:: run_sorter_jobs
.. autofunction:: run_sorters
.. autofunction:: run_sorter_by_property
.. autofunction:: read_sorter_folder

Low level
~~~~~~~~~
Expand Down
Binary file added doc/images/plot_traces_ephyviewer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 17 additions & 20 deletions doc/modules/sorters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,26 @@ Running several sorters in parallel

The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs
sequentially or in parallel. This can be done with the
:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying
:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying
an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`).

.. code-block:: python
recordings = {'rec1' : recording, 'rec2': another_recording}
sorter_list = ['herdingspikes', 'tridesclous']
sorter_params = {
'herdingspikes': {'clustering_bandwidth' : 8},
'tridesclous': {'detect_threshold' : 5.},
}
sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters',
mode_if_folder_exists='overwrite', sorter_params=sorter_params)
# here we run 2 sorters on 2 different recordings = 4 jobs
recording = ...
another_recording = ...
job_list = [
{'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.},
{'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.},
{'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True},
{'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True},
]
# run in loop
sortings = run_sorter_jobs(job_list, engine='loop')
# the output is a dict with (rec_name, sorter_name) as keys
for (rec_name, sorter_name), sorting in sorting_output.items():
print(rec_name, sorter_name, ':', sorting.get_unit_ids())
After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g.
:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting`
as a value.
:py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation:

Expand All @@ -315,13 +314,11 @@ as a value.

.. code-block:: python
run_sorters(sorter_list, recordings, engine='loop')
run_sorter_jobs(job_list, engine='loop')
run_sorters(sorter_list, recordings, engine='joblib',
engine_kwargs={'n_jobs': 2})
run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2})
run_sorters(sorter_list, recordings, engine='slurm',
engine_kwargs={'cpus_per_task': 10, 'mem', '5G'})
run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'})
Spike sorting by group
Expand Down
42 changes: 41 additions & 1 deletion doc/modules/widgets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi
* | :code:`sortingview`: web-based and interactive rendering using the `sortingview <https://github.com/magland/sortingview>`_
| and `FIGURL <https://github.com/flatironinstitute/figurl>`_ packages.
Version 0.100.0, also come with this new backend:
* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer <https://ephyviewer.readthedocs.io/en/latest/>`_ package


Installing backends
-------------------
Expand Down Expand Up @@ -85,6 +88,28 @@ Finally, if you wish to set up another cloud provider, follow the instruction fr
`kachery-cloud <https://github.com/flatironinstitute/kachery-cloud>`_ package ("Using your own storage bucket").


ephyviewer
^^^^^^^^^^

This backend is Qt based with PyQt5, PyQt6 or PySide6 support. Qt is sometimes tedious to install.


For a pip-based installation, run:

.. code-block:: bash
pip install PySide6 ephyviewer
Anaconda users will have a better experience with this:

.. code-block:: bash
conda install pyqt=5
pip install ephyviewer
Usage
-----

Expand Down Expand Up @@ -215,6 +240,21 @@ For example, here is how to combine the timeseries and sorting summary generated
print(url)
ephyviewer
^^^^^^^^^^


The :code:`ephyviewer` backend is currently only available for the :py:func:`~spikeinterface.widgets.plot_traces()` function.


.. code-block:: python
plot_traces(recording, backend="ephyviewer", mode="line", show_channel_ids=True)
.. image:: ../images/plot_traces_ephyviewer.png



Available plotting functions
----------------------------
Expand All @@ -229,7 +269,7 @@ Available plotting functions
* :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`)
* :py:func:`~spikeinterface.widgets.plot_template_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_template_similarity` (backends: ::code:`matplotlib`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_timeseries` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`)
* :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`)
* :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`)
* :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`)
Expand Down
35 changes: 34 additions & 1 deletion src/spikeinterface/comparison/studytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,45 @@
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.extractors import NpzSortingExtractor
from spikeinterface.sorters import sorter_dict
from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output
from spikeinterface.sorters.basesorter import is_log_ok


from .comparisontools import _perf_keys
from .paircomparisons import compare_sorter_to_ground_truth


# This is deprecated and will be removed
def iter_working_folder(working_folder):
working_folder = Path(working_folder)
for rec_folder in working_folder.iterdir():
if not rec_folder.is_dir():
continue
for output_folder in rec_folder.iterdir():
if (output_folder / "spikeinterface_job.json").is_file():
with open(output_folder / "spikeinterface_job.json", "r") as f:
job_dict = json.load(f)
rec_name = job_dict["rec_name"]
sorter_name = job_dict["sorter_name"]
yield rec_name, sorter_name, output_folder
else:
rec_name = rec_folder.name
sorter_name = output_folder.name
if not output_folder.is_dir():
continue
if not is_log_ok(output_folder):
continue
yield rec_name, sorter_name, output_folder


# This is deprecated and will be removed
def iter_sorting_output(working_folder):
"""Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting)."""
for rec_name, sorter_name, output_folder in iter_working_folder(working_folder):
SorterClass = sorter_dict[sorter_name]
sorting = SorterClass.get_result_from_folder(output_folder)
yield rec_name, sorter_name, sorting


def setup_comparison_study(study_folder, gt_dict, **job_kwargs):
"""
Based on a dict of (recording, sorting) create the study folder.
Expand Down
109 changes: 98 additions & 11 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

from .recording_tools import get_channel_distances, get_noise_levels
Expand Down Expand Up @@ -33,7 +35,9 @@

class ChannelSparsity:
"""
Handle channel sparsity for a set of units.
Handle channel sparsity for a set of units. That is, for every unit,
it indicates which channels are used to represent the waveform and the rest
of the non-represented channels are assumed to be zero.
Internally, sparsity is stored as a boolean mask.
Expand Down Expand Up @@ -92,13 +96,17 @@ def __init__(self, mask, unit_ids, channel_ids):
assert self.mask.shape[0] == self.unit_ids.shape[0]
assert self.mask.shape[1] == self.channel_ids.shape[0]

# some precomputed dict
# Those are computed at first call
self._unit_id_to_channel_ids = None
self._unit_id_to_channel_indices = None

self.num_channels = self.channel_ids.size
self.num_units = self.unit_ids.size
self.max_num_active_channels = self.mask.sum(axis=1).max()

def __repr__(self):
ratio = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.unit_ids.size} - channels: {self.channel_ids.size} - ratio: {ratio:0.2f}"
density = np.mean(self.mask)
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

@property
Expand All @@ -119,6 +127,85 @@ def unit_id_to_channel_indices(self):
self._unit_id_to_channel_indices[unit_id] = channel_inds
return self._unit_id_to_channel_indices

def sparsify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Sparsify the waveforms according to a unit_id corresponding sparsity.
Given a unit_id, this method selects only the active channels for
that unit and removes the rest.
Parameters
----------
waveforms : np.array
Dense waveforms with shape (num_waveforms, num_samples, num_channels) or a
single dense waveform (template) with shape (num_samples, num_channels).
unit_id : str
The unit_id for which to sparsify the waveform.
Returns
-------
sparsified_waveforms : np.array
Sparse waveforms with shape (num_waveforms, num_samples, num_active_channels)
or a single sparsified waveform (template) with shape (num_samples, num_active_channels).
"""

assert_msg = (
"Waveforms must be dense to sparsify them. "
f"Their last dimension {waveforms.shape[-1]} must be equal to the number of channels {self.num_channels}"
)
assert self.are_waveforms_dense(waveforms=waveforms), assert_msg

non_zero_indices = self.unit_id_to_channel_indices[unit_id]
sparsified_waveforms = waveforms[..., non_zero_indices]

return sparsified_waveforms

def densify_waveforms(self, waveforms: np.ndarray, unit_id: str | int) -> np.ndarray:
"""
Densify sparse waveforms that were sparisified according to a unit's channel sparsity.
Given a unit_id its sparsified waveform, this method places the waveform back
into its original form within a dense array.
Parameters
----------
waveforms : np.array
The sparsified waveforms array of shape (num_waveforms, num_samples, num_active_channels) or a single
sparse waveform (template) with shape (num_samples, num_active_channels).
unit_id : str
The unit_id that was used to sparsify the waveform.
Returns
-------
densified_waveforms : np.array
The densified waveforms array of shape (num_waveforms, num_samples, num_channels) or a single dense
waveform (template) with shape (num_samples, num_channels).
"""

non_zero_indices = self.unit_id_to_channel_indices[unit_id]

assert_msg = (
"Waveforms do not seem to be be in the sparsity shape of this unit_id. The number of active channels is "
f"{len(non_zero_indices)} but the waveform has {waveforms.shape[-1]} active channels."
)
assert self.are_waveforms_sparse(waveforms=waveforms, unit_id=unit_id), assert_msg

densified_shape = waveforms.shape[:-1] + (self.num_channels,)
densified_waveforms = np.zeros(densified_shape, dtype=waveforms.dtype)
densified_waveforms[..., non_zero_indices] = waveforms

return densified_waveforms

def are_waveforms_dense(self, waveforms: np.ndarray) -> bool:
return waveforms.shape[-1] == self.num_channels

def are_waveforms_sparse(self, waveforms: np.ndarray, unit_id: str | int) -> bool:
non_zero_indices = self.unit_id_to_channel_indices[unit_id]
num_active_channels = len(non_zero_indices)
return waveforms.shape[-1] == num_active_channels

@classmethod
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
"""
Expand All @@ -144,16 +231,16 @@ def to_dict(self):
)

@classmethod
def from_dict(cls, d):
def from_dict(cls, dictionary: dict):
unit_id_to_channel_ids_corrected = {}
for unit_id in d["unit_ids"]:
if unit_id in d["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][unit_id]
for unit_id in dictionary["unit_ids"]:
if unit_id in dictionary["unit_id_to_channel_ids"]:
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][unit_id]
else:
unit_id_to_channel_ids_corrected[unit_id] = d["unit_id_to_channel_ids"][str(unit_id)]
d["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected
unit_id_to_channel_ids_corrected[unit_id] = dictionary["unit_id_to_channel_ids"][str(unit_id)]
dictionary["unit_id_to_channel_ids"] = unit_id_to_channel_ids_corrected

return cls.from_unit_id_to_channel_ids(**d)
return cls.from_unit_id_to_channel_ids(**dictionary)

## Some convinient function to compute sparsity from several strategy
@classmethod
Expand Down
Loading

0 comments on commit b78ff8c

Please sign in to comment.