diff --git a/doc/api.rst b/doc/api.rst index 122c88d01b..97c956c2f6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -217,8 +217,10 @@ spikeinterface.sorters .. autofunction:: print_sorter_versions .. autofunction:: get_sorter_description .. autofunction:: run_sorter + .. autofunction:: run_sorter_jobs .. autofunction:: run_sorters .. autofunction:: run_sorter_by_property + .. autofunction:: read_sorter_folder Low level ~~~~~~~~~ diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 1b27ed442c..f3c8e7b733 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -285,27 +285,26 @@ Running several sorters in parallel The :py:mod:`~spikeinterface.sorters` module also includes tools to run several spike sorting jobs sequentially or in parallel. This can be done with the -:py:func:`~spikeinterface.sorters.run_sorters()` function by specifying +:py:func:`~spikeinterface.sorters.run_sorter_jobs()` function by specifying an :code:`engine` that supports parallel processing (such as :code:`joblib` or :code:`slurm`). .. code-block:: python - recordings = {'rec1' : recording, 'rec2': another_recording} - sorter_list = ['herdingspikes', 'tridesclous'] - sorter_params = { - 'herdingspikes': {'clustering_bandwidth' : 8}, - 'tridesclous': {'detect_threshold' : 5.}, - } - sorting_output = run_sorters(sorter_list, recordings, working_folder='tmp_some_sorters', - mode_if_folder_exists='overwrite', sorter_params=sorter_params) + # here we run 2 sorters on 2 different recordings = 4 jobs + recording = ... + another_recording = ... + + job_list = [ + {'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.}, + {'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.}, + {'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, + ] + + # run in loop + sortings = run_sorter_jobs(job_list, engine='loop') - # the output is a dict with (rec_name, sorter_name) as keys - for (rec_name, sorter_name), sorting in sorting_output.items(): - print(rec_name, sorter_name, ':', sorting.get_unit_ids()) -After the jobs are run, the :code:`sorting_outputs` is a dictionary with :code:`(rec_name, sorter_name)` as a key (e.g. -:code:`('rec1', 'tridesclous')` in this example), and the corresponding :py:class:`~spikeinterface.core.BaseSorting` -as a value. :py:func:`~spikeinterface.sorters.run_sorters` has several "engines" available to launch the computation: @@ -315,13 +314,11 @@ as a value. .. code-block:: python - run_sorters(sorter_list, recordings, engine='loop') + run_sorter_jobs(job_list, engine='loop') - run_sorters(sorter_list, recordings, engine='joblib', - engine_kwargs={'n_jobs': 2}) + run_sorter_jobs(job_list, engine='joblib', engine_kwargs={'n_jobs': 2}) - run_sorters(sorter_list, recordings, engine='slurm', - engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) + run_sorter_jobs(job_list, engine='slurm', engine_kwargs={'cpus_per_task': 10, 'mem', '5G'}) Spike sorting by group diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py index 79227c865f..26d2c1ad6f 100644 --- a/src/spikeinterface/comparison/studytools.py +++ b/src/spikeinterface/comparison/studytools.py @@ -22,12 +22,45 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.extractors import NpzSortingExtractor from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.launcher import iter_working_folder, iter_sorting_output +from spikeinterface.sorters.basesorter import is_log_ok + from .comparisontools import _perf_keys from .paircomparisons import compare_sorter_to_ground_truth +# This is deprecated and will be removed +def iter_working_folder(working_folder): + working_folder = Path(working_folder) + for rec_folder in working_folder.iterdir(): + if not rec_folder.is_dir(): + continue + for output_folder in rec_folder.iterdir(): + if (output_folder / "spikeinterface_job.json").is_file(): + with open(output_folder / "spikeinterface_job.json", "r") as f: + job_dict = json.load(f) + rec_name = job_dict["rec_name"] + sorter_name = job_dict["sorter_name"] + yield rec_name, sorter_name, output_folder + else: + rec_name = rec_folder.name + sorter_name = output_folder.name + if not output_folder.is_dir(): + continue + if not is_log_ok(output_folder): + continue + yield rec_name, sorter_name, output_folder + + +# This is deprecated and will be removed +def iter_sorting_output(working_folder): + """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" + for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): + SorterClass = sorter_dict[sorter_name] + sorting = SorterClass.get_result_from_folder(output_folder) + yield rec_name, sorter_name, sorting + + def setup_comparison_study(study_folder, gt_dict, **job_kwargs): """ Based on a dict of (recording, sorting) create the study folder. diff --git a/src/spikeinterface/sorters/__init__.py b/src/spikeinterface/sorters/__init__.py index a0d437559d..ba663327e8 100644 --- a/src/spikeinterface/sorters/__init__.py +++ b/src/spikeinterface/sorters/__init__.py @@ -1,11 +1,4 @@ from .basesorter import BaseSorter from .sorterlist import * from .runsorter import * - -from .launcher import ( - run_sorters, - run_sorter_by_property, - collect_sorting_outputs, - iter_working_folder, - iter_sorting_output, -) +from .launcher import run_sorter_jobs, run_sorters, run_sorter_by_property diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index ff559cc78d..c7581ba1e1 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -411,3 +411,14 @@ def get_job_kwargs(params, verbose): if not verbose: job_kwargs["progress_bar"] = False return job_kwargs + + +def is_log_ok(output_folder): + # log is OK when run_time is not None + if (output_folder / "spikeinterface_log.json").is_file(): + with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + ok = run_time is not None + return ok + return False diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 52098f45cd..f32a468a22 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -4,61 +4,193 @@ from pathlib import Path import shutil import numpy as np -import json import tempfile import os import stat import subprocess import sys +import warnings -from spikeinterface.core import load_extractor, aggregate_units -from spikeinterface.core.core_tools import check_json +from spikeinterface.core import aggregate_units from .sorterlist import sorter_dict -from .runsorter import run_sorter, run_sorter - - -def _run_one(arg_list): - # the multiprocessing python module force to have one unique tuple argument - ( - sorter_name, - recording, - output_folder, - verbose, - sorter_params, - docker_image, - singularity_image, - with_output, - ) = arg_list - - if isinstance(recording, dict): - recording = load_extractor(recording) +from .runsorter import run_sorter +from .basesorter import is_log_ok + +_default_engine_kwargs = dict( + loop=dict(), + joblib=dict(n_jobs=-1, backend="loky"), + processpoolexecutor=dict(max_workers=2, mp_context=None), + dask=dict(client=None), + slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"), +) + + +_implemented_engine = list(_default_engine_kwargs.keys()) + + +def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False): + """ + Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs. + + For **engine="loop"** this is equivalent to: + + ..code:: + + for job in job_list: + run_sorter(**job) + + The following engines block the I/O: + * "loop" + * "joblib" + * "multiprocessing" + * "dask" + + The following engines are *asynchronous*: + * "slurm" + + Where *blocking* means that this function is blocking until the results are returned. + This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking), + but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know + when jobs are finish. + In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results. + + + Parameters + ---------- + job_list: list of dict + A list a dict that are propagated to run_sorter(...) + engine: str "loop", "joblib", "dask", "slurm" + The engine to run the list. + * "loop": a simple loop. This engine is + engine_kwargs: dict + + return_output: bool, dfault False + Return a sorting or None. + + Returns + ------- + sortings: None or list of sorting + With engine="loop" or "joblib" you can optional get directly the list of sorting result if return_output=True. + """ + + assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" + + engine_kwargs_ = dict() + engine_kwargs_.update(_default_engine_kwargs[engine]) + engine_kwargs_.update(engine_kwargs) + engine_kwargs = engine_kwargs_ + + if return_output: + assert engine in ( + "loop", + "joblib", + "processpoolexecutor", + ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." + out = [] else: - recording = recording - - # because this is checks in run_sorters before this call - remove_existing_folder = False - # result is retrieve later - delete_output_folder = False - # because we won't want the loop/worker to break - raise_error = False - - run_sorter( - sorter_name, - recording, - output_folder=output_folder, - remove_existing_folder=remove_existing_folder, - delete_output_folder=delete_output_folder, - verbose=verbose, - raise_error=raise_error, - docker_image=docker_image, - singularity_image=singularity_image, - with_output=with_output, - **sorter_params, - ) + out = None + + if engine == "loop": + # simple loop in main process + for kwargs in job_list: + sorting = run_sorter(**kwargs) + if return_output: + out.append(sorting) + + elif engine == "joblib": + from joblib import Parallel, delayed + + n_jobs = engine_kwargs["n_jobs"] + backend = engine_kwargs["backend"] + sortings = Parallel(n_jobs=n_jobs, backend=backend)(delayed(run_sorter)(**kwargs) for kwargs in job_list) + if return_output: + out.extend(sortings) + + elif engine == "processpoolexecutor": + from concurrent.futures import ProcessPoolExecutor + + max_workers = engine_kwargs["max_workers"] + mp_context = engine_kwargs["mp_context"] + with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp_context) as executor: + futures = [] + for kwargs in job_list: + res = executor.submit(run_sorter, **kwargs) + futures.append(res) + for futur in futures: + sorting = futur.result() + if return_output: + out.append(sorting) -_implemented_engine = ("loop", "joblib", "dask", "slurm") + elif engine == "dask": + client = engine_kwargs["client"] + assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" + + tasks = [] + for kwargs in job_list: + task = client.submit(run_sorter, **kwargs) + tasks.append(task) + + for task in tasks: + task.result() + + elif engine == "slurm": + # generate python script for slurm + tmp_script_folder = engine_kwargs["tmp_script_folder"] + if tmp_script_folder is None: + tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") + tmp_script_folder = Path(tmp_script_folder) + cpus_per_task = engine_kwargs["cpus_per_task"] + mem = engine_kwargs["mem"] + + tmp_script_folder.mkdir(exist_ok=True, parents=True) + + for i, kwargs in enumerate(job_list): + script_name = tmp_script_folder / f"si_script_{i}.py" + with open(script_name, "w") as f: + kwargs_txt = "" + for k, v in kwargs.items(): + kwargs_txt += " " + if k == "recording": + # put None temporally + kwargs_txt += "recording=None" + else: + if isinstance(v, str): + kwargs_txt += f'{k}="{v}"' + elif isinstance(v, Path): + kwargs_txt += f'{k}="{str(v.absolute())}"' + else: + kwargs_txt += f"{k}={v}" + kwargs_txt += ",\n" + + # recording_dict = task_args[1] + recording_dict = kwargs["recording"].to_dict() + slurm_script = _slurm_script.format( + python=sys.executable, recording_dict=recording_dict, kwargs_txt=kwargs_txt + ) + f.write(slurm_script) + os.fchmod(f.fileno(), mode=stat.S_IRWXU) + + subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + + return out + + +_slurm_script = """#! {python} +from numpy import array +from spikeinterface import load_extractor +from spikeinterface.sorters import run_sorter + +rec_dict = {recording_dict} + +kwargs = dict( +{kwargs_txt} +) +kwargs['recording'] = load_extractor(rec_dict) + +run_sorter(**kwargs) +""" def run_sorter_by_property( @@ -66,7 +198,7 @@ def run_sorter_by_property( recording, grouping_property, working_folder, - mode_if_folder_exists="raise", + mode_if_folder_exists=None, engine="loop", engine_kwargs={}, verbose=False, @@ -93,11 +225,10 @@ def run_sorter_by_property( Property to split by before sorting working_folder: str The working directory. - mode_if_folder_exists: {'raise', 'overwrite', 'keep'} - The mode when the subfolder of recording/sorter already exists. - * 'raise' : raise error if subfolder exists - * 'overwrite' : delete and force recompute - * 'keep' : do not compute again if f=subfolder exists and log is OK + mode_if_folder_exists: None + Must be None. This is deprecated. + If not None then a warning is raise. + Will be removed in next release. engine: {'loop', 'joblib', 'dask'} Which engine to use to run sorter. engine_kwargs: dict @@ -127,46 +258,49 @@ def run_sorter_by_property( engine_kwargs={"n_jobs": 4}) """ + if mode_if_folder_exists is not None: + warnings.warn( + "run_sorter_by_property(): mode_if_folder_exists is not used anymore", + DeprecationWarning, + stacklevel=2, + ) + + working_folder = Path(working_folder).absolute() assert grouping_property in recording.get_property_keys(), ( f"The 'grouping_property' {grouping_property} is not " f"a recording property!" ) recording_dict = recording.split_by(grouping_property) - sorting_output = run_sorters( - [sorter_name], - recording_dict, - working_folder, - mode_if_folder_exists=mode_if_folder_exists, - engine=engine, - engine_kwargs=engine_kwargs, - verbose=verbose, - with_output=True, - docker_images={sorter_name: docker_image}, - singularity_images={sorter_name: singularity_image}, - sorter_params={sorter_name: sorter_params}, - ) - grouping_property_values = None - sorting_list = [] - for output_name, sorting in sorting_output.items(): - prop_name, sorter_name = output_name - sorting_list.append(sorting) - if grouping_property_values is None: - grouping_property_values = np.array( - [prop_name] * len(sorting.get_unit_ids()), dtype=np.dtype(type(prop_name)) - ) - else: - grouping_property_values = np.concatenate( - (grouping_property_values, [prop_name] * len(sorting.get_unit_ids())) - ) + job_list = [] + for k, rec in recording_dict.items(): + job = dict( + sorter_name=sorter_name, + recording=rec, + output_folder=working_folder / str(k), + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **sorter_params, + ) + job_list.append(job) + + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=True) + + unit_groups = [] + for sorting, group in zip(sorting_list, recording_dict.keys()): + num_units = sorting.get_unit_ids().size + unit_groups.extend([group] * num_units) + unit_groups = np.array(unit_groups) aggregate_sorting = aggregate_units(sorting_list) - aggregate_sorting.set_property(key=grouping_property, values=grouping_property_values) + aggregate_sorting.set_property(key=grouping_property, values=unit_groups) aggregate_sorting.register_recording(recording) return aggregate_sorting +# This is deprecated and will be removed def run_sorters( sorter_list, recording_dict_or_list, @@ -180,7 +314,9 @@ def run_sorters( docker_images={}, singularity_images={}, ): - """Run several sorter on several recordings. + """ + This function is deprecated and will be removed in version 0.100 + Please use run_sorter_jobs() instead. Parameters ---------- @@ -221,6 +357,13 @@ def run_sorters( results : dict The output is nested dict[(rec_name, sorter_name)] of SortingExtractor. """ + + warnings.warn( + "run_sorters() is deprecated please use run_sorter_jobs() instead. This will be removed in 0.100", + DeprecationWarning, + stacklevel=2, + ) + working_folder = Path(working_folder) mode_if_folder_exists in ("raise", "keep", "overwrite") @@ -247,8 +390,7 @@ def run_sorters( dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" - need_dump = engine != "loop" - task_args_list = [] + job_list = [] for rec_name, recording in recording_dict.items(): for sorter_name in sorter_list: output_folder = working_folder / str(rec_name) / sorter_name @@ -268,181 +410,21 @@ def run_sorters( params = sorter_params.get(sorter_name, {}) docker_image = docker_images.get(sorter_name, None) singularity_image = singularity_images.get(sorter_name, None) - _check_container_images(docker_image, singularity_image, sorter_name) - - if need_dump: - if not recording.check_if_dumpable(): - raise Exception("recording not dumpable call recording.save() before") - recording_arg = recording.to_dict(recursive=True) - else: - recording_arg = recording - - task_args = ( - sorter_name, - recording_arg, - output_folder, - verbose, - params, - docker_image, - singularity_image, - with_output, - ) - task_args_list.append(task_args) - if engine == "loop": - # simple loop in main process - for task_args in task_args_list: - _run_one(task_args) - - elif engine == "joblib": - from joblib import Parallel, delayed - - n_jobs = engine_kwargs.get("n_jobs", -1) - backend = engine_kwargs.get("backend", "loky") - Parallel(n_jobs=n_jobs, backend=backend)(delayed(_run_one)(task_args) for task_args in task_args_list) - - elif engine == "dask": - client = engine_kwargs.get("client", None) - assert client is not None, "For dask engine you have to provide : client = dask.distributed.Client(...)" - - tasks = [] - for task_args in task_args_list: - task = client.submit(_run_one, task_args) - tasks.append(task) - - for task in tasks: - task.result() - - elif engine == "slurm": - # generate python script for slurm - tmp_script_folder = engine_kwargs.get("tmp_script_folder", None) - if tmp_script_folder is None: - tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_") - tmp_script_folder = Path(tmp_script_folder) - cpus_per_task = engine_kwargs.get("cpus_per_task", 1) - mem = engine_kwargs.get("mem", "1G") - - for i, task_args in enumerate(task_args_list): - script_name = tmp_script_folder / f"si_script_{i}.py" - with open(script_name, "w") as f: - arg_list_txt = "(\n" - for j, arg in enumerate(task_args): - arg_list_txt += "\t" - if j != 1: - if isinstance(arg, str): - arg_list_txt += f'"{arg}"' - elif isinstance(arg, Path): - arg_list_txt += f'"{str(arg.absolute())}"' - else: - arg_list_txt += f"{arg}" - else: - arg_list_txt += "recording" - arg_list_txt += ",\r" - arg_list_txt += ")" - - recording_dict = task_args[1] - slurm_script = _slurm_script.format( - python=sys.executable, recording_dict=recording_dict, arg_list_txt=arg_list_txt - ) - f.write(slurm_script) - os.fchmod(f.fileno(), mode=stat.S_IRWXU) - - print(slurm_script) - - subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"]) + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=output_folder, + verbose=verbose, + docker_image=docker_image, + singularity_image=singularity_image, + **params, + ) + job_list.append(job) - non_blocking_engine = ("loop", "joblib") - if engine in non_blocking_engine: - # dump spikeinterface_job.json - # only for non blocking engine - for rec_name, recording in recording_dict.items(): - for sorter_name in sorter_list: - output_folder = working_folder / str(rec_name) / sorter_name - with open(output_folder / "spikeinterface_job.json", "w") as f: - dump_dict = {"rec_name": rec_name, "sorter_name": sorter_name, "engine": engine} - if engine != "dask": - dump_dict.update({"engine_kwargs": engine_kwargs}) - json.dump(check_json(dump_dict), f) + sorting_list = run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=with_output) if with_output: - if engine not in non_blocking_engine: - print( - f'Warning!! With engine="{engine}" you cannot have directly output results\n' - "Use : run_sorters(..., with_output=False)\n" - "And then: results = collect_sorting_outputs(output_folders)" - ) - return - - results = collect_sorting_outputs(working_folder) + keys = [(rec_name, sorter_name) for rec_name in recording_dict for sorter_name in sorter_list] + results = dict(zip(keys, sorting_list)) return results - - -_slurm_script = """#! {python} -from numpy import array -from spikeinterface.sorters.launcher import _run_one - -recording = {recording_dict} - -arg_list = {arg_list_txt} - -_run_one(arg_list) -""" - - -def is_log_ok(output_folder): - # log is OK when run_time is not None - if (output_folder / "spikeinterface_log.json").is_file(): - with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - ok = run_time is not None - return ok - return False - - -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def collect_sorting_outputs(working_folder): - """Collect results in a working_folder. - - The output is a dict with double key access results[(rec_name, sorter_name)] of SortingExtractor. - """ - results = {} - for rec_name, sorter_name, sorting in iter_sorting_output(working_folder): - results[(rec_name, sorter_name)] = sorting - return results - - -def _check_container_images(docker_image, singularity_image, sorter_name): - if docker_image is not None: - assert singularity_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" - if singularity_image is not None: - assert docker_image is None, f"Provide either a docker or a singularity image " f"for sorter {sorter_name}" diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index cd8bc0fa5d..14c938f8ba 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -1,4 +1,5 @@ import os +import sys import shutil import time @@ -6,8 +7,10 @@ from pathlib import Path from spikeinterface.core import load_extractor -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import run_sorters, run_sorter_by_property, collect_sorting_outputs + +# from spikeinterface.extractors import toy_example +from spikeinterface import generate_ground_truth_recording +from spikeinterface.sorters import run_sorter_jobs, run_sorters, run_sorter_by_property if hasattr(pytest, "global_test_folder"): @@ -15,10 +18,17 @@ else: cache_folder = Path("cache_folder") / "sorters" +base_output = cache_folder / "sorter_output" + +# no need to have many +num_recordings = 2 +sorters = ["tridesclous2"] + def setup_module(): - rec, _ = toy_example(num_channels=8, duration=30, seed=0, num_segments=1) - for i in range(4): + base_seed = 42 + for i in range(num_recordings): + rec, _ = generate_ground_truth_recording(num_channels=8, durations=[10.0], seed=base_seed + i) rec_folder = cache_folder / f"toy_rec_{i}" if rec_folder.is_dir(): shutil.rmtree(rec_folder) @@ -31,19 +41,106 @@ def setup_module(): rec.save(folder=rec_folder) -def test_run_sorters_with_list(): - working_folder = cache_folder / "test_run_sorters_list" +def get_job_list(): + jobs = [] + for i in range(num_recordings): + for sorter_name in sorters: + recording = load_extractor(cache_folder / f"toy_rec_{i}") + kwargs = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=base_output / f"{sorter_name}_rec{i}", + verbose=True, + raise_error=False, + ) + jobs.append(kwargs) + + return jobs + + +@pytest.fixture(scope="module") +def job_list(): + return get_job_list() + + +def test_run_sorter_jobs_loop(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs(job_list, engine="loop", return_output=True) + print(sortings) + + +def test_run_sorter_jobs_joblib(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="joblib", engine_kwargs=dict(n_jobs=2, backend="loky"), return_output=True + ) + print(sortings) + + +def test_run_sorter_jobs_processpoolexecutor(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + sortings = run_sorter_jobs( + job_list, engine="processpoolexecutor", engine_kwargs=dict(max_workers=2), return_output=True + ) + print(sortings) + + +@pytest.mark.skipif(True, reason="This is tested locally") +def test_run_sorter_jobs_dask(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + # create a dask Client for a slurm queue + from dask.distributed import Client + + test_mode = "local" + # test_mode = "client_slurm" + + if test_mode == "local": + client = Client() + elif test_mode == "client_slurm": + from dask_jobqueue import SLURMCluster + + cluster = SLURMCluster( + processes=1, + cores=1, + memory="12GB", + python=sys.executable, + walltime="12:00:00", + ) + cluster.scale(2) + client = Client(cluster) + + # dask + t0 = time.perf_counter() + run_sorter_jobs(job_list, engine="dask", engine_kwargs=dict(client=client)) + t1 = time.perf_counter() + print(t1 - t0) + + +@pytest.mark.skip("Slurm launcher need a machine with slurm") +def test_run_sorter_jobs_slurm(job_list): + if base_output.is_dir(): + shutil.rmtree(base_output) + + working_folder = cache_folder / "test_run_sorters_slurm" if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable - rec0 = load_extractor(cache_folder / "toy_rec_0") - rec1 = load_extractor(cache_folder / "toy_rec_1") - - recording_list = [rec0, rec1] - sorter_list = ["tridesclous"] + tmp_script_folder = working_folder / "slurm_scripts" - run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + run_sorter_jobs( + job_list, + engine="slurm", + engine_kwargs=dict( + tmp_script_folder=tmp_script_folder, + cpus_per_task=32, + mem="32G", + ), + ) def test_run_sorter_by_property(): @@ -59,7 +156,7 @@ def test_run_sorter_by_property(): rec0_by = rec0.split_by("group") group_names0 = list(rec0_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting0 = run_sorter_by_property(sorter_name, rec0, "group", working_folder1, engine="loop", verbose=False) assert "group" in sorting0.get_property_keys() assert all([g in group_names0 for g in sorting0.get_property("group")]) @@ -68,12 +165,31 @@ def test_run_sorter_by_property(): rec1_by = rec1.split_by("group") group_names1 = list(rec1_by.keys()) - sorter_name = "tridesclous" + sorter_name = "tridesclous2" sorting1 = run_sorter_by_property(sorter_name, rec1, "group", working_folder2, engine="loop", verbose=False) assert "group" in sorting1.get_property_keys() assert all([g in group_names1 for g in sorting1.get_property("group")]) +# run_sorters is deprecated +# This will test will be removed in next release +def test_run_sorters_with_list(): + working_folder = cache_folder / "test_run_sorters_list" + if working_folder.is_dir(): + shutil.rmtree(working_folder) + + # make dumpable + rec0 = load_extractor(cache_folder / "toy_rec_0") + rec1 = load_extractor(cache_folder / "toy_rec_1") + + recording_list = [rec0, rec1] + sorter_list = ["tridesclous2"] + + run_sorters(sorter_list, recording_list, working_folder, engine="loop", verbose=False, with_output=False) + + +# run_sorters is deprecated +# This will test will be removed in next release def test_run_sorters_with_dict(): working_folder = cache_folder / "test_run_sorters_dict" if working_folder.is_dir(): @@ -84,9 +200,9 @@ def test_run_sorters_with_dict(): recording_dict = {"toy_tetrode": rec0, "toy_octotrode": rec1} - sorter_list = ["tridesclous", "tridesclous2"] + sorter_list = ["tridesclous2"] - sorter_params = {"tridesclous": dict(detect_threshold=5.6), "tridesclous2": dict()} + sorter_params = {"tridesclous2": dict()} # simple loop t0 = time.perf_counter() @@ -116,143 +232,19 @@ def test_run_sorters_with_dict(): ) -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_joblib(): - working_folder = cache_folder / "test_run_sorters_joblib" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # joblib - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder / "with_joblib", - engine="joblib", - engine_kwargs={"n_jobs": 4}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_dask(): - working_folder = cache_folder / "test_run_sorters_dask" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "tridesclous", - ] - - # create a dask Client for a slurm queue - from dask.distributed import Client - from dask_jobqueue import SLURMCluster - - python = "/home/samuel.garcia/.virtualenvs/py36/bin/python3.6" - cluster = SLURMCluster( - processes=1, - cores=1, - memory="12GB", - python=python, - walltime="12:00:00", - ) - cluster.scale(5) - client = Client(cluster) - - # dask - t0 = time.perf_counter() - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="dask", - engine_kwargs={"client": client}, - with_output=False, - mode_if_folder_exists="keep", - ) - t1 = time.perf_counter() - print(t1 - t0) - - -@pytest.mark.skipif(True, reason="This is tested locally") -def test_run_sorters_slurm(): - working_folder = cache_folder / "test_run_sorters_slurm" - if working_folder.is_dir(): - shutil.rmtree(working_folder) - - # create recording - recording_dict = {} - for i in range(4): - rec = load_extractor(cache_folder / f"toy_rec_{i}") - recording_dict[f"rec_{i}"] = rec - - sorter_list = [ - "spykingcircus2", - "tridesclous2", - ] - - tmp_script_folder = working_folder / "slurm_scripts" - tmp_script_folder.mkdir(parents=True) - - run_sorters( - sorter_list, - recording_dict, - working_folder, - engine="slurm", - engine_kwargs={ - "tmp_script_folder": tmp_script_folder, - "cpus_per_task": 32, - "mem": "32G", - }, - with_output=False, - mode_if_folder_exists="keep", - verbose=True, - ) - - -def test_collect_sorting_outputs(): - working_folder = cache_folder / "test_run_sorters_dict" - results = collect_sorting_outputs(working_folder) - print(results) - - -def test_sorter_installation(): - # This import is to get error on github when import fails - import tridesclous - - # import circus - - if __name__ == "__main__": - setup_module() - # pass - # test_run_sorters_with_list() - - # test_run_sorter_by_property() + # setup_module() + job_list = get_job_list() - test_run_sorters_with_dict() + # test_run_sorter_jobs_loop(job_list) + # test_run_sorter_jobs_joblib(job_list) + # test_run_sorter_jobs_processpoolexecutor(job_list) + # test_run_sorter_jobs_multiprocessing(job_list) + # test_run_sorter_jobs_dask(job_list) + test_run_sorter_jobs_slurm(job_list) - # test_run_sorters_joblib() - - # test_run_sorters_dask() - - # test_run_sorters_slurm() + # test_run_sorter_by_property() - # test_collect_sorting_outputs() + # this deprecated + # test_run_sorters_with_list() + # test_run_sorters_with_dict()