Skip to content

Commit

Permalink
Merge pull request #1968 from samuelgarcia/run_sorter_jobs
Browse files Browse the repository at this point in the history
Refactor sorter launcher. Deprecated `run_sorters()` and add `run_sorter_jobs()`
  • Loading branch information
samuelgarcia authored Sep 19, 2023
2 parents 2aabd14 + 60e8989 commit 30a6613
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 435 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ spikeinterface.sorters
.. autofunction:: print_sorter_versions
.. autofunction:: get_sorter_description
.. autofunction:: run_sorter
.. autofunction:: run_sorter_jobs
.. autofunction:: run_sorters
.. autofunction:: run_sorter_by_property
.. autofunction:: read_sorter_folder

Low level
~~~~~~~~~
Expand Down
37 changes: 17 additions & 20 deletions doc/modules/sorters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,26 @@ Running several sorters in parallel

The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs
sequentially or in parallel. This can be done with the
:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying
:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying
an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`).

.. code-block:: python
recordings = {'rec1' : recording, 'rec2': another_recording}
sorter_list = ['herdingspikes', 'tridesclous']
sorter_params = {
'herdingspikes': {'clustering_bandwidth' : 8},
'tridesclous': {'detect_threshold' : 5.},
}
sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters',
mode_if_folder_exists='overwrite', sorter_params=sorter_params)
# here we run 2 sorters on 2 different recordings = 4 jobs
recording = ...
another_recording = ...
job_list = [
{'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.},
{'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.},
{'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True},
{'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True},
]
# run in loop
sortings = run_sorter_jobs(job_list, engine='loop')
# the output is a dict with (rec_name, sorter_name) as keys
for (rec_name, sorter_name), sorting in sorting_output.items():
print(rec_name, sorter_name, ':', sorting.get_unit_ids())
After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g.
:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting`
as a value.
:py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation:

Expand All @@ -315,13 +314,11 @@ as a value.

.. code-block:: python
run_sorters(sorter_list, recordings, engine='loop')
run_sorter_jobs(job_list, engine='loop')
run_sorters(sorter_list, recordings, engine='joblib',
engine_kwargs={'n_jobs': 2})
run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2})
run_sorters(sorter_list, recordings, engine='slurm',
engine_kwargs={'cpus_per_task': 10, 'mem', '5G'})
run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'})
Spike sorting by group
Expand Down
35 changes: 34 additions & 1 deletion src/spikeinterface/comparison/studytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,45 @@
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.extractors import NpzSortingExtractor
from spikeinterface.sorters import sorter_dict
from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output
from spikeinterface.sorters.basesorter import is_log_ok


from .comparisontools import _perf_keys
from .paircomparisons import compare_sorter_to_ground_truth


# This is deprecated and will be removed
def iter_working_folder(working_folder):
working_folder = Path(working_folder)
for rec_folder in working_folder.iterdir():
if not rec_folder.is_dir():
continue
for output_folder in rec_folder.iterdir():
if (output_folder / "spikeinterface_job.json").is_file():
with open(output_folder / "spikeinterface_job.json", "r") as f:
job_dict = json.load(f)
rec_name = job_dict["rec_name"]
sorter_name = job_dict["sorter_name"]
yield rec_name, sorter_name, output_folder
else:
rec_name = rec_folder.name
sorter_name = output_folder.name
if not output_folder.is_dir():
continue
if not is_log_ok(output_folder):
continue
yield rec_name, sorter_name, output_folder


# This is deprecated and will be removed
def iter_sorting_output(working_folder):
"""Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting)."""
for rec_name, sorter_name, output_folder in iter_working_folder(working_folder):
SorterClass = sorter_dict[sorter_name]
sorting = SorterClass.get_result_from_folder(output_folder)
yield rec_name, sorter_name, sorting


def setup_comparison_study(study_folder, gt_dict, **job_kwargs):
"""
Based on a dict of (recording, sorting) create the study folder.
Expand Down
9 changes: 1 addition & 8 deletions src/spikeinterface/sorters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from .basesorter import BaseSorter
from .sorterlist import *
from .runsorter import *

from .launcher import (
run_sorters,
run_sorter_by_property,
collect_sorting_outputs,
iter_working_folder,
iter_sorting_output,
)
from .launcher import run_sorter_jobs, run_sorters, run_sorter_by_property
11 changes: 11 additions & 0 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,14 @@ def get_job_kwargs(params, verbose):
if not verbose:
job_kwargs["progress_bar"] = False
return job_kwargs


def is_log_ok(output_folder):
# log is OK when run_time is not None
if (output_folder / "spikeinterface_log.json").is_file():
with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile:
log = json.load(logfile)
run_time = log.get("run_time", None)
ok = run_time is not None
return ok
return False
Loading

0 comments on commit 30a6613

Please sign in to comment.