Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to run_sorter_jobs() and slurm #3105

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
efa0751
added option to pass extra arguments to `sbatch` when using `run_sort…
Jun 29, 2024
4bdf244
added option to pass extra arguments to `sbatch` when using `run_sort…
Jun 29, 2024
2d0a83b
Merge remote-tracking branch 'origin/main'
Jun 29, 2024
ce87e68
mistake in Popen
Jun 29, 2024
ebabfec
Merge remote-tracking branch 'origin/main'
Jun 29, 2024
4e59f23
cleaned up code
MarinManuel Jun 29, 2024
407c488
Merge branch 'SpikeInterface:main' into slurm_updates
MarinManuel Jul 2, 2024
d204642
updated to use sbatch_kwargs instead of putting all slurm arguments d…
Jul 2, 2024
8fedd95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
0733589
removed sbatch_executable from the list of kwargs
Jul 3, 2024
54a7b8f
Merge remote-tracking branch 'origin/slurm_updates' into slurm_updates
Jul 3, 2024
9c3ff1d
removed sbatch_executable from the list of kwargs
Jul 3, 2024
c86f3b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
97198de
Merge branch 'main' into slurm_updates
MarinManuel Aug 16, 2024
f12461f
clarified docstring and added error for cpus_per_taks
MarinManuel Aug 16, 2024
002e959
added test
MarinManuel Aug 16, 2024
88ca2f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
e4f9f1f
added test
MarinManuel Aug 16, 2024
e4b0b81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
492c4d0
Merge branch 'SpikeInterface:main' into slurm_updates
MarinManuel Sep 11, 2024
5b6d560
docstring fix
MarinManuel Sep 11, 2024
35a2a7e
docstring fix
MarinManuel Sep 11, 2024
910fa61
docstring fix
MarinManuel Sep 11, 2024
6ba8423
added slurm_kwargs argument
MarinManuel Sep 11, 2024
0ec9af5
fixed test
MarinManuel Sep 11, 2024
e707170
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2024
e7a89c9
fixed test failing on Windows by limiting to slurm test to Linux
MarinManuel Sep 12, 2024
48aee1b
fixed test failing on Windows by limiting to slurm test to Linux
MarinManuel Sep 12, 2024
6bc8be2
reverted slurm_kwargs and improved docstring
MarinManuel Sep 12, 2024
d2ac504
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ test_core = [
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs,
"pytest-mock"
]

test_extractors = [
Expand Down Expand Up @@ -177,6 +180,9 @@ test = [
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs
"pytest-mock",
]

docs = [
Expand Down
82 changes: 53 additions & 29 deletions src/spikeinterface/sorters/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,29 @@

from __future__ import annotations


from pathlib import Path
import shutil
import numpy as np
import tempfile
import os
import stat
import subprocess
import sys
import tempfile
import warnings

import numpy as np
from pathlib import Path
from spikeinterface.core import aggregate_units

from .sorterlist import sorter_dict
from .runsorter import run_sorter
from .basesorter import is_log_ok

_default_engine_kwargs = dict(
loop=dict(),
joblib=dict(n_jobs=-1, backend="loky"),
processpoolexecutor=dict(max_workers=2, mp_context=None),
dask=dict(client=None),
slurm=dict(tmp_script_folder=None, cpus_per_task=1, mem="1G"),
slurm={"tmp_script_folder": None, "sbatch_args": {"cpus-per-task": 1, "mem": "1G"}},
)


_implemented_engine = list(_default_engine_kwargs.keys())


def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=False):
def run_sorter_jobs(job_list, engine="loop", engine_kwargs=None, return_output=False):
JoeZiminski marked this conversation as resolved.
Show resolved Hide resolved
"""
Run several :py:func:`run_sorter()` sequentially or in parallel given a list of jobs.

Expand All @@ -55,23 +48,43 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal

Where *blocking* means that this function is blocking until the results are returned.
This is in opposition to *asynchronous*, where the function returns `None` almost immediately (aka non-blocking),
but the results must be retrieved by hand when jobs are finished. No mechanisim is provided here to be know
when jobs are finish.
but the results must be retrieved by hand when jobs are finished. No mechanism is provided here to know
when jobs are finished.
In this *asynchronous* case, the :py:func:`~spikeinterface.sorters.read_sorter_folder()` helps to retrieve individual results.


Parameters
----------
job_list : list of dict
A list a dict that are propagated to run_sorter(...)
engine : str "loop", "joblib", "dask", "slurm"
The engine to run the list.
* "loop" : a simple loop. This engine is
engine_kwargs : dict

return_output : bool, dfault False
Parameters to be passed to the underlying engine.
* loop : None
* joblib :
- n_jobs : int
The maximum number of concurrently running jobs (default=-1, tries to use all CPUs)
- backend : str
Specify the parallelization backend implementation (default="loky")
* multiprocessing :
- max_workers : int
maximum number of processes (default=2)
- mp_context : str
multiprocessing context (default=None)
* dask :
- client : dask.distributed.Client
Dask client to connect to (required)
* slurm :
- tmp_script_folder : str,Path
the folder in which the job scripts are created (default=None, create a random temporary directory)
- sbatch_args: dict
dictionary of arguments to be passed to the sbatch command. They will be automatically prefixed with --.
Arguments must be in the format slurm specify, see the [documentation for `sbatch`](https://slurm.schedmd.com/sbatch.html)
for a list of possible arguments (default={"cpus-per-task": 1, "mem": "1G"})

return_output : bool, default: False
Return a sortings or None.
This also overwrite kwargs in in run_sorter(with_sorting=True/False)
This also overwrites kwargs in run_sorter(with_sorting=True/False)

Returns
-------
Expand All @@ -81,6 +94,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal

assert engine in _implemented_engine, f"engine must be in {_implemented_engine}"

if engine_kwargs is None:
engine_kwargs = dict()
engine_kwargs_ = dict()
engine_kwargs_.update(_default_engine_kwargs[engine])
engine_kwargs_.update(engine_kwargs)
Expand Down Expand Up @@ -145,14 +160,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
task.result()

elif engine == "slurm":
if "cpus_per_task" in engine_kwargs:
raise ValueError(
"keyword argument cpus_per_task is no longer supported for slurm engine, "
"please use cpus-per-task instead."
)
# generate python script for slurm
tmp_script_folder = engine_kwargs["tmp_script_folder"]
if tmp_script_folder is None:
tmp_script_folder = tempfile.mkdtemp(prefix="spikeinterface_slurm_")
tmp_script_folder = Path(tmp_script_folder)
cpus_per_task = engine_kwargs["cpus_per_task"]
mem = engine_kwargs["mem"]

tmp_script_folder.mkdir(exist_ok=True, parents=True)

for i, kwargs in enumerate(job_list):
Expand Down Expand Up @@ -181,7 +198,16 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal
f.write(slurm_script)
os.fchmod(f.fileno(), mode=stat.S_IRWXU)

subprocess.Popen(["sbatch", str(script_name.absolute()), f"-cpus-per-task={cpus_per_task}", f"-mem={mem}"])
progr = ["sbatch"]
for k, v in engine_kwargs["sbatch_args"].items():
progr.append(f"--{k}")
progr.append(f"{v}")
progr.append(str(script_name.absolute()))
print(f"subprocess called with command {' '.join(progr)}")
p = subprocess.run(progr, capture_output=True, text=True)
JoeZiminski marked this conversation as resolved.
Show resolved Hide resolved
print(p.stdout)
if len(p.stderr) > 0:
warnings.warn(p.stderr)

return out

Expand Down Expand Up @@ -209,7 +235,7 @@ def run_sorter_by_property(
folder,
mode_if_folder_exists=None,
engine="loop",
engine_kwargs={},
engine_kwargs=None,
verbose=False,
docker_image=None,
singularity_image=None,
Expand Down Expand Up @@ -239,13 +265,11 @@ def run_sorter_by_property(
Must be None. This is deprecated.
If not None then a warning is raise.
Will be removed in next release.
engine : "loop" | "joblib" | "dask", default: "loop"
engine : "loop" | "joblib" | "dask" | "slurm", default: "loop"
Which engine to use to run sorter.
engine_kwargs : dict
This contains kwargs specific to the launcher engine:
* "loop" : no kwargs
* "joblib" : {"n_jobs" : } number of processes
* "dask" : {"client":} the dask client for submitting task
This contains kwargs specific to the launcher engine.
See the documentation for :py:func:`~spikeinterface.sorters.launcher.run_sorter_jobs()` for more details.
verbose : bool, default: False
Controls sorter verboseness
docker_image : None or str, default: None
Expand Down
75 changes: 73 additions & 2 deletions src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
import shutil
import tempfile
import time

import pytest
from pathlib import Path

from platform import system
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property

Expand Down Expand Up @@ -126,6 +126,77 @@ def test_run_sorter_jobs_slurm(job_list, create_cache_folder):
)


@pytest.mark.skipif(system() != "Linux", reason="Assumes we are on Linux to run SLURM")
def test_run_sorter_jobs_slurm_kwargs(mocker, tmp_path, job_list):
"""
Mock `subprocess.run()` to check that engine_kwargs are
propagated to the call as expected.
"""
# First, mock `subprocess.run()`, set up a call to `run_sorter_jobs`
# then check the mocked `subprocess.run()` was called with the
# expected signature. Two jobs are passed in `jobs_list`, first
# check the most recent call.
mock_subprocess_run = mocker.patch("spikeinterface.sorters.launcher.subprocess.run")

tmp_script_folder = tmp_path / "slurm_scripts"

engine_kwargs = dict(
tmp_script_folder=tmp_script_folder,
sbatch_args={
"cpus-per-task": 32,
"mem": "32G",
"gres": "gpu:1",
"any_random_kwarg": 12322,
},
)

run_sorter_jobs(job_list, engine="slurm", engine_kwargs=engine_kwargs)

script_0_path = f"{tmp_script_folder}/si_script_0.py"
script_1_path = f"{tmp_script_folder}/si_script_1.py"

expected_command = [
"sbatch",
"--cpus-per-task",
"32",
"--mem",
"32G",
"--gres",
"gpu:1",
"--any_random_kwarg",
"12322",
script_1_path,
]
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)

# Next, check the fisrt call (which sets up `si_script_0.py`)
# also has the expected arguments.
expected_command[9] = script_0_path
assert mock_subprocess_run.call_args_list[0].args[0] == expected_command

# Next, check that defaults are used properly when no kwargs are
# passed. This will default to `_default_engine_kwargs` as
# set in `launcher.py`
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs={"tmp_script_folder": tmp_script_folder},
)
expected_command = ["sbatch", "--cpus-per-task", "1", "--mem", "1G", script_1_path]
mock_subprocess_run.assert_called_with(expected_command, capture_output=True, text=True)

# Finally, check that the `tmp_script_folder` is generated on the
# fly as expected. A random foldername is generated, just check that
# the folder to which the scripts are saved is in the `tempfile` format.
run_sorter_jobs(
job_list,
engine="slurm",
engine_kwargs=None,
)
tmp_script_folder = "_".join(tempfile.mkdtemp(prefix="spikeinterface_slurm_").split("_")[:-1])
assert tmp_script_folder in mock_subprocess_run.call_args_list[-1].args[0][5]


def test_run_sorter_by_property(create_cache_folder):
cache_folder = create_cache_folder
working_folder1 = cache_folder / "test_run_sorter_by_property_1"
Expand Down
Loading